Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 5f4375dc authored by KUHN Matthieu's avatar KUHN Matthieu Committed by Mathieu Faverge
Browse files

Removed unnecessary functions. Changed interfaces for matvec to spm_?_spmv and matmat to spm_?_spmm

parent c10a83e1
No related branches found
No related tags found
1 merge request!4Feature/matvec 4all
...@@ -1024,17 +1024,17 @@ spmMatVec( spm_trans_t trans, ...@@ -1024,17 +1024,17 @@ spmMatVec( spm_trans_t trans,
} }
switch (spm->flttype) { switch (spm->flttype) {
case SpmFloat: case SpmFloat:
rc = s_spmCSCMatVec( trans, alpha, espm, x, beta, y ); rc = spm_s_spmv( trans, alpha, espm, x, beta, y );
break; break;
case SpmComplex32: case SpmComplex32:
rc = c_spmCSCMatVec( trans, alpha, espm, x, beta, y ); rc = spm_c_spmv( trans, alpha, espm, x, beta, y );
break; break;
case SpmComplex64: case SpmComplex64:
rc = z_spmCSCMatVec( trans, alpha, espm, x, beta, y ); rc = spm_z_spmv( trans, alpha, espm, x, beta, y );
break; break;
case SpmDouble: case SpmDouble:
default: default:
rc = d_spmCSCMatVec( trans, alpha, espm, x, beta, y ); rc = spm_d_spmv( trans, alpha, espm, x, beta, y );
} }
if ( spm != espm ) { if ( spm != espm ) {
...@@ -1118,17 +1118,17 @@ spmMatMat( spm_trans_t trans, ...@@ -1118,17 +1118,17 @@ spmMatMat( spm_trans_t trans,
} }
switch (A->flttype) { switch (A->flttype) {
case SpmFloat: case SpmFloat:
rc = s_spmCSCMatMat( trans, n, alpha, espm, B, ldb, beta, C, ldc ); rc = spm_s_spmm( trans, n, alpha, espm, B, ldb, beta, C, ldc );
break; break;
case SpmComplex32: case SpmComplex32:
rc = c_spmCSCMatMat( trans, n, alpha, espm, B, ldb, beta, C, ldc ); rc = spm_c_spmm( trans, n, alpha, espm, B, ldb, beta, C, ldc );
break; break;
case SpmComplex64: case SpmComplex64:
rc = z_spmCSCMatMat( trans, n, alpha, espm, B, ldb, beta, C, ldc ); rc = spm_z_spmm( trans, n, alpha, espm, B, ldb, beta, C, ldc );
break; break;
case SpmDouble: case SpmDouble:
default: default:
rc = d_spmCSCMatMat( trans, n, alpha, espm, B, ldb, beta, C, ldc ); rc = spm_d_spmm( trans, n, alpha, espm, B, ldb, beta, C, ldc );
break; break;
} }
......
...@@ -41,8 +41,8 @@ spm_complex64_t *z_spm2dense( const spmatrix_t *spm ); ...@@ -41,8 +41,8 @@ spm_complex64_t *z_spm2dense( const spmatrix_t *spm );
/** /**
* Matrix-Vector and matrix-matrix product routines * Matrix-Vector and matrix-matrix product routines
*/ */
int z_spmCSCMatVec(const spm_trans_t trans, const void *alpha, const spmatrix_t *spm, const void *x, const void *beta, void *y); int spm_z_spmv(const spm_trans_t trans, const void *alpha, const spmatrix_t *spm, const void *x, const void *beta, void *y);
int z_spmCSCMatMat(const spm_trans_t trans, spm_int_t n, const void *alpha, const spmatrix_t *A, const void *B, spm_int_t ldb, const void *beta, void *Cptr, spm_int_t ldc ); int spm_z_spmm(const spm_trans_t trans, spm_int_t n, const void *alpha, const spmatrix_t *A, const void *B, spm_int_t ldb, const void *beta, void *Cptr, spm_int_t ldc );
/** /**
* Norm computation routines * Norm computation routines
......
...@@ -296,7 +296,7 @@ z_spmGenRHS( spm_rhstype_t type, int nrhs, ...@@ -296,7 +296,7 @@ z_spmGenRHS( spm_rhstype_t type, int nrhs,
} }
/* Compute B */ /* Compute B */
rc = z_spmCSCMatMat( SpmNoTrans, nrhs, &zone, spm, xptr, ldx, &zzero, bptr, ldb ); rc = spm_z_spmm( SpmNoTrans, nrhs, &zone, spm, xptr, ldx, &zzero, bptr, ldb );
if ( x == NULL ) { if ( x == NULL ) {
free(xptr); free(xptr);
......
...@@ -193,36 +193,29 @@ int z_loopMatIJV(const spm_int_t baseval, ...@@ -193,36 +193,29 @@ int z_loopMatIJV(const spm_int_t baseval,
* @ingroup spm_dev_matvec * @ingroup spm_dev_matvec
* *
* @brief compute the matrix-vector product: * @brief compute the matrix-vector product:
* y = alpha * op( A ) * x + beta * y * y = alpha * A + beta * y
*
* A is a SpmGeneral spm, where op( X ) is one of
*
* op( X ) = X or op( X ) = X' or op( X ) = conjg( X' )
* *
* alpha and beta are scalars, and x and y are vectors. * A is a SpmHermitian spm, alpha and beta are scalars, and x and y are
* vectors, and A a symm.
* *
******************************************************************************* *******************************************************************************
* *
* @param[in] trans * @param[in] trans
* Specifies whether the matrix spm is transposed, not transposed or * TODO
* conjugate transposed:
* = SpmNoTrans: A is not transposed;
* = SpmTrans: A is transposed;
* = SpmConjTrans: A is conjugate transposed.
* *
* @param[in] alpha * @param[in] alphaptr
* alpha specifies the scalar alpha * alpha specifies the scalar alpha
* *
* @param[in] spm * @param[in] spm
* The SpmGeneral spm. * The SpmHermitian spm.
* *
* @param[in] x * @param[in] xptr
* The vector x. * The vector x.
* *
* @param[in] beta * @param[in] betaptr
* beta specifies the scalar beta * beta specifies the scalar beta
* *
* @param[inout] y * @param[inout] yptr
* The vector y. * The vector y.
* *
******************************************************************************* *******************************************************************************
...@@ -232,21 +225,25 @@ int z_loopMatIJV(const spm_int_t baseval, ...@@ -232,21 +225,25 @@ int z_loopMatIJV(const spm_int_t baseval,
* *
*******************************************************************************/ *******************************************************************************/
int int
z_spmv(const spm_trans_t trans, spm_z_spmv(const spm_trans_t trans,
spm_complex64_t alpha, const void *alphaptr,
const spmatrix_t *spm, const spmatrix_t *spm,
const spm_complex64_t *x, const void *xptr,
spm_complex64_t beta, const void *betaptr,
spm_complex64_t *y ) void *yptr )
{ {
spm_complex64_t *yptr = (spm_complex64_t*)y; const spm_complex64_t *x = (const spm_complex64_t*)xptr;
spm_complex64_t *y = (spm_complex64_t*)yptr;
spm_complex64_t alpha, beta;
spm_int_t baseval, i; spm_int_t baseval, i;
spm_int_t (*getRow(spm_int_t,spmatrix_t));
spm_int_t (*getCol(spm_int_t,spmatrix_t));
const spm_fmttype_t fmt = spm->fmttype; const spm_fmttype_t fmt = spm->fmttype;
const spm_mtxtype_t mtxtype = spm->mtxtype; const spm_mtxtype_t mtxtype = spm->mtxtype;
z_vectorUpdater_t updateVect; z_vectorUpdater_t updateVect;
alpha = *((const spm_complex64_t *)alphaptr);
beta = *((const spm_complex64_t *)betaptr);
if ( (spm == NULL) || (x == NULL) || (y == NULL ) ) if ( (spm == NULL) || (x == NULL) || (y == NULL ) )
{ {
return SPM_ERR_BADPARAMETER; return SPM_ERR_BADPARAMETER;
...@@ -300,31 +297,31 @@ z_spmv(const spm_trans_t trans, ...@@ -300,31 +297,31 @@ z_spmv(const spm_trans_t trans,
/* first, y = beta*y */ /* first, y = beta*y */
if( beta == 0. ) { if( beta == 0. ) {
memset( yptr, 0, spm->gN * sizeof(spm_complex64_t) ); memset( y, 0, spm->gN * sizeof(spm_complex64_t) );
} }
else { else {
for( i=0; i<spm->gN; i++, yptr++ ) { for( i=0; i<spm->gN; i++, y++ ) {
(*yptr) *= beta; (*y) *= beta;
} }
yptr = y; y = yptr;
} }
baseval = spmFindBase( spm ); baseval = spmFindBase( spm );
if( alpha != 0. ) { if( alpha != 0. ) {
/** /**
* Select the appropriate matrix looper * Select the appropriate matrix looper depending on matrix format
*/ */
if( fmt == SpmCSC ) if( fmt == SpmCSC )
{ {
return z_loopMatCSC(baseval, alpha, spm, x, yptr, updateVect); return z_loopMatCSC(baseval, alpha, spm, x, y, updateVect);
} }
else if( fmt == SpmCSR ) else if( fmt == SpmCSR )
{ {
return z_loopMatCSR(baseval, alpha, spm, x, yptr, updateVect); return z_loopMatCSR(baseval, alpha, spm, x, y, updateVect);
} }
else if( fmt == SpmIJV ) else if( fmt == SpmIJV )
{ {
return z_loopMatIJV(baseval, alpha, spm, x, yptr, updateVect); return z_loopMatIJV(baseval, alpha, spm, x, y, updateVect);
} }
else else
{ {
...@@ -334,382 +331,6 @@ z_spmv(const spm_trans_t trans, ...@@ -334,382 +331,6 @@ z_spmv(const spm_trans_t trans,
return SPM_ERR_BADPARAMETER; return SPM_ERR_BADPARAMETER;
} }
/**
*******************************************************************************
*
* @ingroup spm_dev_matvec
*
* @brief compute the matrix-vector product:
* y = alpha * op( A ) * x + beta * y
*
* A is a SpmGeneral spm, where op( X ) is one of
*
* op( X ) = X or op( X ) = X' or op( X ) = conjg( X' )
*
* alpha and beta are scalars, and x and y are vectors.
*
*******************************************************************************
*
* @param[in] trans
* Specifies whether the matrix spm is transposed, not transposed or
* conjugate transposed:
* = SpmNoTrans: A is not transposed;
* = SpmTrans: A is transposed;
* = SpmConjTrans: A is conjugate transposed.
*
* @param[in] alpha
* alpha specifies the scalar alpha
*
* @param[in] spm
* The SpmGeneral spm.
*
* @param[in] x
* The vector x.
*
* @param[in] beta
* beta specifies the scalar beta
*
* @param[inout] y
* The vector y.
*
*******************************************************************************
*
* @retval SPM_SUCCESS if the y vector has been computed succesfully,
* @retval SPM_ERR_BADPARAMETER otherwise.
*
*******************************************************************************/
int
z_spmGeCSCv(const spm_trans_t trans,
spm_complex64_t alpha,
const spmatrix_t *spm,
const spm_complex64_t *x,
spm_complex64_t beta,
spm_complex64_t *y )
{
const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values);
const spm_complex64_t *xptr = (const spm_complex64_t*)x;
spm_complex64_t *yptr = (spm_complex64_t*)y;
spm_int_t col, row, i, baseval;
if ( (spm == NULL) || (x == NULL) || (y == NULL ) )
{
return SPM_ERR_BADPARAMETER;
}
if( spm->mtxtype != SpmGeneral )
{
return SPM_ERR_BADPARAMETER;
}
baseval = spmFindBase( spm );
/* first, y = beta*y */
if( beta == 0. ) {
memset( yptr, 0, spm->gN * sizeof(spm_complex64_t) );
}
else {
for( i=0; i<spm->gN; i++, yptr++ ) {
(*yptr) *= beta;
}
yptr = y;
}
if( alpha != 0. ) {
/**
* SpmNoTrans
*/
if( trans == SpmNoTrans )
{
for( col=0; col < spm->gN; col++ )
{
for( i=spm->colptr[col]; i<spm->colptr[col+1]; i++ )
{
row = spm->rowptr[i-baseval]-baseval;
yptr[row] += alpha * valptr[i-baseval] * xptr[col];
}
}
}
/**
* SpmTrans
*/
else if( trans == SpmTrans )
{
for( col=0; col < spm->gN; col++ )
{
for( i=spm->colptr[col]; i<spm->colptr[col+1]; i++ )
{
row = spm->rowptr[i-baseval]-baseval;
yptr[col] += alpha * valptr[i-baseval] * xptr[row];
}
}
}
#if defined(PRECISION_c) || defined(PRECISION_z)
else if( trans == SpmConjTrans )
{
for( col=0; col < spm->gN; col++ )
{
for( i=spm->colptr[col]; i<spm->colptr[col+1]; i++ )
{
row = spm->rowptr[i-baseval]-baseval;
yptr[col] += alpha * conj( valptr[i-baseval] ) * xptr[row];
}
}
}
#endif
else
{
return SPM_ERR_BADPARAMETER;
}
}
return SPM_SUCCESS;
}
/**
*******************************************************************************
*
* @ingroup spm_dev_matvec
*
* @brief compute the matrix-vector product:
* y = alpha * A + beta * y
*
* A is a SpmSymmetric spm, alpha and beta are scalars, and x and y are
* vectors, and A a symm.
*
*******************************************************************************
*
* @param[in] alpha
* alpha specifies the scalar alpha
*
* @param[in] spm
* The SpmSymmetric spm.
*
* @param[in] x
* The vector x.
*
* @param[in] beta
* beta specifies the scalar beta
*
* @param[inout] y
* The vector y.
*
*******************************************************************************
*
* @retval SPM_SUCCESS if the y vector has been computed succesfully,
* @retval SPM_ERR_BADPARAMETER otherwise.
*
*******************************************************************************/
int
z_spmSyCSCv( spm_complex64_t alpha,
const spmatrix_t *spm,
const spm_complex64_t *x,
spm_complex64_t beta,
spm_complex64_t *y )
{
const spm_complex64_t *valptr = (spm_complex64_t*)spm->values;
const spm_complex64_t *xptr = x;
spm_complex64_t *yptr = y;
spm_int_t col, row, i, baseval;
if ( (spm == NULL) || (x == NULL) || (y == NULL ) )
{
return SPM_ERR_BADPARAMETER;
}
if( spm->mtxtype != SpmSymmetric )
{
return SPM_ERR_BADPARAMETER;
}
baseval = spmFindBase( spm );
/* First, y = beta*y */
if( beta == 0. ) {
memset( yptr, 0, spm->gN * sizeof(spm_complex64_t) );
}
else {
for( i=0; i<spm->gN; i++, yptr++ ) {
(*yptr) *= beta;
}
yptr = y;
}
if( alpha != 0. ) {
for( col=0; col < spm->gN; col++ )
{
for( i=spm->colptr[col]; i < spm->colptr[col+1]; i++ )
{
row = spm->rowptr[i-baseval]-baseval;
yptr[row] += alpha * valptr[i-baseval] * xptr[col];
if( col != row )
{
yptr[col] += alpha * valptr[i-baseval] * xptr[row];
}
}
}
}
return SPM_SUCCESS;
}
#if defined(PRECISION_c) || defined(PRECISION_z)
/**
*******************************************************************************
*
* @ingroup spm_dev_matvec
*
* @brief compute the matrix-vector product:
* y = alpha * A + beta * y
*
* A is a SpmHermitian spm, alpha and beta are scalars, and x and y are
* vectors, and A a symm.
*
*******************************************************************************
*
* @param[in] alpha
* alpha specifies the scalar alpha
*
* @param[in] spm
* The SpmHermitian spm.
*
* @param[in] x
* The vector x.
*
* @param[in] beta
* beta specifies the scalar beta
*
* @param[inout] y
* The vector y.
*
*******************************************************************************
*
* @retval SPM_SUCCESS if the y vector has been computed succesfully,
* @retval SPM_ERR_BADPARAMETER otherwise.
*
*******************************************************************************/
int
z_spmHeCSCv( spm_complex64_t alpha,
const spmatrix_t *spm,
const spm_complex64_t *x,
spm_complex64_t beta,
spm_complex64_t *y )
{
const spm_complex64_t *valptr = (spm_complex64_t*)spm->values;
const spm_complex64_t *xptr = x;
spm_complex64_t *yptr = y;
spm_int_t col, row, i, baseval;
if ( (spm == NULL) || (x == NULL) || (y == NULL ) )
{
return SPM_ERR_BADPARAMETER;
}
if( spm->mtxtype != SpmHermitian )
{
return SPM_ERR_BADPARAMETER;
}
/* First, y = beta*y */
if( beta == 0. ) {
memset( yptr, 0, spm->gN * sizeof(spm_complex64_t) );
}
else {
for( i=0; i<spm->gN; i++, yptr++ ) {
(*yptr) *= beta;
}
yptr = y;
}
baseval = spmFindBase( spm );
if( alpha != 0. ) {
for( col=0; col < spm->gN; col++ )
{
for( i=spm->colptr[col]; i < spm->colptr[col+1]; i++ )
{
row=spm->rowptr[i-baseval]-baseval;
if( col != row ) {
yptr[row] += alpha * valptr[i-baseval] * xptr[col];
yptr[col] += alpha * conj( valptr[i-baseval] ) * xptr[row];
}
else {
yptr[row] += alpha * creal(valptr[i-baseval]) * xptr[col];
}
}
}
}
return SPM_SUCCESS;
}
#endif
/**
*******************************************************************************
*
* @ingroup spm_dev_matvec
*
* @brief compute the matrix-vector product:
* y = alpha * A + beta * y
*
* A is a SpmHermitian spm, alpha and beta are scalars, and x and y are
* vectors, and A a symm.
*
*******************************************************************************
*
* @param[in] trans
* TODO
*
* @param[in] alphaptr
* alpha specifies the scalar alpha
*
* @param[in] spm
* The SpmHermitian spm.
*
* @param[in] xptr
* The vector x.
*
* @param[in] betaptr
* beta specifies the scalar beta
*
* @param[inout] yptr
* The vector y.
*
*******************************************************************************
*
* @retval SPM_SUCCESS if the y vector has been computed succesfully,
* @retval SPM_ERR_BADPARAMETER otherwise.
*
*******************************************************************************/
int
z_spmCSCMatVec(const spm_trans_t trans,
const void *alphaptr,
const spmatrix_t *spm,
const void *xptr,
const void *betaptr,
void *yptr )
{
const spm_complex64_t *x = (const spm_complex64_t*)xptr;
spm_complex64_t *y = (spm_complex64_t*)yptr;
spm_complex64_t alpha, beta;
alpha = *((const spm_complex64_t *)alphaptr);
beta = *((const spm_complex64_t *)betaptr);
// switch (spm->mtxtype) {
//#if defined(PRECISION_z) || defined(PRECISION_c)
// case SpmHermitian:
// return z_spmHeCSCv( alpha, spm, x, beta, y );
//#endif
// case SpmSymmetric:
// return z_spmSyCSCv( alpha, spm, x, beta, y );
// case SpmGeneral:
// default:
// return z_spmGeCSCv( trans, alpha, spm, x, beta, y );
// }
return z_spmv( trans, alpha, spm, x, beta, y);
}
/** /**
******************************************************************************* *******************************************************************************
* *
...@@ -764,7 +385,7 @@ z_spmCSCMatVec(const spm_trans_t trans, ...@@ -764,7 +385,7 @@ z_spmCSCMatVec(const spm_trans_t trans,
* *
*******************************************************************************/ *******************************************************************************/
int int
z_spmCSCMatMat(const spm_trans_t trans, spm_z_spmm(const spm_trans_t trans,
spm_int_t n, spm_int_t n,
const void *alphaptr, const void *alphaptr,
const spmatrix_t *A, const spmatrix_t *A,
...@@ -776,33 +397,10 @@ z_spmCSCMatMat(const spm_trans_t trans, ...@@ -776,33 +397,10 @@ z_spmCSCMatMat(const spm_trans_t trans,
{ {
const spm_complex64_t *B = (const spm_complex64_t*)Bptr; const spm_complex64_t *B = (const spm_complex64_t*)Bptr;
spm_complex64_t *C = (spm_complex64_t*)Cptr; spm_complex64_t *C = (spm_complex64_t*)Cptr;
spm_complex64_t alpha, beta;
int i, rc = SPM_SUCCESS; int i, rc = SPM_SUCCESS;
alpha = *((const spm_complex64_t *)alphaptr);
beta = *((const spm_complex64_t *)betaptr);
// switch (A->mtxtype) {
//#if defined(PRECISION_z) || defined(PRECISION_c)
// case SpmHermitian:
// for( i=0; i<n; i++ ){
// rc = z_spmHeCSCv( alpha, A, B + i * ldb, beta, C + i *ldc );
// }
// break;
//#endif
// case SpmSymmetric:
// for( i=0; i<n; i++ ){
// rc = z_spmSyCSCv( alpha, A, B + i * ldb, beta, C + i *ldc );
// }
// break;
// case SpmGeneral:
// default:
// for( i=0; i<n; i++ ){
// rc = z_spmGeCSCv( trans, alpha, A, B + i * ldb, beta, C + i *ldc );
// }
// }
for( i=0; i<n; i++ ){ for( i=0; i<n; i++ ){
rc = z_spmv( trans, alpha, A, B + i * ldb, beta, C + i *ldc ); rc = spm_z_spmv( trans, alphaptr, A, B + i * ldb, betaptr, C + i *ldc );
} }
return rc; return rc;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment