From 8731353a25f43c5502f297a4dc567537c54a2354 Mon Sep 17 00:00:00 2001 From: Mathieu Faverge <mathieu.faverge@inria.fr> Date: Wed, 28 Jun 2017 17:04:15 +0200 Subject: [PATCH] Make uniform datatypes between spm and pastix modules + minor --- spm.c | 142 ++++++++++++++++++++++++++++++++++++++++++---- spm.h | 22 +++++-- spm_read_driver.c | 2 +- 3 files changed, 151 insertions(+), 15 deletions(-) diff --git a/spm.c b/spm.c index 44adced1..0a563ba2 100644 --- a/spm.c +++ b/spm.c @@ -188,6 +188,128 @@ spmUpdateComputedFields( pastix_spm_t *spm ) spm->gnnzexp = spm->nnzexp; } +/** + ******************************************************************************* + * + * @brief Init the spm structure. + * + ******************************************************************************* + * + * @param[inout] spm + * The sparse matrix to init. + * + *******************************************************************************/ +void +spm( pastix_spm_t *spm, + pastix_symmetry_t mtxtype, + pastix_coeftype_t flttype, + pastix_fmttype_t fmttype, + pastix_int_t n, + pastix_int_t nnz, + pastix_int_t *colptr, + pastix_int_t *rowptr, + void *values, + pastix_int_t *loc2glob, + pastix_int_t dof, + pastix_layout_t layout, + pastix_int_t *dofs ) +{ + spmInit( spm ); + + if ( ( mtxtype != PastixGeneral ) && + ( mtxtype != PastixGeneral ) && + ( mtxtype != PastixGeneral ) ) + { + fprintf(stderr, "spm: The sparse matrix type must be PastixGeneral, PastixSymmetric or PastixHermitian\n"); + return; + } + spm->mtxtype = mtxtype; + + if ( ( flttype != PastixPattern ) && + ( flttype != PastixFloat ) && + ( flttype != PastixDouble ) && + ( flttype != PastixComplex32 ) && + ( flttype != PastixComplex64 ) ) + { + fprintf(stderr, "spm: The sparse matrix coefficient type must be PastixPattern, PastixFloat, PastixDouble, PastixComplex32, or PastixComplex64\n"); + return; + } + spm->flttype = flttype; + + if ( ( fmttype != PastixCSC ) && + ( fmttype != PastixCSR ) && + ( fmttype != PastixIJV ) ) + { + fprintf(stderr, "spm: The sparse matrix format type must be PastixCSC, PastixCSR, or PastixIJV\n"); + return; + } + spm->fmttype = fmttype; + + if ( n <= 0 ) + { + fprintf(stderr, "spm: The local matrix size n, must be strictly positive\n"); + return; + } + spm->n = n; + + if ( nnz <= 0 ) + { + fprintf(stderr, "spm: The number of non zeros in the local matrix must be strictly positive\n"); + return; + } + spm->nnz = nnz; + + if ( colptr == NULL ) { + fprintf(stderr, "spm: The colptr array must be provided and of size n+1, if PastixCSC, nnz otherwise\n"); + return; + } + spm->colptr = colptr; + + if ( rowptr == NULL ) { + fprintf(stderr, "spm: The rowptr array must be provided and of size n+1, if PastixCSR, nnz otherwise\n"); + return; + } + spm->rowptr = rowptr; + + + if ( loc2glob != NULL ) { + fprintf(stderr, "spm: The distributed interface is not supported for now\n"); + return; + } + spm->loc2glob = NULL; + + if ( (flttype != PastixPattern) && (values == NULL) ) { + fprintf(stderr, "spm: The values array of size nnz, and of type flttype must be provided\n"); + return; + } + spm->values = values; + + spm->dof = dof; + spm->layout = PastixColMajor; + + if ( spm->dof != 1 ) { + if ( ( layout != PastixColMajor ) && + ( layout != PastixRowMajor ) ) + { + fprintf(stderr, "spm: The sparse matrix layout for multi-dof must be PastixColMajor or PastixRowMajor\n"); + return; + } + spm->layout = layout; + + if ( dof < 1 ) { + if ( dofs == NULL ) { + fprintf(stderr, "spm: The dofs array must be provided when dof < 1\n"); + return; + } + spm->dofs = dofs; + } + else { + spm->dofs = NULL; + } + } + spmUpdateComputedFields( spm ); +} + /** ******************************************************************************* * @@ -826,9 +948,9 @@ spmPrintInfo( const pastix_spm_t* spm, FILE *stream ) char *mtxtypestr[4] = { "General", "Symmetric", "Hermitian", "Incorrect" }; char *flttypestr[7] = { "Pattern", "", "Float", "Double", "Complex32", "Complex64", "Incorrect" }; char *fmttypestr[4] = { "CSC", "CSR", "IJV", "Incorrect" }; - int mtxtype = spm->mtxtype - PastixGeneral; - int flttype = spm->flttype - PastixPattern; - int fmttype = spm->fmttype - PastixCSC; + int mtxtype = spm->mtxtype - PastixGeneral; + int flttype = spm->flttype - PastixPattern; + int fmttype = spm->fmttype - PastixCSC; if (stream == NULL) { stream = stdout; @@ -1078,10 +1200,10 @@ spmMatVec(const pastix_trans_t trans, * *******************************************************************************/ int -spmGenRHS( pastix_rhstype_t type, int nrhs, +spmGenRHS( pastix_rhstype_t type, pastix_int_t nrhs, const pastix_spm_t *spm, - void *x, int ldx, - void *b, int ldb ) + void *x, pastix_int_t ldx, + void *b, pastix_int_t ldb ) { static int (*ptrfunc[4])(pastix_rhstype_t, int, const pastix_spm_t *, @@ -1143,11 +1265,11 @@ spmGenRHS( pastix_rhstype_t type, int nrhs, * *******************************************************************************/ int -spmCheckAxb( int nrhs, +spmCheckAxb( pastix_int_t nrhs, const pastix_spm_t *spm, - void *x0, int ldx0, - void *b, int ldb, - const void *x, int ldx ) + void *x0, pastix_int_t ldx0, + void *b, pastix_int_t ldb, + const void *x, pastix_int_t ldx ) { static int (*ptrfunc[4])(int, const pastix_spm_t *, void *, int, void *, int, const void *, int) = diff --git a/spm.h b/spm.h index a7219988..f5d65259 100644 --- a/spm.h +++ b/spm.h @@ -42,7 +42,7 @@ * */ typedef struct pastix_spm_s { - int mtxtype; /**< Matrix structure: PastixGeneral, PastixSymmetric + pastix_symmetry_t mtxtype; /**< Matrix structure: PastixGeneral, PastixSymmetric or PastixHermitian. */ pastix_coeftype_t flttype; /**< avals datatype: PastixPattern, PastixFloat, PastixDouble, PastixComplex32 or PastixComplex64 */ @@ -83,6 +83,20 @@ pastix_int_t spmFindBase( const pastix_spm_t *spm ); int spmConvert( int ofmttype, pastix_spm_t *ospm ); void spmUpdateComputedFields( pastix_spm_t *spm ); +void spm( pastix_spm_t *spm, + pastix_symmetry_t mtxtype, + pastix_coeftype_t flttype, + pastix_fmttype_t fmttype, + pastix_int_t n, + pastix_int_t nnz, + pastix_int_t *colptr, + pastix_int_t *rowptr, + void *values, + pastix_int_t *loc2glob, + pastix_int_t dof, + pastix_layout_t layout, + pastix_int_t *dofs ); + /** * @} * @name SPM BLAS subroutines @@ -108,8 +122,8 @@ pastix_spm_t *spmCheckAndCorrect( pastix_spm_t *spm ); * @name SPM subroutines to check factorization/solve * @{ */ -int spmGenRHS( pastix_rhstype_t type, int nrhs, const pastix_spm_t *spm, void *x, int ldx, void *b, int ldb ); -int spmCheckAxb( int nrhs, const pastix_spm_t *spm, void *x0, int ldx0, void *b, int ldb, const void *x, int ldx ); +int spmGenRHS( pastix_rhstype_t type, pastix_int_t nrhs, const pastix_spm_t *spm, void *x, pastix_int_t ldx, void *b, pastix_int_t ldb ); +int spmCheckAxb( pastix_int_t nrhs, const pastix_spm_t *spm, void *x0, pastix_int_t ldx0, void *b, pastix_int_t ldb, const void *x, pastix_int_t ldx ); /** * @} @@ -135,7 +149,7 @@ int spmSave( pastix_spm_t *spm, FILE *outfile ); * @{ */ int spmReadDriver( pastix_driver_t driver, - char *filename, + const char *filename, pastix_spm_t *spm, MPI_Comm pastix_comm ); /** diff --git a/spm_read_driver.c b/spm_read_driver.c index 094a8e0c..c8543481 100644 --- a/spm_read_driver.c +++ b/spm_read_driver.c @@ -64,7 +64,7 @@ *******************************************************************************/ int spmReadDriver( pastix_driver_t driver, - char *filename, + const char *filename, pastix_spm_t *spm, MPI_Comm comm ) { -- GitLab