Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 96d9c0fb authored by KUHN Matthieu's avatar KUHN Matthieu Committed by Mathieu Faverge
Browse files

Matrix vector product for all Spm storage formats using function pointers

parent e58ae509
No related branches found
No related tags found
1 merge request!4Feature/matvec 4all
...@@ -131,6 +131,7 @@ spm_int_t * spmIntConvert( spm_int_t n, int *input ); ...@@ -131,6 +131,7 @@ spm_int_t * spmIntConvert( spm_int_t n, int *input );
void spmIntSort1Asc1( void * const pbase, const spm_int_t n ); void spmIntSort1Asc1( void * const pbase, const spm_int_t n );
void spmIntSort2Asc1( void * const pbase, const spm_int_t n ); void spmIntSort2Asc1( void * const pbase, const spm_int_t n );
void spmIntSort2Asc2( void * const pbase, const spm_int_t n ); void spmIntSort2Asc2( void * const pbase, const spm_int_t n );
void spmIntMSortIntAsc(void ** const pbase, const spm_int_t n);
void spmIntMSortIntAsc(void ** const pbase, const spm_int_t n); void spmIntMSortIntAsc(void ** const pbase, const spm_int_t n);
......
...@@ -1015,7 +1015,7 @@ spmMatVec( spm_trans_t trans, ...@@ -1015,7 +1015,7 @@ spmMatVec( spm_trans_t trans,
spmatrix_t *espm = (spmatrix_t*)spm; spmatrix_t *espm = (spmatrix_t*)spm;
int rc = SPM_SUCCESS; int rc = SPM_SUCCESS;
if ( spm->fmttype != SpmCSC ) { if ( spm->fmttype != SpmCSC && spm->fmttype != SpmCSR && spm->fmttype != SpmIJV ) {
return SPM_ERR_BADPARAMETER; return SPM_ERR_BADPARAMETER;
} }
......
...@@ -17,6 +17,317 @@ ...@@ -17,6 +17,317 @@
#include "common.h" #include "common.h"
#include "z_spm.h" #include "z_spm.h"
typedef void (*z_vectorUpdater_t)(const spm_complex64_t alpha,
const spm_int_t baseval,
const spm_int_t pos,
const spm_int_t row,
const spm_int_t col,
const spm_complex64_t *x,
const spm_complex64_t *val,
spm_complex64_t *y);
spm_complex64_t z_idFunc(spm_complex64_t val)
{
return val;
}
void z_updateVectCore(const spm_complex64_t alpha,
const spm_int_t baseval,
const spm_int_t pos,
const spm_int_t idy,
const spm_int_t idx,
const spm_complex64_t *x,
const spm_complex64_t *val,
spm_complex64_t *y,
spm_complex64_t (*conj_func)(spm_complex64_t))
{
y[idy] += alpha * conj_func(val[pos-baseval]) * x[idx];
}
void z_updateVectNoTrans(const spm_complex64_t alpha,
const spm_int_t baseval,
const spm_int_t pos,
const spm_int_t row,
const spm_int_t col,
const spm_complex64_t *x,
const spm_complex64_t *val,
spm_complex64_t *y)
{
z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc);
}
void z_updateVectTrans(const spm_complex64_t alpha,
const spm_int_t baseval,
const spm_int_t pos,
const spm_int_t row,
const spm_int_t col,
const spm_complex64_t *x,
const spm_complex64_t *val,
spm_complex64_t *y)
{
z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,z_idFunc);
}
#if defined(PRECISION_c) || defined(PRECISION_z)
void z_updateVectConjTrans(const spm_complex64_t alpha,
const spm_int_t baseval,
const spm_int_t pos,
const spm_int_t row,
const spm_int_t col,
const spm_complex64_t *x,
const spm_complex64_t *val,
spm_complex64_t *y)
{
z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,conj);
}
#endif
void z_updateVectSy(const spm_complex64_t alpha,
const spm_int_t baseval,
const spm_int_t pos,
const spm_int_t row,
const spm_int_t col,
const spm_complex64_t *x,
const spm_complex64_t *val,
spm_complex64_t *y)
{
z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc);
if( col != row )
{
z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,z_idFunc);
}
}
#if defined(PRECISION_c) || defined(PRECISION_z)
void z_updateVectHe(const spm_complex64_t alpha,
const spm_int_t baseval,
const spm_int_t pos,
const spm_int_t row,
const spm_int_t col,
const spm_complex64_t *x,
const spm_complex64_t *val,
spm_complex64_t *y)
{
if( col != row )
{
z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc);
z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,conj);
}
else
{
z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,conj);
}
}
#endif
int z_loopMatCSC(const spm_int_t baseval,
const spm_complex64_t alpha,
const spmatrix_t *spm,
const spm_complex64_t *x,
spm_complex64_t *yptr,
z_vectorUpdater_t updateVect)
{
const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values);
const spm_complex64_t *xptr = (const spm_complex64_t*)x;
spm_int_t col, row, i;
for( col=0; col < spm->gN; col++ )
{
for( i=spm->colptr[col]; i<spm->colptr[col+1]; i++ )
{
row = spm->rowptr[i-baseval]-baseval;
updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr);
}
}
return SPM_SUCCESS;
}
int z_loopMatCSR(const spm_int_t baseval,
const spm_complex64_t alpha,
const spmatrix_t *spm,
const spm_complex64_t *x,
spm_complex64_t *yptr,
z_vectorUpdater_t updateVect)
{
const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values);
const spm_complex64_t *xptr = (const spm_complex64_t*)x;
spm_int_t col, row, i;
for( row=0; row < spm->gN; row++ )
{
for( i=spm->rowptr[row]; i<spm->rowptr[row+1]; i++ )
{
col = spm->colptr[i-baseval]-baseval;
updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr);
}
}
return SPM_SUCCESS;
}
int z_loopMatIJV(const spm_int_t baseval,
const spm_complex64_t alpha,
const spmatrix_t *spm,
const spm_complex64_t *x,
spm_complex64_t *yptr,
z_vectorUpdater_t updateVect)
{
const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values);
const spm_complex64_t *xptr = (const spm_complex64_t*)x;
spm_int_t col, row, i, upperBound;
upperBound = spm->gnnz+baseval;
for( i=baseval; i < upperBound; i++ )
{
row = spm->rowptr[i-baseval]-baseval;
col = spm->colptr[i-baseval]-baseval;
updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr);
}
return SPM_SUCCESS;
}
/**
*******************************************************************************
*
* @ingroup spm_dev_matvec
*
* @brief compute the matrix-vector product:
* y = alpha * op( A ) * x + beta * y
*
* A is a SpmGeneral spm, where op( X ) is one of
*
* op( X ) = X or op( X ) = X' or op( X ) = conjg( X' )
*
* alpha and beta are scalars, and x and y are vectors.
*
*******************************************************************************
*
* @param[in] trans
* Specifies whether the matrix spm is transposed, not transposed or
* conjugate transposed:
* = SpmNoTrans: A is not transposed;
* = SpmTrans: A is transposed;
* = SpmConjTrans: A is conjugate transposed.
*
* @param[in] alpha
* alpha specifies the scalar alpha
*
* @param[in] spm
* The SpmGeneral spm.
*
* @param[in] x
* The vector x.
*
* @param[in] beta
* beta specifies the scalar beta
*
* @param[inout] y
* The vector y.
*
*******************************************************************************
*
* @retval SPM_SUCCESS if the y vector has been computed succesfully,
* @retval SPM_ERR_BADPARAMETER otherwise.
*
*******************************************************************************/
int
z_spmv(const spm_trans_t trans,
spm_complex64_t alpha,
const spmatrix_t *spm,
const spm_complex64_t *x,
spm_complex64_t beta,
spm_complex64_t *y )
{
spm_complex64_t *yptr = (spm_complex64_t*)y;
spm_int_t baseval, i;
spm_int_t (*getRow(spm_int_t,spmatrix_t));
spm_int_t (*getCol(spm_int_t,spmatrix_t));
const spm_fmttype_t fmt = spm->fmttype;
const spm_mtxtype_t mtxtype = spm->mtxtype;
z_vectorUpdater_t updateVect;
if ( (spm == NULL) || (x == NULL) || (y == NULL ) )
{
return SPM_ERR_BADPARAMETER;
}
if( mtxtype == SpmGeneral )
{
/**
* Select the appropriate vector updater
*/
if( trans == SpmNoTrans )
{
updateVect=&z_updateVectNoTrans;
}
/**
* SpmTrans
*/
else if( trans == SpmTrans )
{
updateVect=&z_updateVectTrans;
}
#if defined(PRECISION_c) || defined(PRECISION_z)
/**
* SpmConjTrans
*/
else if( trans == SpmConjTrans )
{
updateVect=&z_updateVectConjTrans;
}
#endif
else
{
return SPM_ERR_BADPARAMETER;
}
}
else if( mtxtype == SpmSymmetric )
{
updateVect=&z_updateVectSy;
}
#if defined(PRECISION_z) || defined(PRECISION_c)
else if( mtxtype == SpmHermitian )
{
updateVect=&z_updateVectHe;
}
#endif
else
return SPM_ERR_BADPARAMETER;
/* first, y = beta*y */
if( beta == 0. ) {
memset( yptr, 0, spm->gN * sizeof(spm_complex64_t) );
}
else {
for( i=0; i<spm->gN; i++, yptr++ ) {
(*yptr) *= beta;
}
yptr = y;
}
baseval = spmFindBase( spm );
if( alpha != 0. ) {
/**
* Select the appropriate matrix looper
*/
if( fmt == SpmCSC)
{
return z_loopMatCSC(baseval, alpha, spm, x, yptr, updateVect);
}
else if( fmt == SpmCSR )
{
return z_loopMatCSR(baseval, alpha, spm, x, yptr, updateVect);
}
else if( fmt == SpmIJV )
{
return z_loopMatIJV(baseval, alpha, spm, x, yptr, updateVect);
}
}
return SPM_ERR_BADPARAMETER;
}
/** /**
******************************************************************************* *******************************************************************************
* *
...@@ -148,6 +459,7 @@ z_spmGeCSCv(const spm_trans_t trans, ...@@ -148,6 +459,7 @@ z_spmGeCSCv(const spm_trans_t trans,
return SPM_SUCCESS; return SPM_SUCCESS;
} }
/** /**
******************************************************************************* *******************************************************************************
* *
...@@ -378,17 +690,18 @@ z_spmCSCMatVec(const spm_trans_t trans, ...@@ -378,17 +690,18 @@ z_spmCSCMatVec(const spm_trans_t trans,
alpha = *((const spm_complex64_t *)alphaptr); alpha = *((const spm_complex64_t *)alphaptr);
beta = *((const spm_complex64_t *)betaptr); beta = *((const spm_complex64_t *)betaptr);
switch (spm->mtxtype) { // switch (spm->mtxtype) {
#if defined(PRECISION_z) || defined(PRECISION_c) //#if defined(PRECISION_z) || defined(PRECISION_c)
case SpmHermitian: // case SpmHermitian:
return z_spmHeCSCv( alpha, spm, x, beta, y ); // return z_spmHeCSCv( alpha, spm, x, beta, y );
#endif //#endif
case SpmSymmetric: // case SpmSymmetric:
return z_spmSyCSCv( alpha, spm, x, beta, y ); // return z_spmSyCSCv( alpha, spm, x, beta, y );
case SpmGeneral: // case SpmGeneral:
default: // default:
return z_spmGeCSCv( trans, alpha, spm, x, beta, y ); // return z_spmGeCSCv( trans, alpha, spm, x, beta, y );
} // }
return z_spmv( trans, alpha, spm, x, beta, y);
} }
/** /**
...@@ -463,24 +776,27 @@ z_spmCSCMatMat(const spm_trans_t trans, ...@@ -463,24 +776,27 @@ z_spmCSCMatMat(const spm_trans_t trans,
alpha = *((const spm_complex64_t *)alphaptr); alpha = *((const spm_complex64_t *)alphaptr);
beta = *((const spm_complex64_t *)betaptr); beta = *((const spm_complex64_t *)betaptr);
switch (A->mtxtype) { // switch (A->mtxtype) {
#if defined(PRECISION_z) || defined(PRECISION_c) //#if defined(PRECISION_z) || defined(PRECISION_c)
case SpmHermitian: // case SpmHermitian:
for( i=0; i<n; i++ ){ // for( i=0; i<n; i++ ){
rc = z_spmHeCSCv( alpha, A, B + i * ldb, beta, C + i *ldc ); // rc = z_spmHeCSCv( alpha, A, B + i * ldb, beta, C + i *ldc );
} // }
break; // break;
#endif //#endif
case SpmSymmetric: // case SpmSymmetric:
for( i=0; i<n; i++ ){ // for( i=0; i<n; i++ ){
rc = z_spmSyCSCv( alpha, A, B + i * ldb, beta, C + i *ldc ); // rc = z_spmSyCSCv( alpha, A, B + i * ldb, beta, C + i *ldc );
} // }
break; // break;
case SpmGeneral: // case SpmGeneral:
default: // default:
for( i=0; i<n; i++ ){ // for( i=0; i<n; i++ ){
rc = z_spmGeCSCv( trans, alpha, A, B + i * ldb, beta, C + i *ldc ); // rc = z_spmGeCSCv( trans, alpha, A, B + i * ldb, beta, C + i *ldc );
} // }
// }
for( i=0; i<n; i++ ){
rc = z_spmv( trans, alpha, A, B + i * ldb, beta, C + i *ldc );
} }
return rc; return rc;
} }
...@@ -39,13 +39,14 @@ int s_spm_matvec_check( int trans, const spmatrix_t *spm ); ...@@ -39,13 +39,14 @@ int s_spm_matvec_check( int trans, const spmatrix_t *spm );
char* fltnames[] = { "Pattern", "", "Float", "Double", "Complex32", "Complex64" }; char* fltnames[] = { "Pattern", "", "Float", "Double", "Complex32", "Complex64" };
char* transnames[] = { "NoTrans", "Trans", "ConjTrans" }; char* transnames[] = { "NoTrans", "Trans", "ConjTrans" };
char* mtxnames[] = { "General", "Symmetric", "Hermitian" }; char* mtxnames[] = { "General", "Symmetric", "Hermitian" };
char* mtxfmts[] = { "CSC", "CSR", "IJV" };
int main (int argc, char **argv) int main (int argc, char **argv)
{ {
spmatrix_t spm; spmatrix_t spm;
spm_driver_t driver; spm_driver_t driver;
char *filename; char *filename;
int t,spmtype, mtxtype, baseval; int t,spmtype, mtxtype, mtxfmt, baseval;
int rc = SPM_SUCCESS; int rc = SPM_SUCCESS;
int err = 0; int err = 0;
...@@ -67,7 +68,6 @@ int main (int argc, char **argv) ...@@ -67,7 +68,6 @@ int main (int argc, char **argv)
/** /**
* Only CSC is supported for now * Only CSC is supported for now
*/ */
spmConvert( SpmCSC, &spm );
spmtype = spm.mtxtype; spmtype = spm.mtxtype;
printf(" -- SPM Matrix-Vector Test --\n"); printf(" -- SPM Matrix-Vector Test --\n");
...@@ -77,56 +77,59 @@ int main (int argc, char **argv) ...@@ -77,56 +77,59 @@ int main (int argc, char **argv)
{ {
printf(" Baseval : %d\n", baseval ); printf(" Baseval : %d\n", baseval );
spmBase( &spm, baseval ); spmBase( &spm, baseval );
for( mtxfmt=SpmCSC; mtxfmt<=SpmIJV; mtxfmt++ )
for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ )
{ {
if ( (mtxtype == SpmHermitian) && spmConvert( mtxfmt, &spm );
( ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)) || for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ )
(spmtype != SpmHermitian) ) )
{
continue;
}
if ( (mtxtype != SpmGeneral) &&
(spmtype == SpmGeneral) )
{
continue;
}
spm.mtxtype = mtxtype;
for( t=SpmNoTrans; t<=SpmConjTrans; t++ )
{ {
if ( (t == SpmConjTrans) && if ( (mtxtype == SpmHermitian) &&
((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32))) ( ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)) ||
(spmtype != SpmHermitian) ) )
{ {
continue; continue;
} }
if ( (spm.mtxtype != SpmGeneral) && (t != SpmNoTrans) ) if ( (mtxtype != SpmGeneral) &&
(spmtype == SpmGeneral) )
{ {
continue; continue;
} }
spm.mtxtype = mtxtype;
printf(" Case %s - %d - %s:\n", for( t=SpmNoTrans; t<=SpmConjTrans; t++ )
mtxnames[mtxtype - SpmGeneral], baseval, {
transnames[t - SpmNoTrans] ); if ( (t == SpmConjTrans) &&
((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)))
switch( spm.flttype ){ {
case SpmComplex64: continue;
rc = z_spm_matvec_check( t, &spm ); }
break; if ( (spm.mtxtype != SpmGeneral) && (t != SpmNoTrans) )
{
case SpmComplex32: continue;
rc = c_spm_matvec_check( t, &spm ); }
break;
printf(" Case %s - %s - %d - %s:\n",
case SpmFloat: mtxnames[mtxtype - SpmGeneral], mtxfmts[mtxfmt - SpmCSC],
rc = s_spm_matvec_check( t, &spm ); baseval, transnames[t - SpmNoTrans] );
break;
switch( spm.flttype ){
case SpmDouble: case SpmComplex64:
default: rc = z_spm_matvec_check( t, &spm );
rc = d_spm_matvec_check( t, &spm ); break;
case SpmComplex32:
rc = c_spm_matvec_check( t, &spm );
break;
case SpmFloat:
rc = s_spm_matvec_check( t, &spm );
break;
case SpmDouble:
default:
rc = d_spm_matvec_check( t, &spm );
}
PRINT_RES(rc);
} }
PRINT_RES(rc);
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment