From caf35f95bd8644b92a9c219b02f9ea701c8bfdcb Mon Sep 17 00:00:00 2001
From: tdelarue <tony.delarue@inria.fr>
Date: Tue, 13 Oct 2020 15:08:48 +0200
Subject: [PATCH] Add ChanckAndCorrect tests

---
 tests/CMakeLists.txt                     |   4 +
 tests/spm_check_and_correct_tests.c      | 229 ++++++++++++++++++++
 tests/spm_dist_check_and_correct_tests.c | 261 +++++++++++++++++++++++
 3 files changed, 494 insertions(+)
 create mode 100644 tests/spm_check_and_correct_tests.c
 create mode 100644 tests/spm_dist_check_and_correct_tests.c

diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 09b01b87..2c973a3c 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -48,6 +48,7 @@ set (TESTS
   spm_dof_norm_tests.c
   spm_dof_matvec_tests.c
   spm_dof_sort_tests.c
+  spm_check_and_correct_tests.c
   )
 
 if ( SPM_WITH_MPI )
@@ -57,6 +58,7 @@ if ( SPM_WITH_MPI )
     spm_dist_genrhs_tests.c
     spm_dist_matvec_tests.c
     spm_dist_sort_tests.c
+    spm_dist_check_and_correct_tests.c
     )
 endif()
 
