From 470e5c5fd3f1e6a68c1feb5079dbc4a0c3572113 Mon Sep 17 00:00:00 2001
From: Alycia Lisito <alycia.lisito@inria.fr>
Date: Tue, 7 Jan 2025 12:28:27 +0100
Subject: [PATCH] zgetrf: clean permutation

---
 compute/pzgetrf.c | 160 ++++++++++++++++++++++++++--------------------
 1 file changed, 89 insertions(+), 71 deletions(-)

diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c
index 98f9d0470..7081a1f7f 100644
--- a/compute/pzgetrf.c
+++ b/compute/pzgetrf.c
@@ -338,10 +338,12 @@ chameleon_pzgetrf_panel_facto( struct chameleon_pzgetrf_s *ws,
                                int                         k,
                                RUNTIME_option_t           *options )
 {
+#if defined(CHAMELEON_USE_MPI)
     chameleon_get_proc_involved_in_panelk_2dbc( A, k, k, ws );
     if ( !ws->involved ) {
         return;
     }
+#endif
 
     /* TODO: Should be replaced by a function pointer */
     switch( ws->alg ) {
@@ -392,19 +394,6 @@ chameleon_pzgetrf_panel_permute( struct chameleon_pzgetrf_s *ws,
         int tempkm, tempkn, tempnn, minmn;
         int withlacpy;
 
-        chameleon_get_proc_involved_in_panelk_2dbc( A, k, n, ws );
-        if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
-            INSERT_TASK_zperm_allreduce_send_perm( options, ipiv, k, A->myrank, ws->np_involved, ws->proc_involved );
-            INSERT_TASK_zperm_allreduce_send_invp( options, ipiv, k, A, k, n );
-        }
-        if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
-            INSERT_TASK_zperm_allreduce_send_A( options, A, k, n, A->myrank, ws->np_involved, ws->proc_involved );
-        }
-
-        if ( !ws->involved ) {
-            return;
-        }
-
         tempkm = A->get_blkdim( A, k, DIM_m, A->m );
         tempkn = A->get_blkdim( A, k, DIM_n, A->n );
         tempnn = A->get_blkdim( A, n, DIM_n, A->n );
@@ -457,19 +446,6 @@ chameleon_pzgetrf_panel_permute_batched( struct chameleon_pzgetrf_s *ws,
         int tempkm, tempkn, tempnn, minmn;
         int withlacpy;
 
-        chameleon_get_proc_involved_in_panelk_2dbc( A, k, n, ws );
-        if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
-            INSERT_TASK_zperm_allreduce_send_perm( options, ipiv, k, A->myrank, ws->np_involved, ws->proc_involved );
-            INSERT_TASK_zperm_allreduce_send_invp( options, ipiv, k, A, k, n );
-        }
-        if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
-            INSERT_TASK_zperm_allreduce_send_A( options, A, k, n, A->myrank, ws->np_involved, ws->proc_involved );
-        }
-
-        if ( !ws->involved ) {
-            return;
-        }
-
         void **clargs = malloc( sizeof(char *) );
         *clargs = NULL;
 
@@ -508,6 +484,80 @@ chameleon_pzgetrf_panel_permute_batched( struct chameleon_pzgetrf_s *ws,
     }
 }
 
