diff --git a/spm.c b/spm.c index bc015845f214018fcab286f1b0bf8c7151933ea2..44adced19722e18f20b4b2c2d691165ef9fe3a6d 100644 --- a/spm.c +++ b/spm.c @@ -25,6 +25,8 @@ #include "s_spm.h" #include "p_spm.h" +#include <cblas.h> + #if !defined(DOXYGEN_SHOULD_SKIP_THIS) static int (*conversionTable[3][3][6])(pastix_spm_t*) = { @@ -1179,7 +1181,7 @@ spmCheckAxb( int nrhs, * *******************************************************************************/ void -spmScal(const pastix_complex64_t alpha, pastix_spm_t* spm) +spmScalMatrix(const pastix_complex64_t alpha, pastix_spm_t* spm) { switch(spm->flttype) { @@ -1200,6 +1202,47 @@ spmScal(const pastix_complex64_t alpha, pastix_spm_t* spm) } } +/** + ******************************************************************************* + * + * @brief Scale a vector according to the spm type. + * + * x = alpha * x + * + ******************************************************************************* + * + * @param[in] alpha + * The scaling parameter. + * + * @param[in] spm + * The spm structure to know the type of the vector. + * + * @param[inout] x + * The vector to scal. + * + *******************************************************************************/ +void +spmScalVector(const double alpha, pastix_spm_t* spm, void *x) +{ + switch(spm->flttype) + { + case PastixPattern: + break; + case PastixFloat: + cblas_sscal(spm->n, alpha, x, 1); + break; + case PastixComplex32: + cblas_csscal(spm->n, alpha, x, 1); + break; + case PastixComplex64: + cblas_zdscal(spm->n, alpha, x, 1); + break; + case PastixDouble: + default: + cblas_dscal(spm->n, alpha, x, 1); + } +} + /** * @} */ diff --git a/spm.h b/spm.h index a0a9ba0de8f2d0cdc09c9ddaf7e5dd24ba2c534a..f09c563933c3796662aed7a6fd67628c8c34af55 100644 --- a/spm.h +++ b/spm.h @@ -90,7 +90,8 @@ void spmUpdateComputedFields( pastix_spm_t *spm ); */ double spmNorm( pastix_normtype_t ntype, const pastix_spm_t *spm ); int spmMatVec(const pastix_trans_t trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y ); -void spmScal( const pastix_complex64_t alpha, pastix_spm_t* spm ); +void spmScalMatrix( const pastix_complex64_t alpha, pastix_spm_t* spm ); +void spmScalVector( const double alpha, pastix_spm_t* spm, void *x ); /** * @}