diff --git a/src/z_spm_matrixvector.c b/src/z_spm_matrixvector.c index 00b1407e5a1d5f0daa91616eea61c344ab30ac3d..cf6078b08780acbb26adaf8c770896aa8664fee2 100644 --- a/src/z_spm_matrixvector.c +++ b/src/z_spm_matrixvector.c @@ -53,24 +53,26 @@ struct __spm_zmatvec_s { spm_complex64_t *y; spm_int_t incy; - __conj_fct_t conj_fct; + __conj_fct_t conjA_fct; + __conj_fct_t conjAt_fct; __loop_fct_t loop_fct; }; static inline int __spm_zmatvec_sy_csr( const __spm_zmatvec_t *args ) { - spm_int_t baseval = args->baseval; - spm_int_t n = args->n; - spm_complex64_t alpha = args->alpha; - const spm_int_t *rowptr = args->rowptr; - const spm_int_t *colptr = args->colptr; - const spm_complex64_t *values = args->values; - const spm_complex64_t *x = args->x; - spm_int_t incx = args->incx; - spm_complex64_t *y = args->y; - spm_int_t incy = args->incy; - const __conj_fct_t conj_fct = args->conj_fct; + spm_int_t baseval = args->baseval; + spm_int_t n = args->n; + spm_complex64_t alpha = args->alpha; + const spm_int_t *rowptr = args->rowptr; + const spm_int_t *colptr = args->colptr; + const spm_complex64_t *values = args->values; + const spm_complex64_t *x = args->x; + spm_int_t incx = args->incx; + spm_complex64_t *y = args->y; + spm_int_t incy = args->incy; + const __conj_fct_t conjA_fct = args->conjA_fct; + const __conj_fct_t conjAt_fct = args->conjAt_fct; spm_int_t col, row, i; for( col=0; col<n; col++, colptr++ ) @@ -80,11 +82,11 @@ __spm_zmatvec_sy_csr( const __spm_zmatvec_t *args ) row = *rowptr - baseval; if ( row != col ) { - y[ row * incy ] += alpha * conj_fct( *values ) * x[ col * incx ]; - y[ col * incy ] += alpha * ( *values ) * x[ row * incx ]; + y[ row * incy ] += alpha * conjAt_fct( *values ) * x[ col * incx ]; + y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ]; } else { - y[ col * incy ] += alpha * ( *values ) * x[ row * incx ]; + y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ]; } } } @@ -94,17 +96,18 @@ __spm_zmatvec_sy_csr( const __spm_zmatvec_t *args ) static inline int __spm_zmatvec_sy_csc( const __spm_zmatvec_t *args ) { - spm_int_t baseval = args->baseval; - spm_int_t n = args->n; - spm_complex64_t alpha = args->alpha; - const spm_int_t *rowptr = args->rowptr; - const spm_int_t *colptr = args->colptr; - const spm_complex64_t *values = args->values; - const spm_complex64_t *x = args->x; - spm_int_t incx = args->incx; - spm_complex64_t *y = args->y; - spm_int_t incy = args->incy; - const __conj_fct_t conj_fct = args->conj_fct; + spm_int_t baseval = args->baseval; + spm_int_t n = args->n; + spm_complex64_t alpha = args->alpha; + const spm_int_t *rowptr = args->rowptr; + const spm_int_t *colptr = args->colptr; + const spm_complex64_t *values = args->values; + const spm_complex64_t *x = args->x; + spm_int_t incx = args->incx; + spm_complex64_t *y = args->y; + spm_int_t incy = args->incy; + const __conj_fct_t conjA_fct = args->conjA_fct; + const __conj_fct_t conjAt_fct = args->conjAt_fct; spm_int_t col, row, i; for( col=0; col<n; col++, colptr++ ) @@ -114,11 +117,11 @@ __spm_zmatvec_sy_csc( const __spm_zmatvec_t *args ) row = *rowptr - baseval; if ( row != col ) { - y[ row * incy ] += alpha * ( *values ) * x[ col * incx ]; - y[ col * incy ] += alpha * conj_fct( *values ) * x[ row * incx ]; + y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ]; + y[ col * incy ] += alpha * conjAt_fct( *values ) * x[ row * incx ]; } else { - y[ col * incy ] += alpha * ( *values ) * x[ row * incx ]; + y[ col * incy ] += alpha * conjA_fct( *values ) * x[ row * incx ]; } } } @@ -128,17 +131,17 @@ __spm_zmatvec_sy_csc( const __spm_zmatvec_t *args ) static inline int __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args ) { - spm_int_t baseval = args->baseval; - spm_int_t n = args->n; - spm_complex64_t alpha = args->alpha; - const spm_int_t *rowptr = args->rowptr; - const spm_int_t *colptr = args->colptr; - const spm_complex64_t *values = args->values; - const spm_complex64_t *x = args->x; - spm_int_t incx = args->incx; - spm_complex64_t *y = args->y; - spm_int_t incy = args->incy; - const __conj_fct_t conj_fct = args->conj_fct; + spm_int_t baseval = args->baseval; + spm_int_t n = args->n; + spm_complex64_t alpha = args->alpha; + const spm_int_t *rowptr = args->rowptr; + const spm_int_t *colptr = args->colptr; + const spm_complex64_t *values = args->values; + const spm_complex64_t *x = args->x; + spm_int_t incx = args->incx; + spm_complex64_t *y = args->y; + spm_int_t incy = args->incy; + const __conj_fct_t conjA_fct = args->conjA_fct; spm_int_t col, row, i; if ( args->follow_x ) { @@ -147,7 +150,7 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args ) for( i=colptr[0]; i<colptr[1]; i++, rowptr++, values++ ) { row = *rowptr - baseval; - y[ row * incy ] += alpha * conj_fct( *values ) * (*x); + y[ row * incy ] += alpha * conjA_fct( *values ) * (*x); } } } @@ -157,7 +160,7 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args ) for( i=colptr[0]; i<colptr[1]; i++, rowptr++, values++ ) { row = *rowptr - baseval; - *y += alpha * conj_fct( *values ) * x[ row * incx ]; + *y += alpha * conjA_fct( *values ) * x[ row * incx ]; } } } @@ -167,17 +170,18 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args ) static inline int __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args ) { - spm_int_t baseval = args->baseval; - spm_int_t nnz = args->nnz; - spm_complex64_t alpha = args->alpha; - const spm_int_t *rowptr = args->rowptr; - const spm_int_t *colptr = args->colptr; - const spm_complex64_t *values = args->values; - const spm_complex64_t *x = args->x; - spm_int_t incx = args->incx; - spm_complex64_t *y = args->y; - spm_int_t incy = args->incy; - const __conj_fct_t conj_fct = args->conj_fct; + spm_int_t baseval = args->baseval; + spm_int_t nnz = args->nnz; + spm_complex64_t alpha = args->alpha; + const spm_int_t *rowptr = args->rowptr; + const spm_int_t *colptr = args->colptr; + const spm_complex64_t *values = args->values; + const spm_complex64_t *x = args->x; + spm_int_t incx = args->incx; + spm_complex64_t *y = args->y; + spm_int_t incy = args->incy; + const __conj_fct_t conjA_fct = args->conjA_fct; + const __conj_fct_t conjAt_fct = args->conjAt_fct; spm_int_t col, row, i; for( i=0; i<nnz; i++, colptr++, rowptr++, values++ ) @@ -186,11 +190,11 @@ __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args ) col = *colptr - baseval; if ( row != col ) { - y[ row * incy ] += alpha * ( *values ) * x[ col * incx ]; - y[ col * incy ] += alpha * conj_fct( *values ) * x[ row * incx ]; + y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ]; + y[ col * incy ] += alpha * conjAt_fct( *values ) * x[ row * incx ]; } else { - y[ row * incy ] += alpha * ( *values ) * x[ col * incx ]; + y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ]; } } return SPM_SUCCESS; @@ -199,17 +203,17 @@ __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args ) static inline int __spm_zmatvec_ge_ijv( const __spm_zmatvec_t *args ) { - spm_int_t baseval = args->baseval; - spm_int_t nnz = args->nnz; - spm_complex64_t alpha = args->alpha; - const spm_int_t *rowptr = args->rowptr; - const spm_int_t *colptr = args->colptr; - const spm_complex64_t *values = args->values; - const spm_complex64_t *x = args->x; - spm_int_t incx = args->incx; - spm_complex64_t *y = args->y; - spm_int_t incy = args->incy; - const __conj_fct_t conj_fct = args->conj_fct; + spm_int_t baseval = args->baseval; + spm_int_t nnz = args->nnz; + spm_complex64_t alpha = args->alpha; + const spm_int_t *rowptr = args->rowptr; + const spm_int_t *colptr = args->colptr; + const spm_complex64_t *values = args->values; + const spm_complex64_t *x = args->x; + spm_int_t incx = args->incx; + spm_complex64_t *y = args->y; + spm_int_t incy = args->incy; + const __conj_fct_t conjA_fct = args->conjA_fct; spm_int_t col, row, i; for( i=0; i<nnz; i++, colptr++, rowptr++, values++ ) @@ -217,7 +221,7 @@ __spm_zmatvec_ge_ijv( const __spm_zmatvec_t *args ) row = *rowptr - baseval; col = *colptr - baseval; - y[ row * incy ] += alpha * conj_fct( *values ) * x[ col * incx ]; + y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ]; } return SPM_SUCCESS; } @@ -245,6 +249,117 @@ __spm_zlascl( spm_complex64_t alpha, #endif + +static inline int +__spm_zmatvec_args_init( __spm_zmatvec_t *args, + spm_side_t side, + spm_trans_t transA, + spm_complex64_t alpha, + const spmatrix_t *A, + const spm_complex64_t *B, + spm_int_t ldb, + spm_complex64_t *C, + spm_int_t ldc ) +{ + spm_int_t incx, incy; + + if ( side == SpmLeft ) { + incx = 1; + incy = 1; + } + else { + incx = ldb; + incy = ldc; + } + + args->follow_x = 0; + args->baseval = spmFindBase( A ); + args->n = A->n; + args->nnz = A->nnz; + args->alpha = alpha; + args->rowptr = A->rowptr; + args->colptr = A->colptr; + args->values = A->values; + args->x = B; + args->incx = incx; + args->y = C; + args->incy = incy; + args->conjA_fct = __fct_id; + args->conjAt_fct = __fct_id; + +#if defined(PRECISION_c) || defined(PRECISION_z) + if ( A->mtxtype != SpmHermitian ) { + if ( transA == SpmConjTrans ) { + args->conjA_fct = __fct_conj; + args->conjAt_fct = __fct_conj; + } + } + else { + if ( transA == SpmTrans ) { + args->conjA_fct = __fct_conj; + args->conjAt_fct = __fct_id; + } + else { + args->conjA_fct = __fct_id; + args->conjAt_fct = __fct_conj; + } + } +#endif + + args->loop_fct = NULL; + + switch( A->fmttype ) { + case SpmCSC: + { + /* Switch pointers and side to get the correct behaviour */ + if ( ((side == SpmLeft) && (transA == SpmNoTrans)) || + ((side == SpmRight) && (transA != SpmNoTrans)) ) + { + args->follow_x = 1; + } + else { + args->follow_x = 0; + } + args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc; + } + break; + case SpmCSR: + { + /* Switch pointers and side to get the correct behaviour */ + if ( ((side == SpmLeft) && (transA != SpmNoTrans)) || + ((side == SpmRight) && (transA == SpmNoTrans)) ) + { + args->follow_x = 1; + } + else { + args->follow_x = 0; + } + args->colptr = A->rowptr; + args->rowptr = A->colptr; + args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr; + } + break; + case SpmIJV: + { + if ( ((side == SpmLeft) && (transA != SpmNoTrans)) || + ((side == SpmRight) && (transA == SpmNoTrans)) ) + { + const __conj_fct_t tmp_fct = args->conjA_fct; + args->conjA_fct = args->conjAt_fct; + args->conjAt_fct = tmp_fct; + args->colptr = A->rowptr; + args->rowptr = A->colptr; + } + args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv; + } + break; + default: + return SPM_ERR_BADPARAMETER; + } + + return SPM_SUCCESS; +} + /** ******************************************************************************* * @@ -342,7 +457,7 @@ spm_zspmm( spm_side_t side, spm_int_t ldc ) { int rc = SPM_SUCCESS; - spm_int_t M, N, incx, incy, ldx, ldy, r; + spm_int_t M, N, ldx, ldy, r; __spm_zmatvec_t args; if ( transB != SpmNoTrans ) { @@ -355,8 +470,6 @@ spm_zspmm( spm_side_t side, M = A->n; N = K; - incx = 1; - incy = 1; ldx = ldb; ldy = ldc; } @@ -364,8 +477,6 @@ spm_zspmm( spm_side_t side, M = K; N = A->n; - incx = ldb; - incy = ldc; ldx = 1; ldy = 1; } @@ -381,76 +492,8 @@ spm_zspmm( spm_side_t side, return SPM_SUCCESS; } - { - args.follow_x = 0; - args.baseval = spmFindBase( A ); - args.n = A->n; - args.nnz = A->nnz; - args.alpha = alpha; - args.rowptr = A->rowptr; - args.colptr = A->colptr; - args.values = A->values; - args.x = B; - args.incx = incx; - args.y = C; - args.incy = incy; - args.conj_fct = __fct_id; - args.loop_fct = NULL; - } - -#if defined(PRECISION_c) || defined(PRECISION_z) - if ( ( transA == SpmConjTrans ) || - ( A->mtxtype == SpmHermitian ) ) - { - args.conj_fct = __fct_conj; - } -#endif - - switch( A->fmttype ) { - case SpmCSC: - { - /* Switch pointers and side to get the correct behaviour */ - if ( ((side == SpmLeft) && (transA == SpmNoTrans)) || - ((side == SpmRight) && (transA != SpmNoTrans)) ) - { - args.follow_x = 1; - } - else { - args.follow_x = 0; - } - args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc; - } - break; - case SpmCSR: - { - /* Switch pointers and side to get the correct behaviour */ - if ( ((side == SpmLeft) && (transA != SpmNoTrans)) || - ((side == SpmRight) && (transA == SpmNoTrans)) ) - { - args.follow_x = 1; - } - else { - args.follow_x = 0; - } - args.colptr = A->rowptr; - args.rowptr = A->colptr; - args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr; - } - break; - case SpmIJV: - { - if ( ((side == SpmLeft) && (transA != SpmNoTrans)) || - ((side == SpmRight) && (transA == SpmNoTrans)) ) - { - args.colptr = A->rowptr; - args.rowptr = A->colptr; - } - args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv; - } - break; - default: - return SPM_ERR_BADPARAMETER; - } + __spm_zmatvec_args_init( &args, side, transA, + alpha, A, B, ldb, C, ldc ); for( r=0; (r < N) && (rc == SPM_SUCCESS); r++ ) { args.x = B + r * ldx; @@ -528,59 +571,8 @@ spm_zspmv( spm_trans_t trans, return SPM_SUCCESS; } - { - args.follow_x = 0; - args.baseval = spmFindBase( A ); - args.n = A->n; - args.nnz = A->nnz; - args.alpha = alpha; - args.rowptr = A->rowptr; - args.colptr = A->colptr; - args.values = A->values; - args.x = x; - args.incx = incx; - args.y = y; - args.incy = incy; - args.conj_fct = __fct_id; - args.loop_fct = NULL; - } - -#if defined(PRECISION_c) || defined(PRECISION_z) - if ( ( trans == SpmConjTrans ) || - ( A->mtxtype == SpmHermitian ) ) - { - args.conj_fct = __fct_conj; - } -#endif - - switch( A->fmttype ) { - case SpmCSC: - { - args.follow_x = (trans == SpmNoTrans) ? 1 : 0; - args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc; - } - break; - case SpmCSR: - { - /* Switch pointers and side to get the correct behaviour */ - args.follow_x = (trans == SpmNoTrans) ? 0 : 1; - args.colptr = A->rowptr; - args.rowptr = A->colptr; - args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr; - } - break; - case SpmIJV: - { - if ( trans != SpmNoTrans ) { - args.colptr = A->rowptr; - args.rowptr = A->colptr; - } - args.loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_ijv : __spm_zmatvec_sy_ijv; - } - break; - default: - return SPM_ERR_BADPARAMETER; - } + __spm_zmatvec_args_init( &args, SpmLeft, trans, + alpha, A, x, incx, y, incy ); rc = args.loop_fct( &args ); diff --git a/tests/spm_matvec_tests.c b/tests/spm_matvec_tests.c index 2ccd492b949d17475d4468deeea9139ed5e92a59..3bc1e705fe4959f59f1f828712a7bf9cbd35773a 100644 --- a/tests/spm_matvec_tests.c +++ b/tests/spm_matvec_tests.c @@ -97,10 +97,6 @@ int main (int argc, char **argv) { continue; } - if ( (spm.mtxtype != SpmGeneral) && (t != SpmNoTrans) ) - { - continue; - } printf(" Case %s - %s - %d - %s:\n", mtxnames[mtxtype - SpmGeneral], fmtnames[fmttype - SpmCSC],