diff --git a/spm.c b/spm.c index 922ad3be2903f0437d4e643fa3d88d6737230110..9c95353379f7ebe13308bbc2df00bd47cbd6738f 100644 --- a/spm.c +++ b/spm.c @@ -768,51 +768,27 @@ spmExpand(const pastix_spm_t* spm) * functions, and simplify this one to have identical calls to all subfunction */ int -spmMatVec( int trans, - const void *alpha, - const pastix_spm_t *spm, - const void *x, - const void *beta, - void *y ) +spmMatVec(const pastix_trans_t trans, + const void *alpha, + const pastix_spm_t *spm, + const void *x, + const void *beta, + void *y ) { - switch (spm->mtxtype) { - case PastixHermitian: - switch (spm->flttype) { - case PastixFloat: - return s_spmSyCSCv( *((const float*)alpha), spm, (const float*)x, *((const float*)beta), (float*)y ); - case PastixComplex32: - return c_spmHeCSCv( *((const pastix_complex32_t*)alpha), spm, (const pastix_complex32_t*)x, *((const pastix_complex32_t*)beta), (pastix_complex32_t*)y ); - case PastixComplex64: - return z_spmHeCSCv( *((const pastix_complex64_t*)alpha), spm, (const pastix_complex64_t*)x, *((const pastix_complex64_t*)beta), (pastix_complex64_t*)y ); - case PastixDouble: - default: - return d_spmSyCSCv( *((const double*)alpha), spm, (const double*)x, *((const double*)beta), (double*)y ); - } - case PastixSymmetric: - switch (spm->flttype) { - case PastixFloat: - return s_spmSyCSCv( *((const float*)alpha), spm, (const float*)x, *((const float*)beta), (float*)y ); - case PastixComplex32: - return c_spmSyCSCv( *((const pastix_complex32_t*)alpha), spm, (const pastix_complex32_t*)x, *((const pastix_complex32_t*)beta), (pastix_complex32_t*)y ); - case PastixComplex64: - return z_spmSyCSCv( *((const pastix_complex64_t*)alpha), spm, (const pastix_complex64_t*)x, *((const pastix_complex64_t*)beta), (pastix_complex64_t*)y ); - case PastixDouble: - default: - return d_spmSyCSCv( *((const double*)alpha), spm, (const double*)x, *((const double*)beta), (double*)y ); - } - case PastixGeneral: + if ( spm->fmttype != PastixCSC ) { + return PASTIX_ERR_BADPARAMETER; + } + + switch (spm->flttype) { + case PastixFloat: + return s_spmCSCMatVec( trans, alpha, spm, x, beta, y ); + case PastixComplex32: + return c_spmCSCMatVec( trans, alpha, spm, x, beta, y ); + case PastixComplex64: + return z_spmCSCMatVec( trans, alpha, spm, x, beta, y ); + case PastixDouble: default: - switch (spm->flttype) { - case PastixFloat: - return s_spmGeCSCv( trans, *((const float*)alpha), spm, (const float*)x, *((const float*)beta), (float*)y ); - case PastixComplex32: - return c_spmGeCSCv( trans, *((const pastix_complex32_t*)alpha), spm, (const pastix_complex32_t*)x, *((const pastix_complex32_t*)beta), (pastix_complex32_t*)y ); - case PastixComplex64: - return z_spmGeCSCv( trans, *((const pastix_complex64_t*)alpha), spm, (const pastix_complex64_t*)x, *((const pastix_complex64_t*)beta), (pastix_complex64_t*)y ); - case PastixDouble: - default: - return d_spmGeCSCv( trans, *((const double*)alpha), spm, (const double*)x, *((const double*)beta), (double*)y ); - } + return d_spmCSCMatVec( trans, alpha, spm, x, beta, y ); } } diff --git a/spm.h b/spm.h index 96f9219426b595d3a6485d9575a3c50700d67de0..d46a446dcb800e5c6c169b6c6b05a0b63623501a 100644 --- a/spm.h +++ b/spm.h @@ -127,7 +127,7 @@ void spmBase( pastix_spm_t *spm, int baseval ); int spmConvert( int ofmttype, pastix_spm_t *ospm ); pastix_int_t spmFindBase( const pastix_spm_t *spm ); double spmNorm( int ntype, const pastix_spm_t *spm ); -int spmMatVec(int trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y ); +int spmMatVec(const pastix_trans_t trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y ); int spmSort( pastix_spm_t *spm ); pastix_int_t spmMergeDuplicate( pastix_spm_t *spm ); diff --git a/z_spm.h b/z_spm.h index 0ac21a6956d5fb189fd9c5a83f12e11a6b281f9d..c846eeca1380e82be614cdc8ef9a5eeecf8f324d 100644 --- a/z_spm.h +++ b/z_spm.h @@ -41,9 +41,11 @@ pastix_complex64_t *z_spm2dense( const pastix_spm_t *spm ); /** * Matrix-Vector product routines */ -int z_spmGeCSCv(int trans, pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y); -int z_spmSyCSCv( pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y); -int z_spmHeCSCv( pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y); +int z_spmGeCSCv(const pastix_trans_t trans, pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y); +int z_spmSyCSCv( pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y); +int z_spmHeCSCv( pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y); + +int z_spmCSCMatVec(const pastix_trans_t trans, const void *alpha, const pastix_spm_t *csc, const void *x, const void *beta, void *y); /** * Extra routines diff --git a/z_spm_genrhs.c b/z_spm_genrhs.c index 25d0e17d0423e71e5872bfe2c58f4c6285f4fc32..8a2defa4358a5d5a5363cba587468842791d847f 100644 --- a/z_spm_genrhs.c +++ b/z_spm_genrhs.c @@ -26,6 +26,10 @@ #define RndF_Mul 5.4210108624275222e-20f #define RndD_Mul 5.4210108624275222e-20 +static pastix_complex64_t mzone = (pastix_complex64_t)-1.; +static pastix_complex64_t zone = (pastix_complex64_t) 1.; +static pastix_complex64_t zzero = (pastix_complex64_t) 0.; + static inline unsigned long long int Rnd64_jump(unsigned long long int n, unsigned long long int seed ) { unsigned long long int a_k, c_k, ran; @@ -260,19 +264,7 @@ z_spmGenRHS( int type, int nrhs, spm->gN, 0, 0, 24356 ); } - switch ( spm->mtxtype ) { -#if defined(PRECISION_z) || defined(PRECISION_c) - case PastixHermitian: - rc = z_spmHeCSCv( 1., spm, xptr, 0., bptr ); - break; -#endif - case PastixSymmetric: - rc = z_spmSyCSCv( 1., spm, xptr, 0., bptr ); - break; - case PastixGeneral: - default: - rc = z_spmGeCSCv( PastixNoTrans, 1., spm, xptr, 0., bptr ); - } + rc = z_spmCSCMatVec( PastixNoTrans, &zone, spm, xptr, &zzero, bptr ); if ( x == NULL ) { memFree_null(xptr); @@ -339,8 +331,6 @@ z_spmCheckAxb( int nrhs, void *b, int ldb, const void *x, int ldx ) { - static pastix_complex64_t mzone = (pastix_complex64_t)-1.; - static pastix_complex64_t zone = (pastix_complex64_t) 1.; double normA, normB, normX, normX0, normR; double backward, forward, eps; int failure = 0; diff --git a/z_spm_matrixvector.c b/z_spm_matrixvector.c index 854c711a53547240c4152cf51abf321bf81cc1e1..48496628683c68000d4b137e895c83275b39dde4 100644 --- a/z_spm_matrixvector.c +++ b/z_spm_matrixvector.c @@ -63,16 +63,16 @@ * *******************************************************************************/ int -z_spmGeCSCv( int trans, +z_spmGeCSCv(const pastix_trans_t trans, pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y ) { - const pastix_complex64_t *valptr = (pastix_complex64_t*)csc->values; - const pastix_complex64_t *xptr = x; - pastix_complex64_t *yptr = y; + const pastix_complex64_t *valptr = (pastix_complex64_t*)(csc->values); + const pastix_complex64_t *xptr = (const pastix_complex64_t*)x; + pastix_complex64_t *yptr = (pastix_complex64_t*)y; pastix_int_t col, row, i, baseval; if ( (csc == NULL) || (x == NULL) || (y == NULL ) ) @@ -328,3 +328,32 @@ z_spmHeCSCv( pastix_complex64_t alpha, return PASTIX_SUCCESS; } #endif + + +int +z_spmCSCMatVec(const pastix_trans_t trans, + const void *alphaptr, + const pastix_spm_t *csc, + const void *xptr, + const void *betaptr, + void *yptr ) +{ + const pastix_complex64_t *x = (const pastix_complex64_t*)xptr; + pastix_complex64_t *y = (pastix_complex64_t*)yptr; + pastix_complex64_t alpha, beta; + + alpha = *((const pastix_complex64_t *)alphaptr); + beta = *((const pastix_complex64_t *)betaptr); + + switch (csc->mtxtype) { +#if defined(PRECISION_z) || defined(PRECISION_c) + case PastixHermitian: + return z_spmHeCSCv( alpha, csc, x, beta, y ); +#endif + case PastixSymmetric: + return z_spmSyCSCv( alpha, csc, x, beta, y ); + case PastixGeneral: + default: + return z_spmGeCSCv( trans, alpha, csc, x, beta, y ); + } +}