From 488866d4968e6d2b0bf44eefe5a8004684f27638 Mon Sep 17 00:00:00 2001
From: Alycia Lisito <alycia.lisito@inria.fr>
Date: Mon, 21 Oct 2024 14:03:18 +0200
Subject: [PATCH] zgetrf: zperm allreduce MPI in task prepare codelet

---
 compute/pzgetrf.c                             |  5 +-
 include/chameleon/tasks_z.h                   | 26 +++----
 .../openmp/codelets/codelet_zperm_allreduce.c | 12 +--
 .../parsec/codelets/codelet_zperm_allreduce.c | 12 +--
 .../quark/codelets/codelet_zperm_allreduce.c  | 12 +--
 .../starpu/codelets/codelet_zperm_allreduce.c | 74 ++++++++++++-------
 6 files changed, 80 insertions(+), 61 deletions(-)

diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c
index 4347f5710..ca7f4d120 100644
--- a/compute/pzgetrf.c
+++ b/compute/pzgetrf.c
@@ -433,8 +433,7 @@ chameleon_pzgetrf_panel_permute( struct chameleon_pzgetrf_s *ws,
                                     ipiv, k, A(k, n), A(m, n) );
         }
 
-        INSERT_TASK_zperm_allreduce( options, A, ipiv, k, k, n,
-                                     Wu(A->myrank, n), ws );
+        INSERT_TASK_zperm_allreduce( options, A, Wu(A->myrank, n), ipiv, k, k, n, ws );
     }
     break;
     default:
@@ -499,7 +498,7 @@ chameleon_pzgetrf_panel_permute_batched( struct chameleon_pzgetrf_s *ws,
         }
         INSERT_TASK_zlaswp_batched_flush( options, ipiv, k, A(k, n), Wu(A->myrank, n), clargs );
 
-        INSERT_TASK_zperm_allreduce( options, A, ipiv, k, k, n, Wu(A->myrank, n), ws );
+        INSERT_TASK_zperm_allreduce( options, A, Wu(A->myrank, n), ipiv, k, k, n, ws );
 
         free( clargs );
     }
