From 96d9c0fbfe9bbd36a652c8a0fcbcbdc3cad6f073 Mon Sep 17 00:00:00 2001 From: KUHN Matthieu <matthieu.kuhn@inria.fr> Date: Thu, 5 Apr 2018 14:49:35 +0200 Subject: [PATCH] Matrix vector product for all Spm storage formats using function pointers --- include/spm.h | 1 + src/spm.c | 2 +- src/z_spm_matrixvector.c | 374 ++++++++++++++++++++++++++++++++++++--- tests/spm_matvec_tests.c | 87 ++++----- 4 files changed, 392 insertions(+), 72 deletions(-) diff --git a/include/spm.h b/include/spm.h index f74efe7e..a5a1aa85 100644 --- a/include/spm.h +++ b/include/spm.h @@ -131,6 +131,7 @@ spm_int_t * spmIntConvert( spm_int_t n, int *input ); void spmIntSort1Asc1( void * const pbase, const spm_int_t n ); void spmIntSort2Asc1( void * const pbase, const spm_int_t n ); void spmIntSort2Asc2( void * const pbase, const spm_int_t n ); +void spmIntMSortIntAsc(void ** const pbase, const spm_int_t n); void spmIntMSortIntAsc(void ** const pbase, const spm_int_t n); diff --git a/src/spm.c b/src/spm.c index b2737556..e49114d7 100644 --- a/src/spm.c +++ b/src/spm.c @@ -1015,7 +1015,7 @@ spmMatVec( spm_trans_t trans, spmatrix_t *espm = (spmatrix_t*)spm; int rc = SPM_SUCCESS; - if ( spm->fmttype != SpmCSC ) { + if ( spm->fmttype != SpmCSC && spm->fmttype != SpmCSR && spm->fmttype != SpmIJV ) { return SPM_ERR_BADPARAMETER; } diff --git a/src/z_spm_matrixvector.c b/src/z_spm_matrixvector.c index f0fc6c63..dd2a9776 100644 --- a/src/z_spm_matrixvector.c +++ b/src/z_spm_matrixvector.c @@ -17,6 +17,317 @@ #include "common.h" #include "z_spm.h" +typedef void (*z_vectorUpdater_t)(const spm_complex64_t alpha, + const spm_int_t baseval, + const spm_int_t pos, + const spm_int_t row, + const spm_int_t col, + const spm_complex64_t *x, + const spm_complex64_t *val, + spm_complex64_t *y); + +spm_complex64_t z_idFunc(spm_complex64_t val) +{ + return val; +} + +void z_updateVectCore(const spm_complex64_t alpha, + const spm_int_t baseval, + const spm_int_t pos, + const spm_int_t idy, + const spm_int_t idx, + const spm_complex64_t *x, + const spm_complex64_t *val, + spm_complex64_t *y, + spm_complex64_t (*conj_func)(spm_complex64_t)) +{ + y[idy] += alpha * conj_func(val[pos-baseval]) * x[idx]; +} + + +void z_updateVectNoTrans(const spm_complex64_t alpha, + const spm_int_t baseval, + const spm_int_t pos, + const spm_int_t row, + const spm_int_t col, + const spm_complex64_t *x, + const spm_complex64_t *val, + spm_complex64_t *y) +{ + z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc); +} + +void z_updateVectTrans(const spm_complex64_t alpha, + const spm_int_t baseval, + const spm_int_t pos, + const spm_int_t row, + const spm_int_t col, + const spm_complex64_t *x, + const spm_complex64_t *val, + spm_complex64_t *y) +{ + z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,z_idFunc); +} + +#if defined(PRECISION_c) || defined(PRECISION_z) +void z_updateVectConjTrans(const spm_complex64_t alpha, + const spm_int_t baseval, + const spm_int_t pos, + const spm_int_t row, + const spm_int_t col, + const spm_complex64_t *x, + const spm_complex64_t *val, + spm_complex64_t *y) +{ + z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,conj); +} +#endif + +void z_updateVectSy(const spm_complex64_t alpha, + const spm_int_t baseval, + const spm_int_t pos, + const spm_int_t row, + const spm_int_t col, + const spm_complex64_t *x, + const spm_complex64_t *val, + spm_complex64_t *y) +{ + z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc); + if( col != row ) + { + z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,z_idFunc); + } +} + +#if defined(PRECISION_c) || defined(PRECISION_z) +void z_updateVectHe(const spm_complex64_t alpha, + const spm_int_t baseval, + const spm_int_t pos, + const spm_int_t row, + const spm_int_t col, + const spm_complex64_t *x, + const spm_complex64_t *val, + spm_complex64_t *y) +{ + if( col != row ) + { + z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc); + z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,conj); + } + else + { + z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,conj); + } +} +#endif + +int z_loopMatCSC(const spm_int_t baseval, + const spm_complex64_t alpha, + const spmatrix_t *spm, + const spm_complex64_t *x, + spm_complex64_t *yptr, + z_vectorUpdater_t updateVect) +{ + const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values); + const spm_complex64_t *xptr = (const spm_complex64_t*)x; + spm_int_t col, row, i; + + for( col=0; col < spm->gN; col++ ) + { + for( i=spm->colptr[col]; i<spm->colptr[col+1]; i++ ) + { + row = spm->rowptr[i-baseval]-baseval; + updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr); + } + } + return SPM_SUCCESS; +} + +int z_loopMatCSR(const spm_int_t baseval, + const spm_complex64_t alpha, + const spmatrix_t *spm, + const spm_complex64_t *x, + spm_complex64_t *yptr, + z_vectorUpdater_t updateVect) +{ + const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values); + const spm_complex64_t *xptr = (const spm_complex64_t*)x; + spm_int_t col, row, i; + + for( row=0; row < spm->gN; row++ ) + { + for( i=spm->rowptr[row]; i<spm->rowptr[row+1]; i++ ) + { + col = spm->colptr[i-baseval]-baseval; + updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr); + } + } + return SPM_SUCCESS; +} + + +int z_loopMatIJV(const spm_int_t baseval, + const spm_complex64_t alpha, + const spmatrix_t *spm, + const spm_complex64_t *x, + spm_complex64_t *yptr, + z_vectorUpdater_t updateVect) +{ + const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values); + const spm_complex64_t *xptr = (const spm_complex64_t*)x; + spm_int_t col, row, i, upperBound; + + upperBound = spm->gnnz+baseval; + for( i=baseval; i < upperBound; i++ ) + { + row = spm->rowptr[i-baseval]-baseval; + col = spm->colptr[i-baseval]-baseval; + updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr); + } + return SPM_SUCCESS; +} + +/** + ******************************************************************************* + * + * @ingroup spm_dev_matvec + * + * @brief compute the matrix-vector product: + * y = alpha * op( A ) * x + beta * y + * + * A is a SpmGeneral spm, where op( X ) is one of + * + * op( X ) = X or op( X ) = X' or op( X ) = conjg( X' ) + * + * alpha and beta are scalars, and x and y are vectors. + * + ******************************************************************************* + * + * @param[in] trans + * Specifies whether the matrix spm is transposed, not transposed or + * conjugate transposed: + * = SpmNoTrans: A is not transposed; + * = SpmTrans: A is transposed; + * = SpmConjTrans: A is conjugate transposed. + * + * @param[in] alpha + * alpha specifies the scalar alpha + * + * @param[in] spm + * The SpmGeneral spm. + * + * @param[in] x + * The vector x. + * + * @param[in] beta + * beta specifies the scalar beta + * + * @param[inout] y + * The vector y. + * + ******************************************************************************* + * + * @retval SPM_SUCCESS if the y vector has been computed succesfully, + * @retval SPM_ERR_BADPARAMETER otherwise. + * + *******************************************************************************/ +int +z_spmv(const spm_trans_t trans, + spm_complex64_t alpha, + const spmatrix_t *spm, + const spm_complex64_t *x, + spm_complex64_t beta, + spm_complex64_t *y ) +{ + spm_complex64_t *yptr = (spm_complex64_t*)y; + spm_int_t baseval, i; + spm_int_t (*getRow(spm_int_t,spmatrix_t)); + spm_int_t (*getCol(spm_int_t,spmatrix_t)); + const spm_fmttype_t fmt = spm->fmttype; + const spm_mtxtype_t mtxtype = spm->mtxtype; + z_vectorUpdater_t updateVect; + + if ( (spm == NULL) || (x == NULL) || (y == NULL ) ) + { + return SPM_ERR_BADPARAMETER; + } + + + if( mtxtype == SpmGeneral ) + { + /** + * Select the appropriate vector updater + */ + if( trans == SpmNoTrans ) + { + updateVect=&z_updateVectNoTrans; + } + /** + * SpmTrans + */ + else if( trans == SpmTrans ) + { + updateVect=&z_updateVectTrans; + } +#if defined(PRECISION_c) || defined(PRECISION_z) + /** + * SpmConjTrans + */ + else if( trans == SpmConjTrans ) + { + updateVect=&z_updateVectConjTrans; + } +#endif + else + { + return SPM_ERR_BADPARAMETER; + } + } + else if( mtxtype == SpmSymmetric ) + { + updateVect=&z_updateVectSy; + } +#if defined(PRECISION_z) || defined(PRECISION_c) + else if( mtxtype == SpmHermitian ) + { + updateVect=&z_updateVectHe; + } +#endif + else + return SPM_ERR_BADPARAMETER; + + /* first, y = beta*y */ + if( beta == 0. ) { + memset( yptr, 0, spm->gN * sizeof(spm_complex64_t) ); + } + else { + for( i=0; i<spm->gN; i++, yptr++ ) { + (*yptr) *= beta; + } + yptr = y; + } + + baseval = spmFindBase( spm ); + if( alpha != 0. ) { + /** + * Select the appropriate matrix looper + */ + if( fmt == SpmCSC) + { + return z_loopMatCSC(baseval, alpha, spm, x, yptr, updateVect); + } + else if( fmt == SpmCSR ) + { + return z_loopMatCSR(baseval, alpha, spm, x, yptr, updateVect); + } + else if( fmt == SpmIJV ) + { + return z_loopMatIJV(baseval, alpha, spm, x, yptr, updateVect); + } + } + return SPM_ERR_BADPARAMETER; +} + /** ******************************************************************************* * @@ -148,6 +459,7 @@ z_spmGeCSCv(const spm_trans_t trans, return SPM_SUCCESS; } + /** ******************************************************************************* * @@ -378,17 +690,18 @@ z_spmCSCMatVec(const spm_trans_t trans, alpha = *((const spm_complex64_t *)alphaptr); beta = *((const spm_complex64_t *)betaptr); - switch (spm->mtxtype) { -#if defined(PRECISION_z) || defined(PRECISION_c) - case SpmHermitian: - return z_spmHeCSCv( alpha, spm, x, beta, y ); -#endif - case SpmSymmetric: - return z_spmSyCSCv( alpha, spm, x, beta, y ); - case SpmGeneral: - default: - return z_spmGeCSCv( trans, alpha, spm, x, beta, y ); - } +// switch (spm->mtxtype) { +//#if defined(PRECISION_z) || defined(PRECISION_c) +// case SpmHermitian: +// return z_spmHeCSCv( alpha, spm, x, beta, y ); +//#endif +// case SpmSymmetric: +// return z_spmSyCSCv( alpha, spm, x, beta, y ); +// case SpmGeneral: +// default: +// return z_spmGeCSCv( trans, alpha, spm, x, beta, y ); +// } + return z_spmv( trans, alpha, spm, x, beta, y); } /** @@ -463,24 +776,27 @@ z_spmCSCMatMat(const spm_trans_t trans, alpha = *((const spm_complex64_t *)alphaptr); beta = *((const spm_complex64_t *)betaptr); - switch (A->mtxtype) { -#if defined(PRECISION_z) || defined(PRECISION_c) - case SpmHermitian: - for( i=0; i<n; i++ ){ - rc = z_spmHeCSCv( alpha, A, B + i * ldb, beta, C + i *ldc ); - } - break; -#endif - case SpmSymmetric: - for( i=0; i<n; i++ ){ - rc = z_spmSyCSCv( alpha, A, B + i * ldb, beta, C + i *ldc ); - } - break; - case SpmGeneral: - default: - for( i=0; i<n; i++ ){ - rc = z_spmGeCSCv( trans, alpha, A, B + i * ldb, beta, C + i *ldc ); - } +// switch (A->mtxtype) { +//#if defined(PRECISION_z) || defined(PRECISION_c) +// case SpmHermitian: +// for( i=0; i<n; i++ ){ +// rc = z_spmHeCSCv( alpha, A, B + i * ldb, beta, C + i *ldc ); +// } +// break; +//#endif +// case SpmSymmetric: +// for( i=0; i<n; i++ ){ +// rc = z_spmSyCSCv( alpha, A, B + i * ldb, beta, C + i *ldc ); +// } +// break; +// case SpmGeneral: +// default: +// for( i=0; i<n; i++ ){ +// rc = z_spmGeCSCv( trans, alpha, A, B + i * ldb, beta, C + i *ldc ); +// } +// } + for( i=0; i<n; i++ ){ + rc = z_spmv( trans, alpha, A, B + i * ldb, beta, C + i *ldc ); } return rc; } diff --git a/tests/spm_matvec_tests.c b/tests/spm_matvec_tests.c index 7ad94e56..9a1f894f 100644 --- a/tests/spm_matvec_tests.c +++ b/tests/spm_matvec_tests.c @@ -39,13 +39,14 @@ int s_spm_matvec_check( int trans, const spmatrix_t *spm ); char* fltnames[] = { "Pattern", "", "Float", "Double", "Complex32", "Complex64" }; char* transnames[] = { "NoTrans", "Trans", "ConjTrans" }; char* mtxnames[] = { "General", "Symmetric", "Hermitian" }; +char* mtxfmts[] = { "CSC", "CSR", "IJV" }; int main (int argc, char **argv) { spmatrix_t spm; spm_driver_t driver; char *filename; - int t,spmtype, mtxtype, baseval; + int t,spmtype, mtxtype, mtxfmt, baseval; int rc = SPM_SUCCESS; int err = 0; @@ -67,7 +68,6 @@ int main (int argc, char **argv) /** * Only CSC is supported for now */ - spmConvert( SpmCSC, &spm ); spmtype = spm.mtxtype; printf(" -- SPM Matrix-Vector Test --\n"); @@ -77,56 +77,59 @@ int main (int argc, char **argv) { printf(" Baseval : %d\n", baseval ); spmBase( &spm, baseval ); - - for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ ) + for( mtxfmt=SpmCSC; mtxfmt<=SpmIJV; mtxfmt++ ) { - if ( (mtxtype == SpmHermitian) && - ( ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)) || - (spmtype != SpmHermitian) ) ) - { - continue; - } - if ( (mtxtype != SpmGeneral) && - (spmtype == SpmGeneral) ) - { - continue; - } - spm.mtxtype = mtxtype; - - for( t=SpmNoTrans; t<=SpmConjTrans; t++ ) + spmConvert( mtxfmt, &spm ); + for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ ) { - if ( (t == SpmConjTrans) && - ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32))) + if ( (mtxtype == SpmHermitian) && + ( ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)) || + (spmtype != SpmHermitian) ) ) { continue; } - if ( (spm.mtxtype != SpmGeneral) && (t != SpmNoTrans) ) + if ( (mtxtype != SpmGeneral) && + (spmtype == SpmGeneral) ) { continue; } + spm.mtxtype = mtxtype; - printf(" Case %s - %d - %s:\n", - mtxnames[mtxtype - SpmGeneral], baseval, - transnames[t - SpmNoTrans] ); - - switch( spm.flttype ){ - case SpmComplex64: - rc = z_spm_matvec_check( t, &spm ); - break; - - case SpmComplex32: - rc = c_spm_matvec_check( t, &spm ); - break; - - case SpmFloat: - rc = s_spm_matvec_check( t, &spm ); - break; - - case SpmDouble: - default: - rc = d_spm_matvec_check( t, &spm ); + for( t=SpmNoTrans; t<=SpmConjTrans; t++ ) + { + if ( (t == SpmConjTrans) && + ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32))) + { + continue; + } + if ( (spm.mtxtype != SpmGeneral) && (t != SpmNoTrans) ) + { + continue; + } + + printf(" Case %s - %s - %d - %s:\n", + mtxnames[mtxtype - SpmGeneral], mtxfmts[mtxfmt - SpmCSC], + baseval, transnames[t - SpmNoTrans] ); + + switch( spm.flttype ){ + case SpmComplex64: + rc = z_spm_matvec_check( t, &spm ); + break; + + case SpmComplex32: + rc = c_spm_matvec_check( t, &spm ); + break; + + case SpmFloat: + rc = s_spm_matvec_check( t, &spm ); + break; + + case SpmDouble: + default: + rc = d_spm_matvec_check( t, &spm ); + } + PRINT_RES(rc); } - PRINT_RES(rc); } } } -- GitLab