From 2a75a75044287506046eeb85329764602dc9e85e Mon Sep 17 00:00:00 2001 From: Pierre Ramet <pierre.ramet@inria.fr> Date: Mon, 19 Feb 2018 17:13:50 +0100 Subject: [PATCH] scal multi RHS --- spm.c | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- spm.h | 5 +++-- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/spm.c b/spm.c index 95562345..4ddac602 100644 --- a/spm.c +++ b/spm.c @@ -26,6 +26,7 @@ #include "p_spm.h" #include <cblas.h> +#include <lapacke.h> #if !defined(DOXYGEN_SHOULD_SKIP_THIS) @@ -1184,7 +1185,7 @@ spmCheckAxb( pastix_int_t nrhs, * *******************************************************************************/ void -spmScalMatrix(const double alpha, pastix_spm_t* spm) +spmScalMatrix(double alpha, pastix_spm_t* spm) { switch(spm->flttype) { @@ -1225,7 +1226,7 @@ spmScalMatrix(const double alpha, pastix_spm_t* spm) * *******************************************************************************/ void -spmScalVector(const double alpha, pastix_spm_t* spm, void *x) +spmScalVector(double alpha, pastix_spm_t* spm, void *x) { switch(spm->flttype) { @@ -1246,6 +1247,57 @@ spmScalVector(const double alpha, pastix_spm_t* spm, void *x) } } +/** + ******************************************************************************* + * + * @brief Scale a dense matrix corresponding to a set of RHS (wrapper to LAPACKE_xlascl) + * + * A = alpha * A + * + ******************************************************************************* + * + * @param[in] flt + * Datatype. + * + * @param[in] m + * Number of rows of the matrix A. + * + * @param[in] n + * Number of columns of the matrix A. + * + * @param[in] alpha + * The scaling parameter. + * + * @param[inout] A + * The matrix of RHS to scal. + * + * @param[in] lda + * Defines the leading dimension of A when multiple right hand sides + * are available. lda >= m. + * + *******************************************************************************/ +void +spmScalRHS(pastix_coeftype_t flt, double alpha, pastix_int_t m, pastix_int_t n, void *A, pastix_int_t lda) +{ + switch(flt) + { + case PastixPattern: + break; + case PastixFloat: + LAPACKE_slascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda); + break; + case PastixComplex32: + LAPACKE_clascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda); + break; + case PastixComplex64: + LAPACKE_zlascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda); + break; + case PastixDouble: + default: + LAPACKE_dlascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda); + } +} + /** * @} */ diff --git a/spm.h b/spm.h index 0f782dcb..91df2c91 100644 --- a/spm.h +++ b/spm.h @@ -92,8 +92,9 @@ void spmGenFakeValues( pastix_spm_t *spm ); */ double spmNorm( pastix_normtype_t ntype, const pastix_spm_t *spm ); int spmMatVec( pastix_trans_t trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y ); -void spmScalMatrix( const double alpha, pastix_spm_t *spm ); -void spmScalVector( const double alpha, pastix_spm_t *spm, void *x ); +void spmScalMatrix( double alpha, pastix_spm_t *spm ); +void spmScalVector( double alpha, pastix_spm_t *spm, void *x ); +void spmScalRHS( pastix_coeftype_t flt, double alpha, pastix_int_t m, pastix_int_t n, void *A, pastix_int_t lda ); /** * @} -- GitLab