From 7fc32871fc9a203c8c3236041da2e575384f9a0b Mon Sep 17 00:00:00 2001
From: Ana Hourcau <ahourcau@sirocco15.plafrim.cluster>
Date: Mon, 19 Aug 2024 10:49:35 +0200
Subject: [PATCH] Preventing the conversion in half precision for diagonal
 tiles

---
 compute/pzgered.c |  54 +++++++++++++++----------
 compute/pzhered.c | 100 ++++++++++++++++++++++++----------------------
 2 files changed, 85 insertions(+), 69 deletions(-)

diff --git a/compute/pzgered.c b/compute/pzgered.c
index c43e9a0e4..e7feeed9d 100644
--- a/compute/pzgered.c
+++ b/compute/pzgered.c
@@ -28,8 +28,10 @@
 #define W( desc, m, n ) (desc), (m), (n)
 
 static inline void
-chameleon_pzgered_frb( cham_uplo_t uplo,
-                       CHAM_desc_t *A, CHAM_desc_t *Wnorm, CHAM_desc_t *Welt,
+chameleon_pzgered_frb( cham_uplo_t       uplo,
+                       CHAM_desc_t      *A,
+                       CHAM_desc_t      *Wnorm,
+                       CHAM_desc_t      *Welt,
                        RUNTIME_option_t *options )
 {
     double alpha = 1.0;
@@ -155,14 +157,17 @@ chameleon_pzgered_frb( cham_uplo_t uplo,
 /**
  *
  */
-void chameleon_pzgered( cham_uplo_t uplo, double prec, CHAM_desc_t *A,
-                         RUNTIME_sequence_t *sequence, RUNTIME_request_t *request )
+void chameleon_pzgered( cham_uplo_t         uplo,
+                        double              prec,
+                        CHAM_desc_t        *A,
+                        RUNTIME_sequence_t *sequence,
+                        RUNTIME_request_t  *request )
 {
     CHAM_context_t *chamctxt;
     RUNTIME_option_t options;
     CHAM_desc_t Wcol;
     CHAM_desc_t Welt;
-    double gnorm, threshold, eps;
+    double gnorm, threshold, eps, eps_diag, threshold_diag;
 
     int workmt, worknt;
     int m, n;
@@ -202,37 +207,36 @@ void chameleon_pzgered( cham_uplo_t uplo, double prec, CHAM_desc_t *A,
     /**
      * Reduce the precision of the tiles if possible
      */
+    eps_diag = CHAMELEON_slamch();
     if ( prec < 0. ) {
-#if !defined(CHAMELEON_SIMULATION)
-        eps = LAPACKE_dlamch_work('e');
-#else
-#if defined(PRECISION_z) || defined(PRECISION_d)
-        eps = 1.e-15;
-#else
-        eps = 1.e-7;
-#endif
-#endif
+        eps = CHAMELEON_dlamch();
     }
     else {
         eps = prec;
     }
     threshold = (eps * gnorm) / (double)(chameleon_min(A->mt, A->nt));
+    threshold_diag = ( eps < eps_diag ) ? threshold : (eps_diag * gnorm) / (double)(chameleon_min(A->mt, A->nt));
 
 #if defined(CHAMELEON_DEBUG_GERED)
     fprintf( stderr,
              "[%2d] The norm of A is:           %e\n"
              "[%2d] The requested precision is: %e\n"
-             "[%2d] The computed threshold is:  %e\n",
+             "[%2d] The computed threshold is:  %e\n"
+             "[%2d] The threshold diag is :     %e\n",
              A->myrank, gnorm,
              A->myrank, eps,
-             A->myrank, threshold );
+             A->myrank, threshold,
+             A->myrank, threshold_diag );
 #endif
-    for(m = 0; m < A->mt; m++) {
+
+    for(m = 0; m < A->mt; m++)
+    {
         int tempmm = ( m == (A->mt-1) ) ? A->m - m * A->mb : A->mb;
         int nmin   = ( uplo == ChamUpper ) ? m                         : 0;
         int nmax   = ( uplo == ChamLower ) ? chameleon_min(m+1, A->nt) : A->nt;
 
-        for(n = nmin; n < nmax; n++) {
+        for(n = nmin; n < nmax; n++)
+        {
             int tempnn = ( n == (A->nt-1) ) ? A->n - n * A->nb : A->nb;
 
             /*
@@ -241,8 +245,14 @@ void chameleon_pzgered( cham_uplo_t uplo, double prec, CHAM_desc_t *A,
              * ||A_{i,j}||_F  < u_{high} * || A ||_F / (nt * u_{low})
              * ||A_{i,j}||_F  < threshold / u_{low}
              */
-            INSERT_TASK_zgered( &options, threshold,
-                                tempmm, tempnn, A( m, n ), W( &Wcol, m, n ) );
+            if ( m == n ) {
+                INSERT_TASK_zgered( &options, threshold_diag,
+                                    tempmm, tempnn, A( m, n ), W( &Wcol, m, n ) );
+            }
+            else {
+                INSERT_TASK_zgered( &options, threshold,
+                                    tempmm, tempnn, A( m, n ), W( &Wcol, m, n ) );
+            }
         }
     }
 
@@ -250,6 +260,6 @@ void chameleon_pzgered( cham_uplo_t uplo, double prec, CHAM_desc_t *A,
     RUNTIME_sequence_wait( chamctxt, sequence );
 
     chameleon_desc_destroy( &Wcol );
-    RUNTIME_options_ws_free(&options);
-    RUNTIME_options_finalize(&options, chamctxt);
+    RUNTIME_options_ws_free( &options );
+    RUNTIME_options_finalize( &options, chamctxt );
 }
diff --git a/compute/pzhered.c b/compute/pzhered.c
index cc32f4243..869d748cf 100644
--- a/compute/pzhered.c
+++ b/compute/pzhered.c
@@ -28,8 +28,11 @@
 #define W(desc, m, n) (desc), (m), (n)
 
 static inline void
-chameleon_pzhered_frb( cham_trans_t trans, cham_uplo_t uplo,
-                       CHAM_desc_t *A, CHAM_desc_t *Wnorm, CHAM_desc_t *Welt,
+chameleon_pzhered_frb( cham_trans_t      trans,
+                       cham_uplo_t       uplo,
+                       CHAM_desc_t      *A,
+                       CHAM_desc_t      *Wnorm,
+                       CHAM_desc_t      *Welt,
                        RUNTIME_option_t *options )
 {
     double alpha = 1.0;
@@ -84,8 +87,7 @@ chameleon_pzhered_frb( cham_trans_t trans, cham_uplo_t uplo,
         {
             int tempnn = (n == (NT - 1)) ? N - n * A->nb : A->nb;
 
-            if (n == m)
-            {
+            if ( n == m ) {
                 if ( trans == ChamConjTrans ) {
                     INSERT_TASK_zhessq(
                         options, ChamEltwise, uplo, tempmm,
@@ -97,8 +99,7 @@ chameleon_pzhered_frb( cham_trans_t trans, cham_uplo_t uplo,
                         A(m, n), W( Wnorm, m, n) );
                 }
             }
-            else
-            {
+            else {
                 INSERT_TASK_zgessq(
                     options, ChamEltwise, tempmm, tempnn,
                     A(m, n), W( Wnorm, m, n ));
@@ -166,7 +167,7 @@ chameleon_pzhered_frb( cham_trans_t trans, cham_uplo_t uplo,
     {
         for (n = 0; n < A->q; n++)
         {
-            if ((m != 0) || (n != 0))
+            if ( ( m != 0 ) || ( n != 0 ) )
             {
                 INSERT_TASK_dlacpy(
                     options,
@@ -180,14 +181,18 @@ chameleon_pzhered_frb( cham_trans_t trans, cham_uplo_t uplo,
 /**
  *
  */
-void chameleon_pzhered( cham_trans_t trans, cham_uplo_t uplo, double prec, CHAM_desc_t *A,
-                        RUNTIME_sequence_t *sequence, RUNTIME_request_t *request )
+void chameleon_pzhered( cham_trans_t        trans,
+                        cham_uplo_t         uplo,
+                        double              prec,
+                        CHAM_desc_t        *A,
+                        RUNTIME_sequence_t *sequence,
+                        RUNTIME_request_t  *request )
 {
     CHAM_context_t *chamctxt;
     RUNTIME_option_t options;
     CHAM_desc_t Wcol;
     CHAM_desc_t Welt;
-    double gnorm, threshold, eps;
+    double gnorm, threshold, eps, eps_diag, threshold_diag;
 
     int workmt, worknt;
     int m, n;
@@ -205,22 +210,22 @@ void chameleon_pzhered( cham_trans_t trans, cham_uplo_t uplo, double prec, CHAM_
     RUNTIME_options_ws_alloc(&options, 1, 0);
 
     /* Matrix to store the norm of each element */
-    chameleon_desc_init(&Wcol, CHAMELEON_MAT_ALLOC_GLOBAL, ChamRealDouble, 2, 1, 2,
-                        A->mt * 2, A->nt, 0, 0, A->mt * 2, A->nt, A->p, A->q,
-                        NULL, NULL, A->get_rankof_init, A->get_rankof_init_arg);
+    chameleon_desc_init( &Wcol, CHAMELEON_MAT_ALLOC_GLOBAL, ChamRealDouble, 2, 1, 2,
+                         A->mt * 2, A->nt, 0, 0, A->mt * 2, A->nt, A->p, A->q,
+                         NULL, NULL, A->get_rankof_init, A->get_rankof_init_arg );
 
     /* Matrix to compute the global frobenius norm */
-    chameleon_desc_init(&Welt, CHAMELEON_MAT_ALLOC_GLOBAL, ChamRealDouble, 2, 1, 2,
-                        workmt * 2, worknt, 0, 0, workmt * 2, worknt, A->p, A->q,
-                        NULL, NULL, NULL, NULL);
+    chameleon_desc_init( &Welt, CHAMELEON_MAT_ALLOC_GLOBAL, ChamRealDouble, 2, 1, 2,
+                         workmt * 2, worknt, 0, 0, workmt * 2, worknt, A->p, A->q,
+                         NULL, NULL, NULL, NULL );
 
     chameleon_pzhered_frb( trans, uplo, A, &Wcol, &Welt, &options );
 
-    CHAMELEON_Desc_Flush(&Wcol, sequence);
-    CHAMELEON_Desc_Flush(&Welt, sequence);
-    CHAMELEON_Desc_Flush(A, sequence);
+    CHAMELEON_Desc_Flush( &Wcol, sequence );
+    CHAMELEON_Desc_Flush( &Welt, sequence );
+    CHAMELEON_Desc_Flush( A,     sequence );
 
-    RUNTIME_sequence_wait(chamctxt, sequence);
+    RUNTIME_sequence_wait( chamctxt, sequence );
 
     gnorm = *((double *)Welt.get_blkaddr(&Welt, A->myrank / A->q, A->myrank % A->q));
     chameleon_desc_destroy(&Welt);
@@ -228,33 +233,28 @@ void chameleon_pzhered( cham_trans_t trans, cham_uplo_t uplo, double prec, CHAM_
     /**
      * Reduce the precision of the tiles if possible
      */
-    if (prec < 0.)
-    {
-#if !defined(CHAMELEON_SIMULATION)
-        eps = LAPACKE_dlamch_work('e');
-#else
-#if defined(PRECISION_z) || defined(PRECISION_d)
-        eps = 1.e-15;
-#else
-        eps = 1.e-7;
-#endif
-#endif
+    eps_diag = CHAMELEON_slamch();
+    if (prec < 0.) {
+        eps = CHAMELEON_dlamch();
     }
-    else
-    {
+    else {
         eps = prec;
     }
     threshold = (eps * gnorm) / (double)(chameleon_min(A->mt, A->nt));
+    threshold_diag = (eps < eps_diag) ? threshold : (eps_diag * gnorm) / (double)(chameleon_min(A->mt, A->nt));
 
 #if defined(CHAMELEON_DEBUG_GERED)
-    fprintf(stderr,
-            "[%2d] The norm of A is:           %e\n"
-            "[%2d] The requested precision is: %e\n"
-            "[%2d] The computed threshold is:  %e\n",
-            A->myrank, gnorm,
-            A->myrank, eps,
-            A->myrank, threshold);
+    fprintf( stderr,
+             "[%2d] The norm of A is:           %e\n"
+             "[%2d] The requested precision is: %e\n"
+             "[%2d] The computed threshold is:  %e\n"
+             "[%2d] The threshold diag is:      %e\n",
+             A->myrank, gnorm,
+             A->myrank, eps,
+             A->myrank, threshold,
+             A->myrank, threshold_diag );
 #endif
+
     for (m = 0; m < A->mt; m++)
     {
         int tempmm = (m == (A->mt - 1)) ? A->m - m * A->mb : A->mb;
@@ -271,15 +271,21 @@ void chameleon_pzhered( cham_trans_t trans, cham_uplo_t uplo, double prec, CHAM_
              * ||A_{i,j}||_F  < u_{high} * || A ||_F / (nt * u_{low})
              * ||A_{i,j}||_F  < threshold / u_{low}
              */
-            INSERT_TASK_zgered( &options, threshold,
-                                tempmm, tempnn, A( m, n ), W( &Wcol, m, n ) );
+            if ( m == n ) {
+                INSERT_TASK_zgered( &options, threshold_diag,
+                                    tempmm, tempnn, A( m, n ), W( &Wcol, m, n ) );
+            }
+            else {
+                INSERT_TASK_zgered( &options, threshold,
+                                    tempmm, tempnn, A( m, n ), W( &Wcol, m, n ) );
+            }
         }
     }
 
-    CHAMELEON_Desc_Flush(A, sequence);
-    RUNTIME_sequence_wait(chamctxt, sequence);
+    CHAMELEON_Desc_Flush( A, sequence );
+    RUNTIME_sequence_wait( chamctxt, sequence );
 
-    chameleon_desc_destroy(&Wcol);
-    RUNTIME_options_ws_free(&options);
-    RUNTIME_options_finalize(&options, chamctxt);
+    chameleon_desc_destroy( &Wcol );
+    RUNTIME_options_ws_free( &options );
+    RUNTIME_options_finalize( &options, chamctxt );
 }
-- 
GitLab