diff --git a/spm.c b/spm.c index 4ddac6021ac4d8dc49c7b321ce9b86ca14e0a151..13155d8f6965e76874c26806196426bc42527b60 100644 --- a/spm.c +++ b/spm.c @@ -1035,6 +1035,101 @@ spmMatVec( pastix_trans_t trans, return rc; } +/** + ******************************************************************************* + * + * @brief Compute a matrix-matrix product. + * + * y = alpha * op(A) * B + beta * C + * + * where op(A) is one of: + * + * op( A ) = A or op( A ) = A' or op( A ) = conjg( A' ) + * + * alpha and beta are scalars, and x and y are vectors. + * + ******************************************************************************* + * + * @param[in] trans + * Specifies whether the matrix spm is transposed, not transposed or conjugate transposed: + * - PastixTrans + * - PastixNoTrans + * - PastixConjTrans + * + * @param[in] n + * The number of columns of the matrices B and C. + * + * @param[in] alpha + * alpha specifies the scalar alpha. + * + * @param[in] A + * The square sparse matrix A + * + * @param[in] B + * The matrix B of size ldb-by-n + * + * @param[in] ldb + * The leading dimension of the matrix B. ldb >= A->n + * + * @param[in] beta + * beta specifies the scalar beta. + * + * @param[inout] C + * The matrix C of size ldc-by-n + * + * @param[in] ldc + * The leading dimension of the matrix C. ldc >= A->n + * + ******************************************************************************* + * + * @retval PASTIX_SUCCESS if the y vector has been computed successfully, + * @retval PASTIX_ERR_BADPARAMETER otherwise. + * + *******************************************************************************/ +int +spmMatMat( pastix_trans_t trans, + pastix_int_t n, + const void *alpha, + const pastix_spm_t *A, + const void *B, + pastix_int_t ldb, + const void *beta, + void *C, + pastix_int_t ldc ) +{ + pastix_spm_t *espm = (pastix_spm_t*)A; + int rc = PASTIX_SUCCESS; + + if ( A->fmttype != PastixCSC ) { + return PASTIX_ERR_BADPARAMETER; + } + + if ( A->dof != 1 ) { + espm = spmExpand( A ); + } + switch (A->flttype) { + case PastixFloat: + rc = s_spmCSCMat( trans, alpha, espm, B, beta, C); + break; + case PastixComplex32: + rc = c_spmCSCMat( trans, alpha, espm, B, beta, C); + break; + case PastixComplex64: + rc = z_spmCSCMat( trans, alpha, espm, B, beta, C); + break; + case PastixDouble: + default: + rc = d_spmCSCMat( trans, alpha, espm, B, beta, C); + break; + } + + if ( A != espm ) { + spmExit( espm ); + free(espm); + } + return rc; +} + /** ******************************************************************************* * diff --git a/spm.h b/spm.h index 91df2c9172b613352d98db4a0605cf2019b87313..9686a79fba312694e650a6de164dd32740a62d7e 100644 --- a/spm.h +++ b/spm.h @@ -92,6 +92,10 @@ void spmGenFakeValues( pastix_spm_t *spm ); */ double spmNorm( pastix_normtype_t ntype, const pastix_spm_t *spm ); int spmMatVec( pastix_trans_t trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y ); +int spmMatMat( pastix_trans_t trans, pastix_int_t n, + const void *alpha, const pastix_spm_t *A, + const void *B, pastix_int_t ldb, + const void *beta, void *C, pastix_int_t ldc ); void spmScalMatrix( double alpha, pastix_spm_t *spm ); void spmScalVector( double alpha, pastix_spm_t *spm, void *x ); void spmScalRHS( pastix_coeftype_t flt, double alpha, pastix_int_t m, pastix_int_t n, void *A, pastix_int_t lda );