diff --git a/include/chameleon/tasks_z.h b/include/chameleon/tasks_z.h
index 402c92a3f..bf3831af5 100644
--- a/include/chameleon/tasks_z.h
+++ b/include/chameleon/tasks_z.h
@@ -600,6 +600,16 @@ void INSERT_TASK_zipiv_allreduce( const RUNTIME_option_t *options,
  * @param[in] A
  *          The descriptor of the matrix A.
  *
+ * @param[inout] U
+ *          The descriptor of the worskpace used for the permutation in the LU
+ *          factorization with partial pivoting.
+ *
+ * @param[in] Um
+ *          The row index of the tile used in U.
+ *
+ * @param[in] Un
+ *          The column index of the tile used in U.
+ *
  * @param[in] ipiv
  *          The pivot structure that contains the informations for the LU
  *          factorization with partial pivoting.
@@ -613,16 +623,6 @@ void INSERT_TASK_zipiv_allreduce( const RUNTIME_option_t *options,
  * @param[in] n
  *          The number of columns in the tile U(Um, Un).
  *
- * @param[inout] U
- *          The descriptor of the worskpace used for the permutation in the LU
- *          factorization with partial pivoting.
- *
- * @param[in] Um
- *          The row index of the tile used in U.
- *
- * @param[in] Un
- *          The column index of the tile used in U.
- *
  * @param[in] ws
  *          The workspace to handle the data in the LU factorization with
  *          partial pivoting.
@@ -631,13 +631,13 @@ void INSERT_TASK_zipiv_allreduce( const RUNTIME_option_t *options,
  */
 void INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options,
                                   const CHAM_desc_t      *A,
+                                  CHAM_desc_t            *U,
+                                  int                     Um,
+                                  int                     Un,
                                   CHAM_ipiv_t            *ipiv,
                                   int                     ipivk,
                                   int                     k,
                                   int                     n,
-                                  CHAM_desc_t            *U,
-                                  int                     Um,
-                                  int                     Un,
                                   void                   *ws );
 
 /**
diff --git a/runtime/openmp/codelets/codelet_zperm_allreduce.c b/runtime/openmp/codelets/codelet_zperm_allreduce.c
index cb77c806b..7aeb24fae 100644
--- a/runtime/openmp/codelets/codelet_zperm_allreduce.c
+++ b/runtime/openmp/codelets/codelet_zperm_allreduce.c
@@ -71,23 +71,23 @@ INSERT_TASK_zperm_allreduce_send_invp( const RUNTIME_option_t *options,
 void
 INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options,
                              const CHAM_desc_t      *A,
+                             CHAM_desc_t            *U,
+                             int                     Um,
+                             int                     Un,
                              CHAM_ipiv_t            *ipiv,
                              int                     ipivk,
                              int                     k,
                              int                     n,
-                             CHAM_desc_t            *U,
-                             int                     Um,
-                             int                     Un,
                              void                   *ws )
 {
     (void)options;
     (void)A;
+    (void)U;
+    (void)Um;
+    (void)Un;
     (void)ipiv;
     (void)ipivk;
     (void)k;
     (void)n;
-    (void)U;
-    (void)Um;
-    (void)Un;
     (void)ws;
 }
diff --git a/runtime/parsec/codelets/codelet_zperm_allreduce.c b/runtime/parsec/codelets/codelet_zperm_allreduce.c
index 30890f811..5acfa4a2b 100644
--- a/runtime/parsec/codelets/codelet_zperm_allreduce.c
+++ b/runtime/parsec/codelets/codelet_zperm_allreduce.c
@@ -71,23 +71,23 @@ INSERT_TASK_zperm_allreduce_send_invp( const RUNTIME_option_t *options,
 void
 INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options,
                              const CHAM_desc_t      *A,
+                             CHAM_desc_t            *U,
+                             int                     Um,
+                             int                     Un,
                              CHAM_ipiv_t            *ipiv,
                              int                     ipivk,
                              int                     k,
                              int                     n,
-                             CHAM_desc_t            *U,
-                             int                     Um,
-                             int                     Un,
                              void                   *ws )
 {
     (void)options;
     (void)A;
+    (void)U;
+    (void)Um;
+    (void)Un;
     (void)ipiv;
     (void)ipivk;
     (void)k;
     (void)n;
-    (void)U;
-    (void)Um;
-    (void)Un;
     (void)ws;
 }
diff --git a/runtime/quark/codelets/codelet_zperm_allreduce.c b/runtime/quark/codelets/codelet_zperm_allreduce.c
index 52281451d..f6c5f98e6 100644
--- a/runtime/quark/codelets/codelet_zperm_allreduce.c
+++ b/runtime/quark/codelets/codelet_zperm_allreduce.c
@@ -71,23 +71,23 @@ INSERT_TASK_zperm_allreduce_send_invp( const RUNTIME_option_t *options,
 void
 INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options,
                              const CHAM_desc_t      *A,
+                             CHAM_desc_t            *U,
+                             int                     Um,
+                             int                     Un,
                              CHAM_ipiv_t            *ipiv,
                              int                     ipivk,
                              int                     k,
                              int                     n,
-                             CHAM_desc_t            *U,
-                             int                     Um,
-                             int                     Un,
                              void                   *ws )
 {
     (void)options;
     (void)A;
+    (void)U;
+    (void)Um;
+    (void)Un;
     (void)ipiv;
     (void)ipivk;
     (void)k;
     (void)n;
-    (void)U;
-    (void)Um;
-    (void)Un;
     (void)ws;
 }
diff --git a/runtime/starpu/codelets/codelet_zperm_allreduce.c b/runtime/starpu/codelets/codelet_zperm_allreduce.c
index 4c33a2e50..1c8d44164 100644
--- a/runtime/starpu/codelets/codelet_zperm_allreduce.c
+++ b/runtime/starpu/codelets/codelet_zperm_allreduce.c
@@ -102,14 +102,14 @@ INSERT_TASK_zperm_allreduce_recv( const RUNTIME_option_t *options,
 {
     struct cl_redux_args_t *clargs;
     clargs = malloc( sizeof( struct cl_redux_args_t ) );
-    clargs->tempmm = tempmm;
-    clargs->n      = n;
-    clargs->p      = p;
-    clargs->q      = q;
-    clargs->p_first  = p_first;
-    clargs->me     = me;
-    clargs->shift  = shift;
-    clargs->np_inv = np;
+    clargs->tempmm  = tempmm;
+    clargs->n       = n;
+    clargs->p       = p;
+    clargs->q       = q;
+    clargs->p_first = p_first;
+    clargs->me      = me;
+    clargs->shift   = shift;
+    clargs->np_inv  = np;
 
     rt_starpu_insert_task(
         &cl_zperm_allreduce,
@@ -124,20 +124,19 @@ INSERT_TASK_zperm_allreduce_recv( const RUNTIME_option_t *options,
     starpu_mpi_cache_flush( options->sequence->comm, RTBLKADDR(U, CHAMELEON_Complex64_t, src, n) );
 }
 
-void
-INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options,
-                             const CHAM_desc_t      *A,
-                             CHAM_ipiv_t            *ipiv,
-                             int                     ipivk,
-                             int                     k,
-                             int                     n,
-                             CHAM_desc_t            *U,
-                             int                     Um,
-                             int                     Un,
-                             void                   *ws )
+static void
+zperm_allreduce_chameleon_starpu_task( const RUNTIME_option_t     *options,
+                                       const CHAM_desc_t          *A,
+                                       CHAM_desc_t                *U,
+                                       int                         Um,
+                                       int                         Un,
+                                       CHAM_ipiv_t                *ipiv,
+                                       int                         ipivk,
+                                       int                         k,
+                                       int                         n,
+                                       struct chameleon_pzgetrf_s *ws)
 {
-    struct chameleon_pzgetrf_s *tmp = (struct chameleon_pzgetrf_s *)ws;
-    int *proc_involved = tmp->proc_involved;
+    int *proc_involved = ws->proc_involved;
     int  np_involved   = chameleon_min( chameleon_desc_datadist_get_iparam(A, 0), A->mt - k);
     int  np_iter       = np_involved;
     int  p_recv, p_send, me, p_first;
@@ -169,6 +168,27 @@ INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options,
     }
 }
 
+void
+INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options,
+                             const CHAM_desc_t      *A,
+                             CHAM_desc_t            *U,
+                             int                     Um,
+                             int                     Un,
+                             CHAM_ipiv_t            *ipiv,
+                             int                     ipivk,
+                             int                     k,
+                             int                     n,
+                             void                   *ws )
+{
+    struct chameleon_pzgetrf_s *tmp = (struct chameleon_pzgetrf_s *)ws;
+    cham_getrf_allreduce_t alg = tmp->alg_allreduce;
+    switch( alg ) {
+    case ChamStarPUTasks:
+    default:
+        zperm_allreduce_chameleon_starpu_task( options, A, U, Um, Un, ipiv, ipivk, k, n, tmp );
+    }
+}
+
 void
 INSERT_TASK_zperm_allreduce_send_A( const RUNTIME_option_t *options,
                                     CHAM_desc_t            *A,
@@ -284,24 +304,24 @@ INSERT_TASK_zperm_allreduce_send_invp( const RUNTIME_option_t *options,
 void
 INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options,
                              const CHAM_desc_t      *A,
+                             CHAM_desc_t            *U,
+                             int                     Um,
+                             int                     Un,
                              CHAM_ipiv_t            *ipiv,
                              int                     ipivk,
                              int                     k,
                              int                     n,
-                             CHAM_desc_t            *U,
-                             int                     Um,
-                             int                     Un,
                              void                   *ws )
 {
     (void)options;
     (void)A;
+    (void)U;
+    (void)Um;
+    (void)Un;
     (void)ipiv;
     (void)ipivk;
     (void)k;
     (void)n;
-    (void)U;
-    (void)Um;
-    (void)Un;
     (void)ws;
 }
 #endif
-- 
GitLab