diff --git a/spm.c b/spm.c index 673c229ae8d4829508d6bd2654d2acab0972cdea..ff7cdbfaab771463a7ad86c9129dd9504c889f55 100644 --- a/spm.c +++ b/spm.c @@ -1311,95 +1311,50 @@ spmScalMatrix(double alpha, pastix_spm_t* spm) * ******************************************************************************* * - * @param[in] alpha - * The scaling parameter. - * - * @param[in] spm - * The spm structure to know the type of the vector. - * - * @param[inout] x - * The vector to scal. - * - *******************************************************************************/ -void -spmScalVector(double alpha, pastix_spm_t* spm, void *x) -{ - switch(spm->flttype) - { - case PastixPattern: - break; - case PastixFloat: - cblas_sscal(spm->n, (float)alpha, x, 1); - break; - case PastixComplex32: - cblas_csscal(spm->n, (float)alpha, x, 1); - break; - case PastixComplex64: - cblas_zdscal(spm->n, alpha, x, 1); - break; - case PastixDouble: - default: - cblas_dscal(spm->n, alpha, x, 1); - } -} - -/** - ******************************************************************************* - * - * @brief Scale a dense matrix corresponding to a set of RHS (wrapper to - * LAPACKE_xlascl) - * - * A = alpha * A - * - ******************************************************************************* - * * @param[in] flt - * Datatype of the matrix that must be: + * Datatype of the elements in the vector that must be: * @arg PastixFloat * @arg PastixDouble * @arg PastixComplex32 * @arg PastixComplex64 * - * @param[in] m - * Number of rows of the matrix A. - * * @param[in] n - * Number of columns of the matrix A. + * Number of elements in the input vectors * * @param[in] alpha - * The scaling parameter. + * The scaling parameter. * - * @param[inout] A - * The dense matrix to scale of size lda-by-n + * @param[inout] x + * The vector to scal of size ( 1 + (n-1) * abs(incx) ), and of type + * defined by flt. * - * @param[in] lda - * Defines the leading dimension of A. lda >= m. + * @param[in] inc + * Storage spacing between elements of x. * *******************************************************************************/ void -spmScalRHS( pastix_coeftype_t flt, - double alpha, - pastix_int_t m, - pastix_int_t n, - void *A, - pastix_int_t lda ) +spmScalVector( pastix_coeftype_t flt, + double alpha, + pastix_int_t n, + void *x, + pastix_int_t inc ) { - switch(flt) + switch(spm->flttype) { case PastixPattern: break; case PastixFloat: - LAPACKE_slascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda); + cblas_sscal( n, (float)alpha, x, incx ); break; case PastixComplex32: - LAPACKE_clascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda); + cblas_csscal( n, (float)alpha, x, incx ); break; case PastixComplex64: - LAPACKE_zlascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda); + cblas_zdscal( n, alpha, x, incx ); break; case PastixDouble: default: - LAPACKE_dlascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda); + cblas_dscal( n, alpha, x, incx ); } } diff --git a/spm.h b/spm.h index 9686a79fba312694e650a6de164dd32740a62d7e..a677e6b8322bf683369f08e972b2abb224da5ecf 100644 --- a/spm.h +++ b/spm.h @@ -97,8 +97,7 @@ int spmMatMat( pastix_trans_t trans, pastix_int_t n, const void *B, pastix_int_t ldb, const void *beta, void *C, pastix_int_t ldc ); void spmScalMatrix( double alpha, pastix_spm_t *spm ); -void spmScalVector( double alpha, pastix_spm_t *spm, void *x ); -void spmScalRHS( pastix_coeftype_t flt, double alpha, pastix_int_t m, pastix_int_t n, void *A, pastix_int_t lda ); +void spmScalVector( pastix_coeftype_t flt, double alpha, pastix_int_t n, void *x, pastix_int_t incx ); /** * @}