diff --git a/z_spm.h b/z_spm.h index 3e7507ffff5e41a1df0ce9ee328bd722649660dc..3a2e4e667e00481ebe380032b4e9060ee7b8ca6a 100644 --- a/z_spm.h +++ b/z_spm.h @@ -38,9 +38,10 @@ int z_spmConvertIJV2CSR( pastix_spm_t *spm ); pastix_complex64_t *z_spm2dense( const pastix_spm_t *spm ); /** - * Matrix-Vector product routines + * Matrix-Vector and matrix-matrix product routines */ int z_spmCSCMatVec(const pastix_trans_t trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y); +int z_spmCSCMatMat(const pastix_trans_t trans, pastix_int_t n, const void *alpha, const pastix_spm_t *A, const void *B, pastix_int_t ldb, const void *beta, void *Cptr, pastix_int_t ldc ); /** * Norm computation routines diff --git a/z_spm_matrixvector.c b/z_spm_matrixvector.c index 7f7ea7e5d1da115245deb336e65afd688941c3c0..133040ca9687e02b86b946197abf8c1636a668ac 100644 --- a/z_spm_matrixvector.c +++ b/z_spm_matrixvector.c @@ -391,3 +391,97 @@ z_spmCSCMatVec(const pastix_trans_t trans, return z_spmGeCSCv( trans, alpha, spm, x, beta, y ); } } + +/** + ******************************************************************************* + * + * @ingroup spm_dev_matvec + * + * @brief Compute a matrix-matrix product. + * + * y = alpha * op(A) * B + beta * C + * + * where op(A) is one of: + * + * op( A ) = A or op( A ) = A' or op( A ) = conjg( A' ) + * + * alpha and beta are scalars, and x and y are vectors. + * + ******************************************************************************* + * + * @param[in] trans + * Specifies whether the matrix spm is transposed, not transposed or conjugate transposed: + * - PastixTrans + * - PastixNoTrans + * - PastixConjTrans + * + * @param[in] n + * The number of columns of the matrices B and C. + * + * @param[in] alpha + * alpha specifies the scalar alpha. + * + * @param[in] A + * The square sparse matrix A + * + * @param[in] B + * The matrix B of size ldb-by-n + * + * @param[in] ldb + * The leading dimension of the matrix B. ldb >= A->n + * + * @param[in] beta + * beta specifies the scalar beta. + * + * @param[inout] C + * The matrix C of size ldc-by-n + * + * @param[in] ldc + * The leading dimension of the matrix C. ldc >= A->n + * + ******************************************************************************* + * + * @retval PASTIX_SUCCESS if the y vector has been computed successfully, + * @retval PASTIX_ERR_BADPARAMETER otherwise. + * + *******************************************************************************/ +int +z_spmCSCMatMat(const pastix_trans_t trans, + pastix_int_t n, + const void *alphaptr, + const pastix_spm_t *A, + const void *Bptr, + pastix_int_t ldb, + const void *betaptr, + void *Cptr, + pastix_int_t ldc ) +{ + const pastix_complex64_t *B = (const pastix_complex64_t*)Bptr; + pastix_complex64_t *C = (pastix_complex64_t*)Cptr; + pastix_complex64_t alpha, beta; + int i, rc = PASTIX_SUCCESS; + + alpha = *((const pastix_complex64_t *)alphaptr); + beta = *((const pastix_complex64_t *)betaptr); + + switch (A->mtxtype) { +#if defined(PRECISION_z) || defined(PRECISION_c) + case PastixHermitian: + for( i=0; i<n; i++ ){ + rc = z_spmHeCSCv( alpha, A, B + i * ldb, beta, C + i *ldc ); + } + break; +#endif + case PastixSymmetric: + for( i=0; i<n; i++ ){ + rc = z_spmSyCSCv( alpha, A, B + i * ldb, beta, C + i *ldc ); + } + break; + case PastixGeneral: + default: + for( i=0; i<n; i++ ){ + rc = z_spmGeCSCv( trans, alpha, A, B + i * ldb, beta, C + i *ldc ); + } + } + return rc; +}