+static inline void
+chameleon_pzgetrf_panel_permute_forward( struct chameleon_pzgetrf_s *ws,
+                                         CHAM_desc_t                *A,
+                                         CHAM_ipiv_t                *ipiv,
+                                         int                         k,
+                                         int                         n,
+                                         RUNTIME_option_t           *options )
+{
+#if defined(CHAMELEON_USE_MPI)
+    chameleon_get_proc_involved_in_panelk_2dbc( A, k, n, ws );
+    if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
+        INSERT_TASK_zperm_allreduce_send_perm( options, ipiv, k, A->myrank, ws->np_involved, ws->proc_involved );
+        INSERT_TASK_zperm_allreduce_send_invp( options, ipiv, k, A, k, n );
+    }
+    if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
+        INSERT_TASK_zperm_allreduce_send_A( options, A, k, n, A->myrank, ws->np_involved, ws->proc_involved );
+    }
+
+    if ( !ws->involved ) {
+        return;
+    }
+#endif
+
+    if ( ws->batch_size_swap > 0 ) {
+        chameleon_pzgetrf_panel_permute_batched( ws, A, ipiv, k, n, options );
+    }
+    else {
+        chameleon_pzgetrf_panel_permute( ws, A, ipiv, k, n, options );
+    }
+}
+
+static inline void
+chameleon_pzgetrf_panel_permute_backward( struct chameleon_pzgetrf_s *ws,
+                                          CHAM_desc_t                *A,
+                                          CHAM_ipiv_t                *ipiv,
+                                          int                         k,
+                                          int                         n,
+                                          RUNTIME_option_t           *options,
+                                          RUNTIME_sequence_t         *sequence )
+{
+    int tempkm, tempnn;
+
+#if defined(CHAMELEON_USE_MPI)
+    chameleon_get_proc_involved_in_panelk_2dbc( A, k, n, ws );
+    if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
+        INSERT_TASK_zperm_allreduce_send_perm( options, ipiv, k, A->myrank, ws->np_involved, ws->proc_involved );
+        INSERT_TASK_zperm_allreduce_send_invp( options, ipiv, k, A, k, n );
+    }
+    if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
+        INSERT_TASK_zperm_allreduce_send_A( options, A, k, n, A->myrank, ws->np_involved, ws->proc_involved );
+    }
+
+    if ( !ws->involved ) {
+        return;
+    }
+#endif
+
+    if ( ws->batch_size_swap > 0 ) {
+        chameleon_pzgetrf_panel_permute_batched( ws, A, ipiv, k, n, options );
+    }
+    else {
+        chameleon_pzgetrf_panel_permute( ws, A, ipiv, k, n, options );
+    }
+
+    if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
+
+        tempkm = A->get_blkdim( A, k, DIM_m, A->m );
+        tempnn = A->get_blkdim( A, n, DIM_n, A->n );
+        INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
+                            Wu(A->myrank, n), A(k, n) );
+        RUNTIME_data_flush( sequence, A(k, n) );
+    }
+}
+
 static inline void
 chameleon_pzgetrf_panel_update_ws( struct chameleon_pzgetrf_s *ws,
                                    CHAM_desc_t                *A,
@@ -515,7 +565,7 @@ chameleon_pzgetrf_panel_update_ws( struct chameleon_pzgetrf_s *ws,
                                    RUNTIME_option_t           *options )
 {
     CHAM_context_t  *chamctxt = chameleon_context_self();
-    int m, tempmm, tempkn, q;
+    int m, n, tempmm, tempkn, q;
     int lookahead = chamctxt->lookahead;
     int P         = chameleon_desc_datadist_get_iparam(A, 0);
     int Q         = chameleon_desc_datadist_get_iparam(A, 1);
@@ -583,12 +633,7 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws,
     tempkm = A->get_blkdim( A, k, DIM_m, A->m );
     tempnn = A->get_blkdim( A, n, DIM_n, A->n );
 
-    if ( ws->batch_size_swap > 0 ) {
-        chameleon_pzgetrf_panel_permute_batched( ws, A, ipiv, k, n, options );
-    }
-    else {
-        chameleon_pzgetrf_panel_permute( ws, A, ipiv, k, n, options );
-    }
+    chameleon_pzgetrf_panel_permute_forward( ws, A, ipiv, k, n, options );
 
     if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
         for ( p = 0; p < ws->np_involved; p++ ) {
@@ -607,6 +652,7 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws,
             ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
             tempkm, tempnn, A->mb,
             zone, A(k, k),
+            zone, Wu(A->myrank, k),
                   Wu(A->myrank, n) );
     }
 
@@ -682,54 +728,26 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
         }
 
         /* Flush panel k */
-        for (m = k; m < A->mt; m++) {
+        for (m = k+1; m < A->mt; m++) {
             RUNTIME_data_flush( sequence, A(m, k) );
         }
+        RUNTIME_data_flush( sequence, Wu(A->myrank, k) );
 
         RUNTIME_iteration_pop( chamctxt );
     }
     CHAMELEON_Desc_Flush( &(ws->Wl), sequence );
 
     /* Backward pivoting */
-    if ( ws->batch_size > 0 ) {
-        for (k = 1; k < min_mnt; k++) {
-            for (n = 0; n < k; n++) {
-                if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
-                    chameleon_involved_in_panelk_2dbc( A, n ) )
-                {
-                    chameleon_pzgetrf_panel_permute_batched( ws, A, IPIV, k, n, &options );
-                    if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
-                        tempkm = A->get_blkdim( A, k, DIM_m, A->m );
-                        tempnn = A->get_blkdim( A, n, DIM_n, A->n );
-                        INSERT_TASK_zlacpy( &options, ChamUpperLower, tempkm, tempnn,
-                                            Wu(A->myrank, n), A(k, n) );
-                        RUNTIME_data_flush( sequence, A(k, n) );
-                    }
-                }
-                RUNTIME_data_flush( sequence, Wu(A->myrank, n) );
-            }
-            RUNTIME_perm_flushk( sequence, IPIV, k );
-        }
-    }
-    else {
-        for (k = 1; k < min_mnt; k++) {
-            for (n = 0; n < k; n++) {
-                if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
-                    chameleon_involved_in_panelk_2dbc( A, n ) )
-                {
-                    chameleon_pzgetrf_panel_permute( ws, A, IPIV, k, n, &options );
-                    if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
-                        tempkm = A->get_blkdim( A, k, DIM_m, A->m );
-                        tempnn = A->get_blkdim( A, n, DIM_n, A->n );
-                        INSERT_TASK_zlacpy( &options, ChamUpperLower, tempkm, tempnn,
-                                            Wu(A->myrank, n), A(k, n) );
-                        RUNTIME_data_flush( sequence, A(k, n) );
-                    }
-                }
-                RUNTIME_data_flush( sequence, Wu(A->myrank, n) );
+    for (k = 1; k < min_mnt; k++) {
+        for (n = 0; n < k; n++) {
+            if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
+                 chameleon_involved_in_panelk_2dbc( A, n ) )
+            {
+                chameleon_pzgetrf_panel_permute_backward( ws, A, IPIV, k, n, &options, sequence );
             }
-            RUNTIME_perm_flushk( sequence, IPIV, k );
+            RUNTIME_data_flush( sequence, Wu(A->myrank, n) );
         }
+        RUNTIME_perm_flushk( sequence, IPIV, k );
     }
     CHAMELEON_Desc_Flush( &(ws->Wu), sequence );
 
-- 
GitLab