From 8731353a25f43c5502f297a4dc567537c54a2354 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Wed, 28 Jun 2017 17:04:15 +0200
Subject: [PATCH] Make uniform datatypes between spm and pastix modules + minor

---
 spm.c             | 142 ++++++++++++++++++++++++++++++++++++++++++----
 spm.h             |  22 +++++--
 spm_read_driver.c |   2 +-
 3 files changed, 151 insertions(+), 15 deletions(-)

diff --git a/spm.c b/spm.c
index 44adced1..0a563ba2 100644
--- a/spm.c
+++ b/spm.c
@@ -188,6 +188,128 @@ spmUpdateComputedFields( pastix_spm_t *spm )
     spm->gnnzexp = spm->nnzexp;
 }
 
+/**
+ *******************************************************************************
+ *
+ * @brief Init the spm structure.
+ *
+ *******************************************************************************
+ *
+ * @param[inout] spm
+ *          The sparse matrix to init.
+ *
+ *******************************************************************************/
+void
+spm( pastix_spm_t      *spm,
+     pastix_symmetry_t  mtxtype,
+     pastix_coeftype_t  flttype,
+     pastix_fmttype_t   fmttype,
+     pastix_int_t       n,
+     pastix_int_t       nnz,
+     pastix_int_t      *colptr,
+     pastix_int_t      *rowptr,
+     void              *values,
+     pastix_int_t      *loc2glob,
+     pastix_int_t       dof,
+     pastix_layout_t    layout,
+     pastix_int_t      *dofs )
+{
+    spmInit( spm );
+
+    if ( ( mtxtype != PastixGeneral ) &&
+         ( mtxtype != PastixGeneral ) &&
+         ( mtxtype != PastixGeneral ) )
+    {
+        fprintf(stderr, "spm: The sparse matrix type must be PastixGeneral, PastixSymmetric or PastixHermitian\n");
+        return;
+    }
+    spm->mtxtype = mtxtype;
+
+    if ( ( flttype != PastixPattern   ) &&
+         ( flttype != PastixFloat     ) &&
+         ( flttype != PastixDouble    ) &&
+         ( flttype != PastixComplex32 ) &&
+         ( flttype != PastixComplex64 ) )
+    {
+        fprintf(stderr, "spm: The sparse matrix coefficient type must be PastixPattern, PastixFloat, PastixDouble, PastixComplex32, or PastixComplex64\n");
+        return;
+    }
+    spm->flttype = flttype;
+
+    if ( ( fmttype != PastixCSC ) &&
+         ( fmttype != PastixCSR ) &&
+         ( fmttype != PastixIJV ) )
+    {
+        fprintf(stderr, "spm: The sparse matrix format type must be PastixCSC, PastixCSR, or PastixIJV\n");
+        return;
+    }
+    spm->fmttype = fmttype;
+
+    if ( n <= 0 )
+    {
+        fprintf(stderr, "spm: The local matrix size n, must be strictly positive\n");
+        return;
+    }
+    spm->n = n;
+
+    if ( nnz <= 0 )
+    {
+        fprintf(stderr, "spm: The number of non zeros in the local matrix must be strictly positive\n");
+        return;
+    }
+    spm->nnz = nnz;
+
+    if ( colptr == NULL ) {
+        fprintf(stderr, "spm: The colptr array must be provided and of size n+1, if PastixCSC, nnz otherwise\n");
+        return;
+    }
+    spm->colptr   = colptr;
+
+    if ( rowptr == NULL ) {
+        fprintf(stderr, "spm: The rowptr array must be provided and of size n+1, if PastixCSR, nnz otherwise\n");
+        return;
+    }
+    spm->rowptr   = rowptr;
+
+
+    if ( loc2glob != NULL ) {
+        fprintf(stderr, "spm: The distributed interface is not supported for now\n");
+        return;
+    }
+    spm->loc2glob = NULL;
+
+    if ( (flttype != PastixPattern) && (values == NULL) ) {
+        fprintf(stderr, "spm: The values array of size nnz, and of type flttype must be provided\n");
+        return;
+    }
+    spm->values = values;
+
+    spm->dof = dof;
+    spm->layout = PastixColMajor;
+
+    if ( spm->dof != 1 ) {
+        if ( ( layout != PastixColMajor ) &&
+             ( layout != PastixRowMajor ) )
+        {
+            fprintf(stderr, "spm: The sparse matrix layout for multi-dof must be PastixColMajor or PastixRowMajor\n");
+            return;
+        }
+        spm->layout = layout;
+
+        if ( dof < 1 ) {
+            if ( dofs == NULL ) {
+                fprintf(stderr, "spm: The dofs array must be provided when dof < 1\n");
+                return;
+            }
+            spm->dofs = dofs;
+        }
+        else {
+            spm->dofs = NULL;
+        }
+    }
+    spmUpdateComputedFields( spm );
+}
+
 /**
  *******************************************************************************
  *
@@ -826,9 +948,9 @@ spmPrintInfo( const pastix_spm_t* spm, FILE *stream )
     char *mtxtypestr[4] = { "General", "Symmetric", "Hermitian", "Incorrect" };
     char *flttypestr[7] = { "Pattern", "", "Float", "Double", "Complex32", "Complex64", "Incorrect" };
     char *fmttypestr[4] = { "CSC", "CSR", "IJV", "Incorrect" };
-    int  mtxtype = spm->mtxtype - PastixGeneral;
-    int  flttype = spm->flttype - PastixPattern;
-    int  fmttype = spm->fmttype - PastixCSC;
+    int mtxtype = spm->mtxtype - PastixGeneral;
+    int flttype = spm->flttype - PastixPattern;
+    int fmttype = spm->fmttype - PastixCSC;
 
     if (stream == NULL) {
         stream = stdout;
@@ -1078,10 +1200,10 @@ spmMatVec(const pastix_trans_t trans,
  *
  *******************************************************************************/
 int
