Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 8a3e00ef authored by Mathieu Faverge's avatar Mathieu Faverge
Browse files

Simplify code for matvec

parent 41e1c9c4
No related branches found
No related tags found
No related merge requests found
......@@ -768,51 +768,27 @@ spmExpand(const pastix_spm_t* spm)
* functions, and simplify this one to have identical calls to all subfunction
*/
int
spmMatVec( int trans,
const void *alpha,
const pastix_spm_t *spm,
const void *x,
const void *beta,
void *y )
spmMatVec(const pastix_trans_t trans,
const void *alpha,
const pastix_spm_t *spm,
const void *x,
const void *beta,
void *y )
{
switch (spm->mtxtype) {
case PastixHermitian:
switch (spm->flttype) {
case PastixFloat:
return s_spmSyCSCv( *((const float*)alpha), spm, (const float*)x, *((const float*)beta), (float*)y );
case PastixComplex32:
return c_spmHeCSCv( *((const pastix_complex32_t*)alpha), spm, (const pastix_complex32_t*)x, *((const pastix_complex32_t*)beta), (pastix_complex32_t*)y );
case PastixComplex64:
return z_spmHeCSCv( *((const pastix_complex64_t*)alpha), spm, (const pastix_complex64_t*)x, *((const pastix_complex64_t*)beta), (pastix_complex64_t*)y );
case PastixDouble:
default:
return d_spmSyCSCv( *((const double*)alpha), spm, (const double*)x, *((const double*)beta), (double*)y );
}
case PastixSymmetric:
switch (spm->flttype) {
case PastixFloat:
return s_spmSyCSCv( *((const float*)alpha), spm, (const float*)x, *((const float*)beta), (float*)y );
case PastixComplex32:
return c_spmSyCSCv( *((const pastix_complex32_t*)alpha), spm, (const pastix_complex32_t*)x, *((const pastix_complex32_t*)beta), (pastix_complex32_t*)y );
case PastixComplex64:
return z_spmSyCSCv( *((const pastix_complex64_t*)alpha), spm, (const pastix_complex64_t*)x, *((const pastix_complex64_t*)beta), (pastix_complex64_t*)y );
case PastixDouble:
default:
return d_spmSyCSCv( *((const double*)alpha), spm, (const double*)x, *((const double*)beta), (double*)y );
}
case PastixGeneral:
if ( spm->fmttype != PastixCSC ) {
return PASTIX_ERR_BADPARAMETER;
}
switch (spm->flttype) {
case PastixFloat:
return s_spmCSCMatVec( trans, alpha, spm, x, beta, y );
case PastixComplex32:
return c_spmCSCMatVec( trans, alpha, spm, x, beta, y );
case PastixComplex64:
return z_spmCSCMatVec( trans, alpha, spm, x, beta, y );
case PastixDouble:
default:
switch (spm->flttype) {
case PastixFloat:
return s_spmGeCSCv( trans, *((const float*)alpha), spm, (const float*)x, *((const float*)beta), (float*)y );
case PastixComplex32:
return c_spmGeCSCv( trans, *((const pastix_complex32_t*)alpha), spm, (const pastix_complex32_t*)x, *((const pastix_complex32_t*)beta), (pastix_complex32_t*)y );
case PastixComplex64:
return z_spmGeCSCv( trans, *((const pastix_complex64_t*)alpha), spm, (const pastix_complex64_t*)x, *((const pastix_complex64_t*)beta), (pastix_complex64_t*)y );
case PastixDouble:
default:
return d_spmGeCSCv( trans, *((const double*)alpha), spm, (const double*)x, *((const double*)beta), (double*)y );
}
return d_spmCSCMatVec( trans, alpha, spm, x, beta, y );
}
}
......
......@@ -127,7 +127,7 @@ void spmBase( pastix_spm_t *spm, int baseval );
int spmConvert( int ofmttype, pastix_spm_t *ospm );
pastix_int_t spmFindBase( const pastix_spm_t *spm );
double spmNorm( int ntype, const pastix_spm_t *spm );
int spmMatVec(int trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y );
int spmMatVec(const pastix_trans_t trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y );
int spmSort( pastix_spm_t *spm );
pastix_int_t spmMergeDuplicate( pastix_spm_t *spm );
......
......@@ -41,9 +41,11 @@ pastix_complex64_t *z_spm2dense( const pastix_spm_t *spm );
/**
* Matrix-Vector product routines
*/
int z_spmGeCSCv(int trans, pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y);
int z_spmSyCSCv( pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y);
int z_spmHeCSCv( pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y);
int z_spmGeCSCv(const pastix_trans_t trans, pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y);
int z_spmSyCSCv( pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y);
int z_spmHeCSCv( pastix_complex64_t alpha, const pastix_spm_t *csc, const pastix_complex64_t *x, pastix_complex64_t beta, pastix_complex64_t *y);
int z_spmCSCMatVec(const pastix_trans_t trans, const void *alpha, const pastix_spm_t *csc, const void *x, const void *beta, void *y);
/**
* Extra routines
......
......@@ -26,6 +26,10 @@
#define RndF_Mul 5.4210108624275222e-20f
#define RndD_Mul 5.4210108624275222e-20
static pastix_complex64_t mzone = (pastix_complex64_t)-1.;
static pastix_complex64_t zone = (pastix_complex64_t) 1.;
static pastix_complex64_t zzero = (pastix_complex64_t) 0.;
static inline unsigned long long int
Rnd64_jump(unsigned long long int n, unsigned long long int seed ) {
unsigned long long int a_k, c_k, ran;
......@@ -260,19 +264,7 @@ z_spmGenRHS( int type, int nrhs,
spm->gN, 0, 0, 24356 );
}
switch ( spm->mtxtype ) {
#if defined(PRECISION_z) || defined(PRECISION_c)
case PastixHermitian:
rc = z_spmHeCSCv( 1., spm, xptr, 0., bptr );
break;
#endif
case PastixSymmetric:
rc = z_spmSyCSCv( 1., spm, xptr, 0., bptr );
break;
case PastixGeneral:
default:
rc = z_spmGeCSCv( PastixNoTrans, 1., spm, xptr, 0., bptr );
}
rc = z_spmCSCMatVec( PastixNoTrans, &zone, spm, xptr, &zzero, bptr );
if ( x == NULL ) {
memFree_null(xptr);
......@@ -339,8 +331,6 @@ z_spmCheckAxb( int nrhs,
void *b, int ldb,
const void *x, int ldx )
{
static pastix_complex64_t mzone = (pastix_complex64_t)-1.;
static pastix_complex64_t zone = (pastix_complex64_t) 1.;
double normA, normB, normX, normX0, normR;
double backward, forward, eps;
int failure = 0;
......
......@@ -63,16 +63,16 @@
*
*******************************************************************************/
int
z_spmGeCSCv( int trans,
z_spmGeCSCv(const pastix_trans_t trans,
pastix_complex64_t alpha,
const pastix_spm_t *csc,
const pastix_complex64_t *x,
pastix_complex64_t beta,
pastix_complex64_t *y )
{
const pastix_complex64_t *valptr = (pastix_complex64_t*)csc->values;
const pastix_complex64_t *xptr = x;
pastix_complex64_t *yptr = y;
const pastix_complex64_t *valptr = (pastix_complex64_t*)(csc->values);
const pastix_complex64_t *xptr = (const pastix_complex64_t*)x;
pastix_complex64_t *yptr = (pastix_complex64_t*)y;
pastix_int_t col, row, i, baseval;
if ( (csc == NULL) || (x == NULL) || (y == NULL ) )
......@@ -328,3 +328,32 @@ z_spmHeCSCv( pastix_complex64_t alpha,
return PASTIX_SUCCESS;
}
#endif
int
z_spmCSCMatVec(const pastix_trans_t trans,
const void *alphaptr,
const pastix_spm_t *csc,
const void *xptr,
const void *betaptr,
void *yptr )
{
const pastix_complex64_t *x = (const pastix_complex64_t*)xptr;
pastix_complex64_t *y = (pastix_complex64_t*)yptr;
pastix_complex64_t alpha, beta;
alpha = *((const pastix_complex64_t *)alphaptr);
beta = *((const pastix_complex64_t *)betaptr);
switch (csc->mtxtype) {
#if defined(PRECISION_z) || defined(PRECISION_c)
case PastixHermitian:
return z_spmHeCSCv( alpha, csc, x, beta, y );
#endif
case PastixSymmetric:
return z_spmSyCSCv( alpha, csc, x, beta, y );
case PastixGeneral:
default:
return z_spmGeCSCv( trans, alpha, csc, x, beta, y );
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment