From a5b17df2b0526370583cef91b0d6c6d64f8c6103 Mon Sep 17 00:00:00 2001
From: tdelarue <tony.delarue@inria.fr>
Date: Mon, 7 Sep 2020 16:45:57 +0200
Subject: [PATCH] SpmSort now works with multidof

---
 src/spm.c                  |  62 ++++----
 src/z_spm.c                | 283 ++++++++++++++++++++++++++++++++++++-
 src/z_spm.h                |   1 +
 src/z_spm_expand.c         |   2 +-
 tests/CMakeLists.txt       |   3 +-
 tests/spm_dof_sort_tests.c | 152 ++++++++++++++++++++
 6 files changed, 475 insertions(+), 28 deletions(-)
 create mode 100644 tests/spm_dof_sort_tests.c

diff --git a/src/spm.c b/src/spm.c
index 8c92c26a..a27458ce 100644
--- a/src/spm.c
+++ b/src/spm.c
@@ -540,11 +540,7 @@ spmNorm( spm_normtype_t   ntype,
 /**
  *******************************************************************************
  *
- * @brief Sort the subarray of edges of each vertex in a CSC or CSR format.
- *
- * Nothing is performed if IJV format is used.
- *
- * @warning This function should NOT be called if dof is greater than 1.
+ * @brief Sort the subarray of edges of each vertex.
  *
  *******************************************************************************
  *
@@ -563,27 +559,43 @@ int
 spmSort( spmatrix_t *spm )
 {
     if ( (spm->dof != 1) && (spm->flttype != SpmPattern) ) {
-        assert( 0 );
-        fprintf(stderr, "ERROR: spmSort should not be called with non expanded matrices including values\n");
+        switch (spm->flttype) {
+        case SpmFloat:
+            s_spmSortMultidof( spm );
+            break;
+        case SpmDouble:
+            d_spmSortMultidof( spm );
+            break;
+        case SpmComplex32:
+            c_spmSortMultidof( spm );
+            break;
+        case SpmComplex64:
+            z_spmSortMultidof( spm );
+            break;
+        default:
+            return SPM_ERR_BADPARAMETER;
+        }
     }
-    switch (spm->flttype) {
-    case SpmPattern:
-        p_spmSort( spm );
-        break;
-    case SpmFloat:
-        s_spmSort( spm );
-        break;
-    case SpmDouble:
-        d_spmSort( spm );
-        break;
-    case SpmComplex32:
-        c_spmSort( spm );
-        break;
-    case SpmComplex64:
-        z_spmSort( spm );
-        break;
-    default:
-        return SPM_ERR_BADPARAMETER;
+    else {
+        switch (spm->flttype) {
+        case SpmPattern:
+            p_spmSort( spm );
+            break;
+        case SpmFloat:
+            s_spmSort( spm );
+            break;
+        case SpmDouble:
+            d_spmSort( spm );
+            break;
+        case SpmComplex32:
+            c_spmSort( spm );
+            break;
+        case SpmComplex64:
+            z_spmSort( spm );
+            break;
+        default:
+            return SPM_ERR_BADPARAMETER;
+        }
     }
     return SPM_SUCCESS;
 }
diff --git a/src/z_spm.c b/src/z_spm.c
index cc64cfef..366a84c5 100644
--- a/src/z_spm.c
+++ b/src/z_spm.c
@@ -78,7 +78,7 @@ z_spmSort( spmatrix_t *spm )
             size = rowptr[1] - rowptr[0];
 
 #if defined(PRECISION_p)
-            spmIntSort1Asc1( rowptr, size );
+            spmIntSort1Asc1( colptr, size );
 #else
             sortptr[0] = colptr;
             sortptr[1] = values;
@@ -103,6 +103,287 @@ z_spmSort( spmatrix_t *spm )
     }
 }
 
+/**
+ * @brief Sort subarrays of rowptr and values
+ *
+ * @param[in] rowptr
+ *          The original rowptr.
+ *
+ *  @param[in] rowtmp
+ *          The sorted copy of the rowptr.
+ *
+ * @param[in] values
+ *          The original values array.
+ *
+ * @param[inout] valtmp
+ *          The copy of the valptr to sort.
+ *
+ * @param[in] dofs
+ *          The pointer to the dofs array.
+ *
+ * @param[in] dof
+ *          SPM dof value.
+ *
+ * @param[in] baseval
+ *          SPM baseval.
+ *
+ * @param[in] dofj
+ *          Current colum dof.
+ *
+ * @param[in] size
+ *          Size of the current subarray.
+ *
+ * @return number of value sorted in the value array
+ */
+static inline spm_int_t
+z_spm_sort_multidof_csx_values( const spm_int_t       *rowptr,
+                                const spm_int_t       *rowtmp,
+                                const spm_complex64_t *values,
+                                      spm_complex64_t *valtmp,
+                                const spm_int_t       *dofs,
+                                      spm_int_t        dof,
+                                      spm_int_t        baseval,
+                                      spm_int_t        dofj,
+                                      spm_int_t        size )
+{
+    spm_int_t i, ig, dofi;
+    spm_int_t k = 0;
+    spm_int_t memory, count, added = 0;
+
+    while (k < size)
+    {
+        memory = 0;
+        while ( (k < (size - 1)) && (rowtmp[k] == rowtmp[k+1]) )
+        {
+            memory++;
+            k++;
+        }
+
+        count = 0;
+        for ( i = 0; i < size; i++)
+        {
+            ig   = rowptr[i] - baseval;
+            dofi = (dof > 0) ? dof : dofs[ig+1] - dofs[ig];
+            if ( rowtmp[k] != rowptr[i] ) {
+                count += dofj * dofi;
+                continue;
+            }
+            /*
+             * The matrix isn't merged.
+             * We have to make sure that we don't copy the same information.
+             */
+            memcpy( valtmp + added,
+                    values + count,
+                    dofi * dofj * sizeof(spm_complex64_t) );
+            added += dofi * dofj;
+
+            if ( memory > 0 ) {
+                memory--;
+                continue;
+            }
+
+            k++;
+            break;
+        }
+    }
+    return added;
+}
+
+/**
+ * @brief Sort a IJV matrix.
+ *
+ * @param[in] spm
+ *          Pointer to the spm structure.
+ *
+ * @param[inout] newcol
+ *          The sorted copy of the colptr.
+ *
+ * @param[inout] newrow
+ *          The sorted copy of the rowptr.
+ *
+ * @param[inout] newval
+ *          The copy of the valptr to sort.
+ */
+static inline void
+z_spm_sort_multidof_ijv_values( const spmatrix_t *spm,
+                                spm_int_t        *newcol,
+                                spm_int_t        *newrow,
+                                spm_complex64_t  *newval )
+{
+    spm_int_t       *colptr;
+    spm_int_t       *rowptr;
+    spm_complex64_t *values;
+    spm_int_t       *dofs;
+    spm_int_t        i, ig, jg, dofi, dofj, dof2;
+    spm_int_t        size, baseval;
+    spm_int_t        k = 0;
+    spm_int_t        count, memory = 0;
+
+    values  = spm->values;
+    dofs    = spm->dofs;
+    size    = spm->nnz;
+    baseval = spmFindBase(spm);
+    while (k < size)
+    {
+        while ( (newcol[0] == newcol[1]) && (newrow[0] == newrow[1]) )
+        {
+            newcol++;
+            newrow++;
+            memory++;
+            k++;
+        }
+
+        jg   = *newcol - baseval;
+        dofj = (spm->dof > 0) ? spm->dof : dofs[jg+1] - dofs[jg];
+        ig   = *newrow - baseval;
+        dofi = (spm->dof > 0) ? spm->dof : dofs[ig+1] - dofs[ig];
+        dof2 = dofi * dofj;
+
+        count  = 0;
+        colptr = spm->colptr;
+        rowptr = spm->rowptr;
+        for ( i = 0; i < size; i++, colptr++, rowptr++ )
+        {
+            jg   = *colptr - baseval;
+            dofj = (spm->dof > 0) ? spm->dof : dofs[jg+1] - dofs[jg];
+            ig   = *rowptr - baseval;
+            dofi = (spm->dof > 0) ? spm->dof : dofs[ig+1] - dofs[ig];
+
+            if ( ((*newcol) != (*colptr)) || ((*newrow) != (*rowptr)) ) {
+                count += dofj * dofi;
+                continue;
+            }
+
+            memcpy( newval,
+                    values + count,
+                    dof2 * sizeof(spm_complex64_t) );
+            newval += dof2;
+
+            if( memory > 0 ) {
+                memory--;
+                continue;
+            }
+
+            newcol++;
+            newrow++;
+            k++;
+            break;
+        }
+        assert(memory == 0);
+    }
+}
+
+/**
+ *******************************************************************************
+ *
+ * @ingroup spm_dev_check
+ *
+ * @brief This routine sorts the spm matrix.
+ *
+ * For the CSC and CSR formats, the subarray of edges for each vertex are sorted.
+ * For the IJV format, the edges are sorted first by column indexes, and then
+ * by row indexes. To perform a sort first by row, second by column, please swap
+ * the colptr and rowptr of the structure before calling the subroutine.
+ * This routine is used for multidof matrices. It's way less efficient than the
+ * single dof one.
+ *
+ *******************************************************************************
+ *
+ * @param[inout] spm
+ *          On entry, the pointer to the sparse matrix structure.
+ *          On exit, the same sparse matrix with subarrays of edges sorted by
+ *          ascending order.
+ *
+ *******************************************************************************/
+void
+z_spmSortMultidof( spmatrix_t *spm )
+{
+    spm_int_t       *colptr, *newcol, *coltmp;
+    spm_int_t       *rowptr, *newrow, *rowtmp;
+    spm_complex64_t *values, *newval, *valtmp;
+    spm_int_t        size, n = spm->n;
+
+    newrow = malloc( spm->nnz    * sizeof(spm_int_t) );
+    newval = malloc( spm->nnzexp * sizeof(spm_complex64_t) );
+    values = spm->values;
+
+    if ( spm->fmttype != SpmIJV ) {
+        spm_int_t *loc2glob = spm->loc2glob;
+        spm_int_t *dofs = spm->dofs;
+        spm_int_t  j, jg, dofj, baseval;
+        spm_int_t  added;
+
+        baseval = spmFindBase(spm);
+        rowtmp  = newrow;
+        valtmp  = newval;
+        colptr  = (spm->fmttype == SpmCSC) ? spm->colptr : spm->rowptr;
+        rowptr  = (spm->fmttype == SpmCSC) ? spm->rowptr : spm->colptr;
+
+        memcpy( newrow, rowptr, spm->nnz * sizeof(spm_int_t) );
+        for (j=0; j<n; j++, colptr++, loc2glob++)
+        {
+            size = colptr[1] - colptr[0];
+            jg   = (spm->loc2glob == NULL) ? j : *loc2glob - baseval;
+            dofj = (spm->dof > 0) ? spm->dof : dofs[jg+1] - dofs[jg];
+
+            /* Sort rowptr */
+            spmIntSort1Asc1( rowtmp, size );
+
+            /* Sort values */
+            added = z_spm_sort_multidof_csx_values( rowptr, rowtmp, values, valtmp, dofs,
+                                                    spm->dof, baseval, dofj, size );
+
+            rowptr += size;
+            rowtmp += size;
+            values += added;
+            valtmp += added;
+        }
+
+        if(spm->fmttype == SpmCSC) {
+            memcpy( spm->rowptr, newrow, spm->nnz * sizeof( spm_int_t ) );
+        }
+        else {
+            memcpy( spm->colptr, newrow, spm->nnz * sizeof( spm_int_t ) );
+        }
+    }
+
+    else {
+        void *sortptr[2];
+
+        colptr = spm->colptr;
+        rowptr = spm->rowptr;
+        size   = spm->nnz;
+        newcol = malloc( size * sizeof(spm_int_t) );
+
+        memcpy( newcol, colptr, size * sizeof(spm_int_t) );
+        memcpy( newrow, rowptr, size * sizeof(spm_int_t) );
+
+        sortptr[0] = newcol;
+        sortptr[1] = newrow;
+
+        /* Sort the colptr and the rowptr */
+        spmIntMSortIntAsc( sortptr, size );
+
+        coltmp = newcol;
+        rowtmp = newrow;
+        valtmp = newval;
+
+        /* Sort values */
+        z_spm_sort_multidof_ijv_values( spm, coltmp, rowtmp, valtmp );
+
+        memcpy(spm->colptr, newcol, spm->nnz    * sizeof( spm_int_t ));
+        memcpy(spm->rowptr, newrow, spm->nnz    * sizeof( spm_int_t ));
+
+        free(newcol);
+    }
+
+    memcpy( spm->values, newval, spm->nnzexp * sizeof( spm_complex64_t ) );
+
+    free(newrow);
+    free(newval);
+}
+
+
 /**
  *******************************************************************************
  *
diff --git a/src/z_spm.h b/src/z_spm.h
index 7cd4c301..f2758d4e 100644
--- a/src/z_spm.h
+++ b/src/z_spm.h
@@ -70,6 +70,7 @@ double z_spmNorm( spm_normtype_t ntype, const spmatrix_t *spm );
  * Extra routines
  */
 void      z_spmSort( spmatrix_t *spm );
+void      z_spmSortMultidof( spmatrix_t *spm );
 spm_int_t z_spmMergeDuplicate( spmatrix_t *spm );
 spm_int_t z_spmSymmetrize( spmatrix_t *spm );
 
diff --git a/src/z_spm_expand.c b/src/z_spm_expand.c
index f57aef7b..ad757e3f 100644
--- a/src/z_spm_expand.c
+++ b/src/z_spm_expand.c
@@ -282,7 +282,7 @@ z_spmCSRExpand( const spmatrix_t *spm_in, spmatrix_t *spm_out )
         newrow[1] = newrow[0];
         for(k=oldrow[0]; k<oldrow[1]; k++)
         {
-            jg = oldcol[k-baseval] - baseval;
+            jg   = oldcol[k-baseval] - baseval;
             dofj = (spm_in->dof > 0 ) ? spm_in->dof : dofs[jg+1] - dofs[jg];
             newrow[1] += dofj;
 
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 0040b2ac..7257f221 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -46,6 +46,7 @@ set (TESTS
   spm_dof_expand_tests.c
   spm_dof_norm_tests.c
   spm_dof_matvec_tests.c
+  spm_dof_sort_tests.c
   )
 
 if ( SPM_WITH_MPI )
@@ -67,7 +68,7 @@ endforeach()
 set( SPM_TESTS
   spm_convert_tests spm_norm_tests spm_matvec_tests )
 set( SPM_DOF_TESTS
-  spm_dof_expand_tests spm_dof_norm_tests spm_dof_matvec_tests)
+  spm_dof_expand_tests spm_dof_norm_tests spm_dof_matvec_tests spm_dof_sort_tests)
 set( SPM_MPI_TESTS
   spm_scatter_gather_tests
   spm_dist_norm_tests
diff --git a/tests/spm_dof_sort_tests.c b/tests/spm_dof_sort_tests.c
new file mode 100644
index 00000000..1921bae6
--- /dev/null
+++ b/tests/spm_dof_sort_tests.c
@@ -0,0 +1,152 @@
+/**
+ *
+ * @file spm_dof_sort_tests.c
+ *
+ * Tests and validate the spm_sort routines when the spm_tests.hold constant and/or variadic dofs.
+ *
+ * @copyright 2015-2017 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
+ *                      Univ. Bordeaux. All rights reserved.
+ *
+ * @version 6.0.0
+ * @author Mathieu Faverge
+ * @author Delarue Tony
+ * @date 2020-09-07
+ *
+ **/
+#include <stdint.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <time.h>
+#include <spm_tests.h>
+
+#define PRINT_RES(_ret_)                        \
+    if(_ret_) {                                 \
+        printf("FAILED(%d)\n", _ret_);          \
+        err++;                                  \
+    }                                           \
+    else {                                      \
+        printf("SUCCESS\n");                    \
+    }
+
+static inline int
+spm_sort_check( spmatrix_t *spm )
+{
+    spmatrix_t expand1, expand2;
+    int rc;
+
+    spmExpand( spm, &expand1 );
+
+    spmSort( spm );
+    spmSort( &expand1 );
+
+    spmExpand( spm, &expand2 );
+
+    rc = spmCompare( &expand1, &expand2 );
+
+    spmExit( &expand1 );
+    spmExit( &expand2 );
+
+    return rc;
+}
+
+int main (int argc, char **argv)
+{
+    spmatrix_t    original, *spm;
+    spm_driver_t  driver;
+    char         *filename;
+    spm_mtxtype_t spmtype, mtxtype;
+    spm_fmttype_t fmttype;
+    int baseval;
+    int rc = SPM_SUCCESS;
+    int err = 0;
+    int i, dofmax = 4;
+
+#if defined(SPM_WITH_MPI)
+    MPI_Init( &argc, &argv );
+#endif
+
+    /**
+     * Get options from command line
+     */
+    spmGetOptions( argc, argv,
+                   &driver, &filename );
+
+    rc = spmReadDriver( driver, filename, &original );
+    free(filename);
+
+    if ( rc != SPM_SUCCESS ) {
+        fprintf(stderr, "ERROR: Could not read the file, stop the test !!!\n");
+        return EXIT_FAILURE;
+    }
+
+    spmtype = original.mtxtype;
+    printf(" -- SPM Sort Dof Test --\n");
+
+    for( i=0; i<2; i++ )
+    {
+        for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ )
+        {
+            if ( (mtxtype == SpmHermitian) &&
+                 ( ((original.flttype != SpmComplex64) && (original.flttype != SpmComplex32)) ||
+                   (spmtype != SpmHermitian) ) )
+            {
+                continue;
+            }
+            if ( (mtxtype != SpmGeneral) &&
+                 (spmtype == SpmGeneral) )
+            {
+                continue;
+            }
+            original.mtxtype = mtxtype;
+
+            for( baseval=0; baseval<2; baseval++ )
+            {
+                spmBase( &original, baseval );
+
+                for( fmttype=SpmCSC; fmttype<=SpmIJV; fmttype++ )
+                {
+                    spmConvert( fmttype, &original );
+                    spm = spmDofExtend( &original, i, dofmax );
+                    if ( spm == NULL ) {
+                        fprintf( stderr, "FAILED to extend matrix\n" );
+                        PRINT_RES(1);
+                        continue;
+                    }
+
+                    printf( " Case: %s / %s / %s / %d / %s\n",
+                            fltnames[spm->flttype],
+                            dofname[i+1],
+                            mtxnames[mtxtype - SpmGeneral],
+                            baseval,
+                            fmtnames[spm->fmttype] );
+
+                    rc = spm_sort_check( spm );
+                    err = (rc == 0) ? err : err + 1;
+                    PRINT_RES(rc);
+
+                    spmExit( spm );
+                    free(spm);
+                    spm = NULL;
+                }
+            }
+        }
+    }
+    spmExit( &original );
+
+#if defined(SPM_WITH_MPI)
+    MPI_Finalize();
+#endif
+
+    if( err == 0 ) {
+        printf(" -- All tests PASSED --\n");
+        return EXIT_SUCCESS;
+    }
+    else
+    {
+        printf(" -- %d tests FAILED --\n", err);
+        return EXIT_FAILURE;
+    }
+}
-- 
GitLab