From eb1e9ed0f7532b109943e3cb419d8044c59e2b55 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Tue, 22 Dec 2020 21:38:06 +0100
Subject: [PATCH] Fix corner case in C examples

---
 examples/example_drivers.c |  4 +++-
 examples/example_lap1.c    |  4 +++-
 examples/example_lap2.c    |  4 +++-
 examples/example_mdof1.c   |  4 +++-
 examples/example_mdof2.c   |  4 +++-
 src/z_spm_genrhs.c         | 20 ++++++++++++++++----
 6 files changed, 31 insertions(+), 9 deletions(-)

diff --git a/examples/example_drivers.c b/examples/example_drivers.c
index bf72128e..d132166e 100644
--- a/examples/example_drivers.c
+++ b/examples/example_drivers.c
@@ -65,7 +65,9 @@ int main( int argc, char **argv )
     /*
      * Scale the sparse matrix.
      */
-    spmScalMatrix( 1. / norm, &spm );
+    if ( norm > 0. ) {
+        spmScalMatrix( 1. / norm, &spm );
+    }
 
     /*
      * Create a random vector x to test products.
diff --git a/examples/example_lap1.c b/examples/example_lap1.c
index a19e5689..51263da8 100644
--- a/examples/example_lap1.c
+++ b/examples/example_lap1.c
@@ -189,7 +189,9 @@ int main( int argc, char **argv )
     /*
      * Scale the sparse matrix.
      */
-    spmScalMatrix( 1. / norm, &spm );
+    if ( norm > 0. ) {
+        spmScalMatrix( 1. / norm, &spm );
+    }
 
     /*
      * Create a random matrix x to test products (multiple right hand side).
diff --git a/examples/example_lap2.c b/examples/example_lap2.c
index 2be28b49..6a45e8d6 100644
--- a/examples/example_lap2.c
+++ b/examples/example_lap2.c
@@ -190,7 +190,9 @@ int main (int argc, char **argv)
     /*
      * Scale the sparse matrix.
      */
-    spmScalMatrix( 1. / norm, &spm );
+    if ( norm > 0. ) {
+        spmScalMatrix( 1. / norm, &spm );
+    }
 
     /*
      * Create a random matrix x to test products (multiple right hand side).
diff --git a/examples/example_mdof1.c b/examples/example_mdof1.c
index 8b8bcff3..9713fec2 100644
--- a/examples/example_mdof1.c
+++ b/examples/example_mdof1.c
@@ -254,7 +254,9 @@ int main( int argc, char **argv )
     /*
      * Scale the sparse matrix.
      */
-    spmScalMatrix( 1. / norm, &spm );
+    if ( norm > 0. ) {
+        spmScalMatrix( 1. / norm, &spm );
+    }
 
     /*
      * Create a random matrix x to test products (multiple right hand side).
diff --git a/examples/example_mdof2.c b/examples/example_mdof2.c
index 497c6e9d..c25fa626 100644
--- a/examples/example_mdof2.c
+++ b/examples/example_mdof2.c
@@ -217,7 +217,9 @@ int main (int argc, char **argv)
     /*
      * Scale the sparse matrix.
      */
-    spmScalMatrix( 1. / norm, &spm );
+    if ( norm > 0. ) {
+        spmScalMatrix( 1. / norm, &spm );
+    }
 
     /*
      * Create a random matrix x to test products (multiple right hand side).
diff --git a/src/z_spm_genrhs.c b/src/z_spm_genrhs.c
index 72e28d0b..e21fec9b 100644
--- a/src/z_spm_genrhs.c
+++ b/src/z_spm_genrhs.c
@@ -119,6 +119,9 @@ z_spmGenRHS( spm_rhstype_t     type,
     if ( type == SpmRhsRndB ) {
         /* Compute the spm norm to scale the b vector */
         spm_complex64_t norm = z_spmNorm( SpmFrobeniusNorm, spm );
+        if ( norm == 0. ) {
+            norm = 1.;
+        }
         z_spmGenMat( type, nrhs, spm, &norm, 24356, bptr, ldb );
 
         return SPM_SUCCESS;
@@ -228,11 +231,12 @@ z_spmCheckAxb( spm_fixdbl_t eps, int nrhs,
         double norm;
 
         norm  = LAPACKE_zlange( LAPACK_COL_MAJOR, 'I', spm->nexp, 1, zb + i * ldb, ldb );
-        normB = (norm > normB ) ? norm : normB;
+        normB = ( norm > normB ) ? norm : normB;
         norm  = LAPACKE_zlange( LAPACK_COL_MAJOR, 'I', spm->nexp, 1, zx + i * ldx, ldx );
-        normX = (norm > normX ) ? norm : normX;
+        normX = ( norm > normX ) ? norm : normX;
 
         nb2[i] = cblas_dznrm2( spm->nexp, zb + i * ldb, 1 );
+        if ( nb2[i] == 0. ) { nb2[i] = 1.; }
     }
     fprintf( stdout,
              "   || A ||_1                                               %e\n"
@@ -254,9 +258,16 @@ z_spmCheckAxb( spm_fixdbl_t eps, int nrhs,
         double nx   = cblas_dzasum( spm->nexp, zx + i * ldx, 1 );
         double nr   = cblas_dzasum( spm->nexp, zb + i * ldb, 1 );
         double nr2  = cblas_dznrm2( spm->nexp, zb + i * ldb, 1 ) / nb2[i];
-        double back =  ((nr / normA) / nx) / eps;
+        double back =  ( nr / eps );
         int fail = 0;
 
+        if ( normA > 0. ) {
+            nr = nr / normA;
+        }
+        if ( nx > 0. ) {
+            nr = nr / nx;
+        }
+
         normR    = (nr   > normR   ) ? nr   : normR;
         normR2   = (nr2  > normR2  ) ? nr2  : normR2;
         backward = (back > backward) ? back : backward;
@@ -307,7 +318,8 @@ z_spmCheckAxb( spm_fixdbl_t eps, int nrhs,
 
             nr = LAPACKE_zlange( LAPACK_COL_MAJOR, 'I', spm->nexp, 1, zx0, ldx0 );
 
-            forw = (nr / nx0) / eps;
+            forw = nr / eps;
+            if ( nx0 > 0. ) { forw = forw / nx0; }
 
             normX0  = ( nx   > normX0  ) ? nx   : normX0;
             normR   = ( nr   > normR   ) ? nr   : normR;
-- 
GitLab