From 0e64bc8b2455a84da029f732e170c12d66a05e8a Mon Sep 17 00:00:00 2001
From: Alycia Lisito <alycia.lisito@inria.fr>
Date: Thu, 13 Feb 2025 17:43:55 +0100
Subject: [PATCH] zgetrf: send Akk before allreduce

---
 compute/pzgetrf.c                      | 58 +++++++++++++++++++-------
 control/descriptor_helpers.c           | 20 +++++++++
 include/chameleon/descriptor_helpers.h |  1 +
 3 files changed, 64 insertions(+), 15 deletions(-)

diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c
index 7081a1f7f..635bbbb84 100644
--- a/compute/pzgetrf.c
+++ b/compute/pzgetrf.c
@@ -565,7 +565,7 @@ chameleon_pzgetrf_panel_update_ws( struct chameleon_pzgetrf_s *ws,
                                    RUNTIME_option_t           *options )
 {
     CHAM_context_t  *chamctxt = chameleon_context_self();
-    int m, n, tempmm, tempkn, q;
+    int m, n, tempmm, tempkn, tempkm, p, q, involved, np;
     int lookahead = chamctxt->lookahead;
     int P         = chameleon_desc_datadist_get_iparam(A, 0);
     int Q         = chameleon_desc_datadist_get_iparam(A, 1);
@@ -610,6 +610,44 @@ chameleon_pzgetrf_panel_update_ws( struct chameleon_pzgetrf_s *ws,
             RUNTIME_data_flush( options->sequence, A(m, k) );
         }
     }
+
+    tempkm = A->get_blkdim( A, k, DIM_m, A->m );
+    np = chameleon_desc_datadist_get_iparam(A, 1) * chameleon_desc_datadist_get_iparam(A, 0);
+#if defined(CHAMELEON_USE_MPI)
+    /* Send Akk for replicated trsm */
+    if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
+        for ( p = 0; p < np; p++ ) {
+            involved = 0;
+            for ( n = k+1; n < A->nt; n++ ) {
+                if ( chameleon_p_involved_in_panelk_2dbc( A, n, p ) ) {
+                    involved = 1;
+                    break;
+                }
+            }
+            if ( involved ) {
+                INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempkn,
+                                    A(k, k), Wu(p, k) );
+            }
+        }
+    }
+    else {
+        involved = 0;
+        for ( n = k+1; n < A->nt; n++ ) {
+            if ( chameleon_involved_in_panelk_2dbc( A, n ) ) {
+                involved = 1;
+                break;
+            }
+        }
+        if ( involved ) {
+            INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempkn,
+                                A(k, k), Wu(A->myrank, k) );
+        }
+    }
+#else
+    INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempkn,
+                        A(k, k), Wu(A->myrank, k) );
+#endif
+    RUNTIME_data_flush( options->sequence, A(k, k) );
 }
 
 static inline void
@@ -635,23 +673,14 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws,
 
     chameleon_pzgetrf_panel_permute_forward( ws, A, ipiv, k, n, options );
 
-    if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
-        for ( p = 0; p < ws->np_involved; p++ ) {
-            INSERT_TASK_ztrsm(
-                options,
-                ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
-                tempkm, tempnn, A->mb,
-                zone, A(k, k),
-                      Wu(ws->proc_involved[p], n) );
-            RUNTIME_data_flush( options->sequence, Wu(ws->proc_involved[p], n) );
-        }
-    }
-    else if ( ws->involved ) {
+#if defined(CHAMELEON_USE_MPI)
+    if ( ws->involved )
+#endif
+    {
         INSERT_TASK_ztrsm(
             options,
             ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
             tempkm, tempnn, A->mb,
-            zone, A(k, k),
             zone, Wu(A->myrank, k),
                   Wu(A->myrank, n) );
     }
@@ -677,7 +706,6 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws,
     }
 
     RUNTIME_data_flush( options->sequence, Wu(A->myrank, n) );
-    RUNTIME_data_flush( options->sequence, A(k, k) );
     RUNTIME_data_flush( options->sequence, A(k, n) );
 }
 
diff --git a/control/descriptor_helpers.c b/control/descriptor_helpers.c
index d5e143063..6a0492111 100644
--- a/control/descriptor_helpers.c
+++ b/control/descriptor_helpers.c
@@ -100,6 +100,26 @@ int chameleon_involved_in_panelk_2dbc( const CHAM_desc_t *A, int k ) {
     return ( myrank % chameleon_desc_datadist_get_iparam(A,1) == k % chameleon_desc_datadist_get_iparam(A,1) );
 }
 
+/**
+ * @brief Test if the MPI process p is involved in the panel k for 2DBC distributions.
+ *
+ * @param[in] A
+ *        The matrix descriptor.
+ *
+ * @param[in] k
+ *        The index of the panel to test.
+ *
+ * @param[in] p
+ *        The rank of the MPI process.
+ *
+ * @return 1 if the current MPI process contributes to the panel k.
+ *         0 if the current MPI process doesn't contribute to the panel k.
+ *
+ */
+int chameleon_p_involved_in_panelk_2dbc( const CHAM_desc_t *A, int k, int p ) {
+    return ( p % chameleon_desc_datadist_get_iparam(A,1) == k % chameleon_desc_datadist_get_iparam(A,1) );
+}
+
 /**
  * @brief Test if the current MPI process is involved in the panel k for 2DBC distributions.
  *
diff --git a/include/chameleon/descriptor_helpers.h b/include/chameleon/descriptor_helpers.h
index 9e60ef27d..f8caf5080 100644
--- a/include/chameleon/descriptor_helpers.h
+++ b/include/chameleon/descriptor_helpers.h
@@ -64,6 +64,7 @@ int chameleon_getrankof_custom        ( const CHAM_desc_t *A, int m, int n );
  */
 
 int chameleon_involved_in_panelk_2dbc( const CHAM_desc_t *A, int An );
+int chameleon_p_involved_in_panelk_2dbc( const CHAM_desc_t *A, int k, int p );
 void chameleon_get_proc_involved_in_panelk_2dbc( const CHAM_desc_t *A,
                                                  int                k,
                                                  int                n,
-- 
GitLab