@@ -71,6 +73,7 @@ set( SPM_TESTS
   spm_convert_tests
   spm_norm_tests
   spm_matvec_tests
+  spm_check_and_correct_tests
   )
 set( SPM_DOF_TESTS
   spm_dof_expand_tests
@@ -84,6 +87,7 @@ set( SPM_MPI_TESTS
   spm_dist_genrhs_tests
   spm_dist_matvec_tests
   spm_dist_sort_tests
+  spm_dist_check_and_correct_tests
   )
 
 # List of run types
diff --git a/tests/spm_check_and_correct_tests.c b/tests/spm_check_and_correct_tests.c
new file mode 100644
index 00000000..b18793aa
--- /dev/null
+++ b/tests/spm_check_and_correct_tests.c
@@ -0,0 +1,229 @@
+/**
+ *
+ * @file spm_check_and_correct_tests.c
+ *
+ * Tests and validate the spmCheckAndCorrect routines.
+ *
+ * @copyright 2015-2020 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
+ *                      Univ. Bordeaux. All rights reserved.
+ *
+ * @version 1.0.0
+ * @author Mathieu Faverge
+ * @author Delarue Tony
+ * @date 2020-09-30
+ *
+ **/
+#include <stdint.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <time.h>
+#include <spm_tests.h>
+
+#define PRINT_RES(_ret_)                        \
+    if(_ret_) {                                 \
+        printf("FAILED(%d)\n", _ret_);          \
+        err++;                                  \
+    }                                           \
+    else {                                      \
+        printf("SUCCESS\n");                    \
+    }
+
+static inline int
+spm_check_and_correct_check_merge_duplicate(const spmatrix_t *spm)
+{
+    spm_int_t  i, j;
+    spm_int_t  size, n;
+    spm_int_t *colptr = (spm->fmttype == SpmCSC) ? spm->colptr : spm->rowptr;
+    spm_int_t *rowptr = (spm->fmttype == SpmCSC) ? spm->rowptr : spm->colptr;
+
+    n       = spm->n;
+    for ( j = 0; j < n; j++, colptr++)
+    {
+        size = colptr[1] - colptr[0] - 1;
+        for ( i = 0; i < size; i++, rowptr++ )
+        {
+            /* MergeDuplicate should have been called */
+            if( rowptr[0] == rowptr[1] ) {
+                return 1;
+            }
+        }
+        rowptr++;
+    }
+    return 0;
+}
+
+static inline int
+spm_check_and_correct_check_symmetrize(const spmatrix_t *spm)
+{
+    spm_int_t  i, j, col, row, n, baseval;
+    spm_int_t  found, index, size;
+    spm_int_t *colptr = (spm->fmttype == SpmCSC) ? spm->colptr : spm->rowptr;
+    spm_int_t *coltmp = colptr;
+    spm_int_t *rowptr = (spm->fmttype == SpmCSC) ? spm->rowptr : spm->colptr;
+    spm_int_t *rowtmp = rowptr;
+
+    if ( spm->mtxtype != SpmGeneral ) {
+        return 0;
+    }
+
+    n       = spm->n;
+    baseval = spmFindBase(spm);
+    for ( col = 0; col < n; col++, coltmp++)
+    {
+        for ( i = coltmp[0]; i < coltmp[1]; i++, rowtmp++ )
+        {
+            row = *rowtmp - baseval;
+
+            index = colptr[row] - baseval;
+            size  = colptr[row + 1] - colptr[row];
+            /* Check the symmetry */
+            for ( j = 0; j < size; j++)
+            {
+                found = rowptr[index + j] - baseval;
+                if( found == col ) {
+                    break;
+                }
+                /* We've sort the matrix */
+                if( found > col ) {
+                    return 1;
+                }
+            }
+        }
+    }
+    return 0;
+}
+
+static inline int
+spm_check_and_correct_check(const spmatrix_t *spm)
+{
+    int rc1 = 0, rc2 = 0;
+    int new;
+    spmatrix_t spm_out;
+
+    new = spmCheckAndCorrect( spm, &spm_out );
+    if( new ) {
+        assert(spm_out.fmttype == SpmCSC);
+        /* Sort it */
+        spmSort( &spm_out );
+
+        rc1 = spm_check_and_correct_check_merge_duplicate( &spm_out );
+        rc2 = spm_check_and_correct_check_symmetrize( &spm_out );
+
+        spmExit(&spm_out);
+    }
+
+    return rc1 + rc2;
+}
+
+int main (int argc, char **argv)
+{
+    spmatrix_t    original, *spm;
+    spm_driver_t driver;
+    char *filename;
+    spm_mtxtype_t spmtype, mtxtype;
+    spm_fmttype_t fmttype;
+    int baseval, dof, to_free;
+    int rc = SPM_SUCCESS;
+    int err = 0;
+
+#if defined(SPM_WITH_MPI)
+    MPI_Init( &argc, &argv );
+#endif
+
+    /**
+     * Get options from command line
+     */
+    spmGetOptions( argc, argv,
+                   &driver, &filename );
+
+    rc = spmReadDriver( driver, filename, &original );
+    free(filename);
+
+    if ( rc != SPM_SUCCESS ) {
+        fprintf(stderr, "ERROR: Could not read the file, stop the test !!!\n");
+        return EXIT_FAILURE;
+    }
+
+    spmtype = original.mtxtype;
+    printf(" -- SPM CheckAndCorrect Test --\n");
+
+    for( fmttype=SpmCSC; fmttype<=SpmIJV; fmttype++ ) {
+
+        spmConvert( fmttype, &original );
+
+        for( dof=-1; dof<2; dof++ )
+        {
+            if ( dof >= 0 ) {
+                spm = spmDofExtend( &original, dof, 4 );
+                to_free = 1;
+            }
+            else {
+                spm = malloc(sizeof(spmatrix_t));
+                memcpy( spm, &original, sizeof(spmatrix_t) );
+                to_free = 0;
+            }
+
+            if ( spm == NULL ) {
+                fprintf( stderr, "Issue to extend the matrix\n" );
+                continue;
+            }
+
+            for( baseval=0; baseval<2; baseval++ )
+            {
+                spmBase( spm, baseval );
+
+                for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ )
+                {
+                    if ( (mtxtype == SpmHermitian) &&
+                        ( ((spm->flttype != SpmComplex64) && (spm->flttype != SpmComplex32)) ||
+                        (spmtype != SpmHermitian) ) )
+                    {
+                        continue;
+                    }
+                    if ( (mtxtype != SpmGeneral) &&
+                        (spmtype == SpmGeneral) )
+                    {
+                        continue;
+                    }
+                    spm->mtxtype = mtxtype;
+
+                    printf(" Case: %s / %s / %d / %s\n",
+                        fltnames[spm->flttype],
+                        fmtnames[spm->fmttype],
+                        baseval,
+                        mtxnames[mtxtype - SpmGeneral] );
+
+                    rc = spm_check_and_correct_check(spm);
+                    err = (rc == 0) ? err : err+1;
+                    PRINT_RES(rc);
+                }
+            }
+
+            if ( spm != &original ) {
+                if( to_free ){
+                    spmExit( spm  );
+                }
+                free( spm );
+            }
+
+        }
+    }
+    spmExit( &original  );
+
+#if defined(SPM_WITH_MPI)
+    MPI_Finalize();
+#endif
+
+    if( err == 0 ) {
+        printf(" -- All tests PASSED --\n");
+        return EXIT_SUCCESS;
+    }
+    else
+    {
+        printf(" -- %d tests FAILED --\n", err);
+        return EXIT_FAILURE;
+    }
+}
diff --git a/tests/spm_dist_check_and_correct_tests.c b/tests/spm_dist_check_and_correct_tests.c
new file mode 100644
index 00000000..48c463f0
--- /dev/null
+++ b/tests/spm_dist_check_and_correct_tests.c
@@ -0,0 +1,261 @@
+/**
+ *
+ * @file spm_dist_check_and_correct_tests.c
+ *
+ * Tests and validate the spmCheckAndCorrect distributed routines.
+ *
+ * @copyright 2015-2020 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
+ *                      Univ. Bordeaux. All rights reserved.
+ *
+ * @version 1.0.0
+ * @author Mathieu Faverge
+ * @author Delarue Tony
+ * @date 2020-07-20
+ *
+ **/
+#include <stdint.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <time.h>
+#include "spm_tests.h"
+
+#if !defined(SPM_WITH_MPI)
+#error "This test should not be compiled in non distributed version"
+#endif
+
+/*------------------------------------------------------------------------
+ *  Check the symmetrization of the solution
+ */
+static inline int
+spm_check_and_correct_check_merge_duplicate(const spmatrix_t *spm)
+{
+    spm_int_t  i, j;
+    spm_int_t  size, n;
+    spm_int_t *colptr = (spm->fmttype == SpmCSC) ? spm->colptr : spm->rowptr;
+    spm_int_t *rowptr = (spm->fmttype == SpmCSC) ? spm->rowptr : spm->colptr;
+
+    n       = spm->n;
+    for ( j = 0; j < n; j++, colptr++)
+    {
+        size = colptr[1] - colptr[0] - 1;
+        for ( i = 0; i < size; i++, rowptr++ )
+        {
+            /* MergeDuplicate should have been called */
+            if( rowptr[0] == rowptr[1] ) {
+                return 1;
+            }
+        }
+        rowptr++;
+    }
+    return 0;
+}
+
+static inline int
+spm_check_and_correct_check_symmetrize(const spmatrix_t *spm)
+{
+    spm_int_t  i, j, col, row, n, baseval;
+    spm_int_t  found, index, size;
+    spm_int_t *colptr = (spm->fmttype == SpmCSC) ? spm->colptr : spm->rowptr;
+    spm_int_t *coltmp = colptr;
+    spm_int_t *rowptr = (spm->fmttype == SpmCSC) ? spm->rowptr : spm->colptr;
+    spm_int_t *rowtmp = rowptr;
+
+    if ( spm->mtxtype != SpmGeneral ) {
+        return 0;
+    }
+
+    n       = spm->n;
+    baseval = spmFindBase(spm);
+    for ( col = 0; col < n; col++, coltmp++)
+    {
+        for ( i = coltmp[0]; i < coltmp[1]; i++, rowtmp++ )
+        {
+            row = *rowtmp - baseval;
+
+            index = colptr[row] - baseval;
+            size  = colptr[row + 1] - colptr[row];
+            /* Check the symmetry */
+            for ( j = 0; j < size; j++)
+            {
+                found = rowptr[index + j] - baseval;
+                if( found == col ) {
+                    break;
+                }
+                /* We've sort the matrix */
+                if( found > col ) {
+                    return 1;
+                }
+            }
+        }
+    }
+    return 0;
+}
+
+int
+spm_dist_check_and_correct_check( const spmatrix_t *dist )
+{
+    spmatrix_t spm_out, *gathered;
+    int rc, rc1, rc2;
+
+    rc = spmCheckAndCorrect( dist, &spm_out );
+
+    rc1 = 0;
+    rc2 = 0;
+    if( rc == 1 ) {
+
+        /* Sort it */
+        spmSort( &spm_out );
+
+        gathered = spmGather( &spm_out, -1 );
+
+        rc1 = spm_check_and_correct_check_merge_duplicate( gathered );
+        rc2 = spm_check_and_correct_check_symmetrize( gathered );
+
+        spmExit(&spm_out);
+        spmExit(gathered);
+        free(gathered);
+    }
+
+    return rc1+rc2;
+}
+
+int main (int argc, char **argv)
+{
+    char         *filename;
+    spmatrix_t    original, *spmdist, *spm;
+    spm_driver_t  driver;
+    int clustnbr = 1;
+    int clustnum = 0;
+    spm_mtxtype_t mtxtype;
+    spm_fmttype_t fmttype;
+    int baseval, distbycol = 1;
+    int rc = SPM_SUCCESS;
+    int err = 0;
+    int dof, dofmax = 4;
+    int to_free = 0;
+
+    MPI_Init( &argc, &argv );
+
+    /**
+     * Get options from command line
+     */
+    spmGetOptions( argc, argv,
+                   &driver, &filename );
+
+    rc = spmReadDriver( driver, filename, &original );
+    free(filename);
+
+    if ( rc != SPM_SUCCESS ) {
+        fprintf(stderr, "ERROR: Could not read the file, stop the test !!!\n");
+        return EXIT_FAILURE;
+    }
+
+    MPI_Comm_size( MPI_COMM_WORLD, &clustnbr );
+    MPI_Comm_rank( MPI_COMM_WORLD, &clustnum );
+
+    spmPrintInfo( &original, stdout );
+
+    if ( clustnum == 0 ) {
+        printf(" -- SPM check_and_correct Test --\n");
+    }
+
+    for( fmttype=SpmCSC; fmttype<=SpmIJV; fmttype++ )
+    {
+        /* This routine only concerns CSC matrices, and CSR2CSC doesn't exist with MPI */
+        if (fmttype == SpmCSR) {
+            continue;
+        }
+
+        if ( spmConvert( fmttype, &original ) != SPM_SUCCESS ) {
+            fprintf( stderr, "Issue to convert to %s format\n", fmtnames[fmttype] );
+            continue;
+        }
+
+        for( dof=-1; dof<2; dof++ )
+        {
+            if ( dof >= 0 ) {
+                spm = spmDofExtend( &original, dof, dofmax );
+                to_free = 1;
+            }
+            else {
+                spm = malloc(sizeof(spmatrix_t));
+                memcpy( spm, &original, sizeof(spmatrix_t) );
+                to_free = 0;
+            }
+
+            if ( spm == NULL ) {
+                fprintf( stderr, "Issue to extend the matrix\n" );
+                continue;
+            }
+
+            spmdist = spmScatter( spm, -1, NULL, distbycol, -1, MPI_COMM_WORLD );
+            if ( spmdist == NULL ) {
+                fprintf( stderr, "Failed to scatter the spm\n" );
+                err++;
+                continue;
+            }
+
+            for( baseval=0; baseval<2; baseval++ )
+            {
+                spmBase( spmdist, baseval );
+
+                for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ )
+                {
+                    if ( (mtxtype == SpmHermitian) &&
+                        ( ((original.flttype != SpmComplex64) &&
+                           (original.flttype != SpmComplex32)) ||
+                          (original.mtxtype != SpmHermitian) ) )
+                    {
+                        continue;
+                    }
+
+                    if ( (mtxtype != SpmGeneral) &&
+                         (original.mtxtype == SpmGeneral) )
+                    {
+                        continue;
+                    }
+
+                    spmdist->mtxtype = mtxtype;
+
+                    if ( clustnum == 0 ) {
+                        printf( " Case: %s / %s / %d / %s / %d\n",
+                                fltnames[spmdist->flttype],
+                                fmtnames[spmdist->fmttype], baseval,
+                                mtxnames[mtxtype - SpmGeneral], (int)spm->dof );
+                    }
+
+                    rc = spm_dist_check_and_correct_check( spmdist );
+                    err = (rc != 0) ? err+1 : err;
+                }
+            }
+            spmExit( spmdist );
+            free( spmdist );
+
+            if ( spm != &original ) {
+                if( to_free ){
+                    spmExit( spm  );
+                }
+                free( spm );
+            }
+        }
+    }
+
+    spmExit(&original);
+
+    MPI_Finalize();
+
+    if( err == 0 ) {
+        if(clustnum == 0) {
+            printf(" -- All tests PASSED --\n");
+        }
+        return EXIT_SUCCESS;
+    }
+    else
+    {
+        printf(" -- %d tests FAILED --\n", err);
+        return EXIT_FAILURE;
+    }
+}
-- 
GitLab