From 20929b52450da275467af5a875fac39a3a2c1f23 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Tue, 25 Oct 2016 13:57:24 +0200
Subject: [PATCH] Move spmExpand to spm.c

---
 spm.c          | 77 +++++++++++++++++++++++++++++++++++++++++++++++---
 spm_expand.c   | 21 --------------
 z_spm_expand.c |  2 +-
 3 files changed, 74 insertions(+), 26 deletions(-)

diff --git a/spm.c b/spm.c
index 54e3afc7..d8f954ef 100644
--- a/spm.c
+++ b/spm.c
@@ -85,7 +85,7 @@ void
 spmInit( pastix_spm_t *spm )
 {
     spm->mtxtype = PastixGeneral;
-    spm->flttype = PastixComplex64;
+    spm->flttype = PastixDouble;
     spm->fmttype = PastixCSC;
 
     spm->gN   = 0;
@@ -183,7 +183,9 @@ spmBase( pastix_spm_t *spm,
 	return;
 
     n   = spm->n;
-    nnz = spm->colptr[n] - spm->colptr[0];
+    nnz = spm->nnz;
+
+    assert( nnz == (spm->colptr[n] - spm->colptr[0]) );
 
     for (i = 0; i <= n; i++) {
         spm->colptr[i] += baseadj;
@@ -272,6 +274,9 @@ spmFindBase( const pastix_spm_t *spm )
 int
 spmConvert( int ofmttype, pastix_spm_t *spm )
 {
+    if ( spm->dof != 1 ) {
+        spm = spmExpand( spm );
+    }
     if ( conversionTable[spm->fmttype][ofmttype][spm->flttype] ) {
         return conversionTable[spm->fmttype][ofmttype][spm->flttype]( spm );
     }
@@ -331,6 +336,10 @@ spmNorm( int ntype,
 {
     double tmp;
 
+    if ( spm->dof != 1 ) {
+        fprintf(stderr, "WARNING: spm expanded due to non implemented norm for non-expanded spm\n");
+        spm = spmExpand( spm );
+    }
     switch (spm->flttype) {
     case PastixFloat:
         tmp = (double)s_spmNorm( ntype, spm );
@@ -381,6 +390,10 @@ spmNorm( int ntype,
 int
 spmSort( pastix_spm_t *spm )
 {
+    if ( spm->dof != 1 ) {
+        fprintf(stderr, "WARNING: spm expanded due to non implemented sort for non-expanded spm\n");
+        spm = spmExpand( spm );
+    }
     switch (spm->flttype) {
     case PastixPattern:
         p_spmSort( spm );
@@ -433,6 +446,10 @@ spmSort( pastix_spm_t *spm )
 pastix_int_t
 spmMergeDuplicate( pastix_spm_t *spm )
 {
+    if ( spm->dof != 1 ) {
+        fprintf(stderr, "WARNING: spm expanded due to non implemented merge for non-expanded spm\n");
+        spm = spmExpand( spm );
+    }
     switch (spm->flttype) {
     case PastixPattern:
         return p_spmMergeDuplicate( spm );
@@ -484,6 +501,10 @@ spmMergeDuplicate( pastix_spm_t *spm )
 pastix_int_t
 spmSymmetrize( pastix_spm_t *spm )
 {
+    if ( spm->dof != 1 ) {
+        fprintf(stderr, "WARNING: spm expanded due to non implemented symmetrize for non-expanded spm\n");
+        spm = spmExpand( spm );
+    }
     switch (spm->flttype) {
     case PastixPattern:
         return p_spmSymmetrize( spm );
@@ -551,6 +572,11 @@ spmCheckAndCorrect( pastix_spm_t *spm )
     /* PaStiX works on CSC matrices */
     spmConvert( PastixCSC, newspm );
 
+    if ( newspm->dof != 1 ) {
+        fprintf(stderr, "WARNING: newspm expanded due to missing check functions implementations\n");
+        newspm = spmExpand( newspm );
+    }
+
     /* Sort the rowptr for each column */
     spmSort( newspm );
 
@@ -580,7 +606,7 @@ spmCheckAndCorrect( pastix_spm_t *spm )
      * have been made
      */
     if (( spm->fmttype != newspm->fmttype ) ||
-        ( spm->nnz     != newspm->nnz     ) )
+        ( spm->nnzexp  != newspm->nnzexp  ) )
     {
         return newspm;
     }
@@ -633,13 +659,56 @@ spmCopy( const pastix_spm_t *spm )
         memcpy( newspm->loc2glob, spm->loc2glob, spm->n * sizeof(pastix_int_t));
     }
     if(spm->values != NULL) {
-        size_t valsize = spm->nnz * pastix_size_of( spm->flttype );
+        size_t valsize = spm->nnzexp * pastix_size_of( spm->flttype );
         newspm->values = malloc(valsize);
         memcpy( newspm->values, spm->values, valsize);
     }
     return newspm;
 }
 
+/**
+ *******************************************************************************
+ *
+ * @ingroup pastix_spm
+ *
+ * @brief Expand a multi-dof spm matrix into an spm with constant dof to 1.
+ *
+ * Duplicate the spm data structure given as parameter. All new arrays are
+ * allocated and copied from the original matrix. Both matrices need to be
+ * freed.
+ *
+ *******************************************************************************
+ *
+ * @param[in] spm
+ *          The sparse matrix to copy.
+ *
+ *******************************************************************************
+ *
+ * @return
+ *          The copy of the sparse matrix.
+ *
+ *******************************************************************************/
+void
+spmExpand(pastix_spm_t* spm)
+{
+    switch(spm->flttype)
+    {
+    case PastixFloat:
+        s_spmExpand(spm);
+        break;
+    case PastixComplex32:
+        c_spmExpand(spm);
+        break;
+    case PastixComplex64:
+        z_spmExpand(spm);
+        break;
+    case PastixDouble:
+    default:
+        d_spmExpand(spm);
+        break;
+    }
+}
+
 /**
  *******************************************************************************
  *
diff --git a/spm_expand.c b/spm_expand.c
index ac66de23..2957feb9 100644
--- a/spm_expand.c
+++ b/spm_expand.c
@@ -22,27 +22,6 @@
 #include "s_spm.h"
 #include "p_spm.h"
 
-void
-spmExpand(pastix_spm_t* spm)
-{
-    switch(spm->flttype)
-    {
-    case PastixFloat:
-        s_spmExpand(spm);
-        break;
-    case PastixComplex32:
-        c_spmExpand(spm);
-        break;
-    case PastixComplex64:
-        z_spmExpand(spm);
-        break;
-    case PastixDouble:
-    default:
-        d_spmExpand(spm);
-        break;
-    }
-}
-
 void print_tab_int(pastix_int_t* tab, pastix_int_t size)
 {
     int i;
diff --git a/z_spm_expand.c b/z_spm_expand.c
index 3e461691..672b9222 100644
--- a/z_spm_expand.c
+++ b/z_spm_expand.c
@@ -21,7 +21,7 @@
  * TODO: This function is incorrect
  */
 int
-z_spmExpand(pastix_spm_t* spm)
+z_spmExpand(pastix_spm_t *spm)
 {
     pastix_int_t i, col, row, cpt, dofj, dofi, baseval;
     pastix_complex64_t *oldvalptr;
-- 
GitLab