-spmGenRHS( pastix_rhstype_t type, int nrhs,
+spmGenRHS( pastix_rhstype_t type, pastix_int_t nrhs,
            const pastix_spm_t  *spm,
-           void                *x, int ldx,
-           void                *b, int ldb )
+           void                *x, pastix_int_t ldx,
+           void                *b, pastix_int_t ldb )
 {
     static int (*ptrfunc[4])(pastix_rhstype_t, int,
                              const pastix_spm_t *,
@@ -1143,11 +1265,11 @@ spmGenRHS( pastix_rhstype_t type, int nrhs,
  *
  *******************************************************************************/
 int
-spmCheckAxb( int nrhs,
+spmCheckAxb( pastix_int_t nrhs,
              const pastix_spm_t  *spm,
-                   void *x0, int ldx0,
-                   void *b,  int ldb,
-             const void *x,  int ldx )
+                   void *x0, pastix_int_t ldx0,
+                   void *b,  pastix_int_t ldb,
+             const void *x,  pastix_int_t ldx )
 {
     static int (*ptrfunc[4])(int, const pastix_spm_t *,
                              void *, int, void *, int, const void *, int) =
diff --git a/spm.h b/spm.h
index a7219988..f5d65259 100644
--- a/spm.h
+++ b/spm.h
@@ -42,7 +42,7 @@
  *
  */
 typedef struct pastix_spm_s {
-    int               mtxtype; /**< Matrix structure: PastixGeneral, PastixSymmetric
+    pastix_symmetry_t mtxtype; /**< Matrix structure: PastixGeneral, PastixSymmetric
                                     or PastixHermitian.                                            */
     pastix_coeftype_t flttype; /**< avals datatype: PastixPattern, PastixFloat, PastixDouble,
                                     PastixComplex32 or PastixComplex64                             */
@@ -83,6 +83,20 @@ pastix_int_t  spmFindBase( const pastix_spm_t *spm );
 int           spmConvert( int ofmttype, pastix_spm_t *ospm );
 void          spmUpdateComputedFields( pastix_spm_t *spm );
 
+void          spm( pastix_spm_t      *spm,
+                   pastix_symmetry_t  mtxtype,
+                   pastix_coeftype_t  flttype,
+                   pastix_fmttype_t   fmttype,
+                   pastix_int_t       n,
+                   pastix_int_t       nnz,
+                   pastix_int_t      *colptr,
+                   pastix_int_t      *rowptr,
+                   void              *values,
+                   pastix_int_t      *loc2glob,
+                   pastix_int_t       dof,
+                   pastix_layout_t    layout,
+                   pastix_int_t      *dofs );
+
 /**
  * @}
  * @name SPM BLAS subroutines
@@ -108,8 +122,8 @@ pastix_spm_t *spmCheckAndCorrect( pastix_spm_t *spm );
  * @name SPM subroutines to check factorization/solve
  * @{
  */
-int           spmGenRHS( pastix_rhstype_t type, int nrhs, const pastix_spm_t *spm, void *x, int ldx, void *b, int ldb );
-int           spmCheckAxb( int nrhs, const pastix_spm_t *spm, void *x0, int ldx0, void *b, int ldb, const void *x, int ldx );
+int           spmGenRHS( pastix_rhstype_t type, pastix_int_t nrhs, const pastix_spm_t *spm, void *x, pastix_int_t ldx, void *b, pastix_int_t ldb );
+int           spmCheckAxb( pastix_int_t nrhs, const pastix_spm_t *spm, void *x0, pastix_int_t ldx0, void *b, pastix_int_t ldb, const void *x, pastix_int_t ldx );
 
 /**
  * @}
@@ -135,7 +149,7 @@ int           spmSave( pastix_spm_t *spm, FILE *outfile );
  * @{
  */
 int           spmReadDriver( pastix_driver_t  driver,
-                             char            *filename,
+                             const char      *filename,
                              pastix_spm_t    *spm,
                              MPI_Comm         pastix_comm );
 /**
diff --git a/spm_read_driver.c b/spm_read_driver.c
index 094a8e0c..c8543481 100644
--- a/spm_read_driver.c
+++ b/spm_read_driver.c
@@ -64,7 +64,7 @@
  *******************************************************************************/
 int
 spmReadDriver( pastix_driver_t  driver,
-               char            *filename,
+               const char      *filename,
                pastix_spm_t    *spm,
                MPI_Comm         comm )
 {
-- 
GitLab