From 6b88697d24bbc62aeaf8f7d9337e741b3d248ba7 Mon Sep 17 00:00:00 2001
From: Alycia Lisito <alycia.lisito@inria.fr>
Date: Thu, 5 Sep 2024 15:06:06 +0200
Subject: [PATCH] zgetrf: zperm distributed use Wu for update

---
 compute/pzgetrf.c                             | 125 ++++++++++++++----
 compute/zgetrf.c                              |  10 ++
 control/compute_z.h                           |  16 ++-
 include/chameleon/tasks_z.h                   |  12 +-
 .../openmp/codelets/codelet_zlaswp_batched.c  |  44 +++---
 .../parsec/codelets/codelet_zlaswp_batched.c  |  44 +++---
 .../quark/codelets/codelet_zlaswp_batched.c   |  44 +++---
 runtime/starpu/codelets/codelet_zlaswp.c      |   6 +
 .../starpu/codelets/codelet_zlaswp_batched.c  |  34 +++--
 9 files changed, 233 insertions(+), 102 deletions(-)

diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c
index 5ecbe2959..5b63d1e01 100644
--- a/compute/pzgetrf.c
+++ b/compute/pzgetrf.c
@@ -26,6 +26,7 @@
 #define A(m,n)  A,        m, n
 #define U(m,n)  &(ws->U), m, n
 #define Up(m,n)  &(ws->Up), m, n
+#define Wu(m,n)  &(ws->Wu), m, n
 
 /*
  * All the functions below are panel factorization variant.
@@ -389,6 +390,19 @@ chameleon_pzgetrf_panel_permute( struct chameleon_pzgetrf_s *ws,
         int m;
         int tempkm, tempkn, tempnn, minmn;
 
+        chameleon_get_proc_involved_in_panelk_2dbc( A, k, n, ws );
+        if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
+            INSERT_TASK_zperm_allreduce_send_perm( options, ipiv, k, A->myrank, ws->np_involved, ws->proc_involved );
+            INSERT_TASK_zperm_allreduce_send_invp( options, ipiv, k, A, k, n );
+        }
+        if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
+            INSERT_TASK_zperm_allreduce_send_A( options, A, k, n, A->myrank, ws->np_involved, ws->proc_involved );
+        }
+
+        if ( !ws->involved ) {
+            return;
+        }
+
         tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
         tempkn = k == A->nt-1 ? A->n-k*A->nb : A->nb;
         tempnn = n == A->nt-1 ? A->n-n*A->nb : A->nb;
@@ -396,28 +410,26 @@ chameleon_pzgetrf_panel_permute( struct chameleon_pzgetrf_s *ws,
 
         /* Extract selected rows into U */
         INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
-                            A(k, n), U(k, n) );
+                            A(k, n), Wu(A->myrank, n) );
 
         /*
          * perm array is made of size tempkm for the first row especially.
          * Otherwise, the final copy back to the tile may copy only a partial tile
          */
         INSERT_TASK_zlaswp_get( options, k*A->mb, tempkm,
-                                ipiv, k, A(k, n), U(k, n) );
+                                ipiv, k, A(k, n), Wu(A->myrank, n) );
 
         for(m=k+1; m<A->mt; m++){
             /* Extract selected rows into A(k, n) */
             INSERT_TASK_zlaswp_get( options, m*A->mb, minmn,
-                                    ipiv, k, A(m, n), U(k, n) );
+                                    ipiv, k, A(m, n), Wu(A->myrank, n) );
             /* Copy rows from A(k,n) into their final position */
             INSERT_TASK_zlaswp_set( options, m*A->mb, minmn,
                                     ipiv, k, A(k, n), A(m, n) );
         }
 
-        INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
-                            U(k, n), A(k, n) );
-
-        RUNTIME_data_flush( options->sequence, U(k, n) );
+        INSERT_TASK_zperm_allreduce( options, A, ipiv, k, k, n,
+                                     Wu(A->myrank, n), ws );
     }
     break;
     default:
@@ -440,6 +452,20 @@ chameleon_pzgetrf_panel_permute_batched( struct chameleon_pzgetrf_s *ws,
     {
         int m;
         int tempkm, tempkn, tempnn, minmn;
+
+        chameleon_get_proc_involved_in_panelk_2dbc( A, k, n, ws );
+        if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
+            INSERT_TASK_zperm_allreduce_send_perm( options, ipiv, k, A->myrank, ws->np_involved, ws->proc_involved );
+            INSERT_TASK_zperm_allreduce_send_invp( options, ipiv, k, A, k, n );
+        }
+        if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
+            INSERT_TASK_zperm_allreduce_send_A( options, A, k, n, A->myrank, ws->np_involved, ws->proc_involved );
+        }
+
+        if ( !ws->involved ) {
+            return;
+        }
+
         void **clargs = malloc( sizeof(char *) );
         *clargs = NULL;
 
@@ -450,25 +476,23 @@ chameleon_pzgetrf_panel_permute_batched( struct chameleon_pzgetrf_s *ws,
 
         /* Extract selected rows into U */
         INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
-                            A(k, n), U(k, n) );
+                            A(k, n), Wu(A->myrank, n) );
 
         /*
          * perm array is made of size tempkm for the first row especially.
          * Otherwise, the final copy back to the tile may copy only a partial tile
          */
         INSERT_TASK_zlaswp_get( options, k*A->mb, tempkm,
-                                ipiv, k, A(k, n), U(k, n) );
+                                ipiv, k, A(k, n), Wu(A->myrank, n) );
 
         for(m=k+1; m<A->mt; m++){
-            INSERT_TASK_zlaswp_batched( options, m*A->mb, minmn, k, m, n, (void *)ws,
-                                        ipiv, k, A, &(ws->U), clargs );
+            INSERT_TASK_zlaswp_batched( options, m*A->mb, minmn, (void *)ws, ipiv, k,
+                                        A(m, n), A(k, n), Wu(A->myrank, n), clargs );
         }
-        INSERT_TASK_zlaswp_batched_flush( options, k, n, ipiv, k, A, &(ws->U), clargs );
+        INSERT_TASK_zlaswp_batched_flush( options, ipiv, k, A(k, n), Wu(A->myrank, n), clargs );
 
-        INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
-                            U(k, n), A(k, n) );
+        INSERT_TASK_zperm_allreduce( options, A, ipiv, k, k, n, Wu(A->myrank, n), ws );
 
-        RUNTIME_data_flush( options->sequence, U(k, n) );
         free( clargs );
     }
     break;
@@ -488,7 +512,7 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws,
     const CHAMELEON_Complex64_t zone  = (CHAMELEON_Complex64_t) 1.0;
     const CHAMELEON_Complex64_t mzone = (CHAMELEON_Complex64_t)-1.0;
 
-    int m, tempkm, tempmm, tempnn;
+    int m, tempkm, tempmm, tempnn, rankAmn, p;
 
     tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
     tempnn = n == A->nt-1 ? A->n-n*A->nb : A->nb;
@@ -500,25 +524,44 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws,
         chameleon_pzgetrf_panel_permute( ws, A, ipiv, k, n, options );
     }
 
-    INSERT_TASK_ztrsm(
-        options,
-        ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
-        tempkm, tempnn, A->mb,
-        zone, A(k, k),
-              A(k, n) );
+    if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
+        for ( p = 0; p < ws->np_involved; p++ ) {
+            INSERT_TASK_ztrsm(
+                options,
+                ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
+                tempkm, tempnn, A->mb,
+                zone, A(k, k),
+                      Wu(ws->proc_involved[p], n) );
+        }
+    }
+    else if ( ws->involved ) {
+        INSERT_TASK_ztrsm(
+            options,
+            ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
+            tempkm, tempnn, A->mb,
+            zone, A(k, k),
+                  Wu(A->myrank, n) );
+    }
 
     for (m = k+1; m < A->mt; m++) {
         tempmm = m == A->mt-1 ? A->m-m*A->mb : A->mb;
+        rankAmn = A->get_rankof( A, m, n );
 
         INSERT_TASK_zgemm(
             options,
             ChamNoTrans, ChamNoTrans,
             tempmm, tempnn, A->mb, A->mb,
             mzone, A(m, k),
-                   A(k, n),
+                   Wu(rankAmn, n),
             zone,  A(m, n) );
     }
 
+    if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
+        INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
+                            Wu(A->myrank, n), A(k, n) );
+    }
+
+    RUNTIME_data_flush( options->sequence, Wu(A->myrank, n) );
     RUNTIME_data_flush( options->sequence, A(k, n) );
 }
 
@@ -534,7 +577,7 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
     CHAM_context_t  *chamctxt;
     RUNTIME_option_t options;
 
-    int k, m, n;
+    int k, m, n, tempkm, tempnn;
     int min_mnt = chameleon_min( A->mt, A->nt );
 
     chamctxt = chameleon_context_self();
@@ -559,7 +602,11 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
 
         for (n = k+1; n < A->nt; n++) {
             options.priority = A->nt-n;
-            chameleon_pzgetrf_panel_update( ws, A, IPIV, k, n, &options );
+            if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
+                 chameleon_involved_in_panelk_2dbc( A, n ) )
+            {
+                chameleon_pzgetrf_panel_update( ws, A, IPIV, k, n, &options );
+            }
         }
 
         /* Flush panel k */
@@ -574,7 +621,19 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
     if ( ws->batch_size > 0 ) {
         for (k = 1; k < min_mnt; k++) {
             for (n = 0; n < k; n++) {
-                chameleon_pzgetrf_panel_permute_batched( ws, A, IPIV, k, n, &options );
+                if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
+                    chameleon_involved_in_panelk_2dbc( A, n ) )
+                {
+                    chameleon_pzgetrf_panel_permute_batched( ws, A, IPIV, k, n, &options );
+                    if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
+                        tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
+                        tempnn = n == A->nt-1 ? A->n-n*A->nb : A->nb;
+                        INSERT_TASK_zlacpy( &options, ChamUpperLower, tempkm, tempnn,
+                                            Wu(A->myrank, n), A(k, n) );
+                        RUNTIME_data_flush( sequence, A(k, n) );
+                    }
+                }
+                RUNTIME_data_flush( sequence, Wu(A->myrank, n) );
             }
             RUNTIME_perm_flushk( sequence, IPIV, k );
         }
@@ -582,7 +641,19 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
     else {
         for (k = 1; k < min_mnt; k++) {
             for (n = 0; n < k; n++) {
-                chameleon_pzgetrf_panel_permute( ws, A, IPIV, k, n, &options );
+                if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
+                    chameleon_involved_in_panelk_2dbc( A, n ) )
+                {
+                    chameleon_pzgetrf_panel_permute( ws, A, IPIV, k, n, &options );
+                    if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
+                        tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
+                        tempnn = n == A->nt-1 ? A->n-n*A->nb : A->nb;
+                        INSERT_TASK_zlacpy( &options, ChamUpperLower, tempkm, tempnn,
+                                            Wu(A->myrank, n), A(k, n) );
+                        RUNTIME_data_flush( sequence, A(k, n) );
+                    }
+                }
+                RUNTIME_data_flush( sequence, Wu(A->myrank, n) );
             }
             RUNTIME_perm_flushk( sequence, IPIV, k );
         }
diff --git a/compute/zgetrf.c b/compute/zgetrf.c
index da434379c..b7e8f87b6 100644
--- a/compute/zgetrf.c
+++ b/compute/zgetrf.c
@@ -118,6 +118,11 @@ CHAMELEON_zgetrf_WS_Alloc( const CHAM_desc_t *A )
                              A->m, A->n, 0, 0,
                              A->m, A->n, A->p, A->q,
                              NULL, NULL, A->get_rankof_init, A->get_rankof_init_arg );
+        chameleon_desc_init( &(ws->Wu), CHAMELEON_MAT_ALLOC_TILE,
+                             ChamComplexDouble, A->mb, A->nb, A->mb*A->nb,
+                             A->mb * A->p * A->q, A->n, 0, 0,
+                             A->mb * A->p * A->q, A->n, A->p * A->q, 1,
+                             NULL, NULL, NULL, A->get_rankof_init_arg );
     }
 
     /* Set ib to 1 if per column algorithm */
@@ -180,6 +185,11 @@ CHAMELEON_zgetrf_WS_Free( void *user_ws )
     {
         chameleon_desc_destroy( &(ws->Up) );
     }
+    if ( ( ws->alg == ChamGetrfPPiv           ) ||
+         ( ws->alg == ChamGetrfPPivPerColumn  ) )
+    {
+        chameleon_desc_destroy( &(ws->Wu) );
+    }
     free( ws );
 }
 
diff --git a/control/compute_z.h b/control/compute_z.h
index 65d580ad9..acb9599f2 100644
--- a/control/compute_z.h
+++ b/control/compute_z.h
@@ -43,13 +43,15 @@ struct chameleon_pzgemm_s {
  * @brief Data structure to handle the GETRF workspaces with partial pivoting
  */
 struct chameleon_pzgetrf_s {
-    cham_getrf_t alg;
-    int          ib;         /**< Internal blocking parameter */
-    int          batch_size; /**< Batch size for the panel    */
-    CHAM_desc_t  U;
-    CHAM_desc_t  Up;
-    int         *proc_involved;
-    unsigned int involved:1;
+    cham_getrf_t   alg;
+    int            ib;         /**< Internal blocking parameter */
+    int            batch_size; /**< Batch size for the panel    */
+    CHAM_desc_t    U;
+    CHAM_desc_t    Up; /**< Workspace used for the panel factorization    */
+    CHAM_desc_t    Wu; /**< Workspace used for the permutation and update */
+    int           *proc_involved;
+    unsigned int   involved;
+    int            np_involved;
 };
 
 /**
diff --git a/include/chameleon/tasks_z.h b/include/chameleon/tasks_z.h
index 9b843c60a..5f1bbcd32 100644
--- a/include/chameleon/tasks_z.h
+++ b/include/chameleon/tasks_z.h
@@ -199,17 +199,17 @@ void INSERT_TASK_zlaswp_set( const RUNTIME_option_t *options,
                              const CHAM_desc_t *tileA, int tileAm, int tileAn,
                              const CHAM_desc_t *tileB, int tileBm, int tileBn );
 void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
-                                 int m0, int minmn, int k, int m, int n,
+                                 int m0, int minmn,
                                  void *ws,
                                  const CHAM_ipiv_t *ipiv, int ipivk,
-                                 const CHAM_desc_t *A,
-                                 const CHAM_desc_t *U,
+                                 const CHAM_desc_t *Am, int Amm, int Amn,
+                                 const CHAM_desc_t *Ak, int Akm, int Akn,
+                                 const CHAM_desc_t *U,  int Um,  int Un,
                                  void **clargs_ptr );
 void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
-                                       int k, int n,
                                        const CHAM_ipiv_t *ipiv, int ipivk,
-                                       const CHAM_desc_t *A,
-                                       const CHAM_desc_t *U,
+                                       const CHAM_desc_t *Ak, int Akm, int Akn,
+                                       const CHAM_desc_t *U,  int Um,  int Un,
                                        void **clargs_ptr );
 void INSERT_TASK_zlatro( const RUNTIME_option_t *options,
                          cham_uplo_t uplo, cham_trans_t trans, int m, int n, int mb,
diff --git a/runtime/openmp/codelets/codelet_zlaswp_batched.c b/runtime/openmp/codelets/codelet_zlaswp_batched.c
index 49ac5381c..07fd1eab8 100644
--- a/runtime/openmp/codelets/codelet_zlaswp_batched.c
+++ b/runtime/openmp/codelets/codelet_zlaswp_batched.c
@@ -21,45 +21,57 @@
 void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
                                  int                     m0,
                                  int                     minmn,
-                                 int                     k,
-                                 int                     m,
-                                 int                     n,
                                  void                   *ws,
                                  const CHAM_ipiv_t      *ipiv,
                                  int                     ipivk,
-                                 const CHAM_desc_t      *A,
-                                 const CHAM_desc_t      *Wu,
+                                 const CHAM_desc_t      *Am,
+                                 int                     Amm,
+                                 int                     Amn,
+                                 const CHAM_desc_t      *Ak,
+                                 int                     Akm,
+                                 int                     Akn,
+                                 const CHAM_desc_t      *U,
+                                 int                     Um,
+                                 int                     Un,
                                  void                  **clargs_ptr )
 {
     (void)options;
     (void)m0;
     (void)minmn;
-    (void)k;
-    (void)m;
-    (void)n;
     (void)ws;
     (void)ipiv;
     (void)ipivk;
-    (void)A;
-    (void)Wu;
+    (void)Am;
+    (void)Amm;
+    (void)Amn;
+    (void)Ak;
+    (void)Akm;
+    (void)Akn;
+    (void)U;
+    (void)Um;
+    (void)Un;
     (void)clargs_ptr;
 }
 
 void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
-                                       int                     k,
-                                       int                     n,
                                        const CHAM_ipiv_t      *ipiv,
                                        int                     ipivk,
-                                       const CHAM_desc_t      *A,
+                                       const CHAM_desc_t      *Ak,
+                                       int                     Akm,
+                                       int                     Akn,
                                        const CHAM_desc_t      *U,
+                                       int                     Um,
+                                       int                     Un,
                                        void                  **clargs_ptr )
 {
     (void)options;
-    (void)k;
-    (void)n;
     (void)ipiv;
     (void)ipivk;
-    (void)A;
+    (void)Ak;
+    (void)Akm;
+    (void)Akn;
     (void)U;
+    (void)Um;
+    (void)Un;
     (void)clargs_ptr;
 }
diff --git a/runtime/parsec/codelets/codelet_zlaswp_batched.c b/runtime/parsec/codelets/codelet_zlaswp_batched.c
index aa8726690..011d42e8b 100644
--- a/runtime/parsec/codelets/codelet_zlaswp_batched.c
+++ b/runtime/parsec/codelets/codelet_zlaswp_batched.c
@@ -21,45 +21,57 @@
 void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
                                  int                     m0,
                                  int                     minmn,
-                                 int                     k,
-                                 int                     m,
-                                 int                     n,
                                  void                   *ws,
                                  const CHAM_ipiv_t      *ipiv,
                                  int                     ipivk,
-                                 const CHAM_desc_t      *A,
-                                 const CHAM_desc_t      *Wu,
+                                 const CHAM_desc_t      *Am,
+                                 int                     Amm,
+                                 int                     Amn,
+                                 const CHAM_desc_t      *Ak,
+                                 int                     Akm,
+                                 int                     Akn,
+                                 const CHAM_desc_t      *U,
+                                 int                     Um,
+                                 int                     Un,
                                  void                  **clargs_ptr )
 {
     (void)options;
     (void)m0;
     (void)minmn;
-    (void)k;
-    (void)m;
-    (void)n;
     (void)ws;
     (void)ipiv;
     (void)ipivk;
-    (void)A;
-    (void)Wu;
+    (void)Am;
+    (void)Amm;
+    (void)Amn;
+    (void)Ak;
+    (void)Akm;
+    (void)Akn;
+    (void)U;
+    (void)Um;
+    (void)Un;
     (void)clargs_ptr;
 }
 
 void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
-                                       int                     k,
-                                       int                     n,
                                        const CHAM_ipiv_t      *ipiv,
                                        int                     ipivk,
-                                       const CHAM_desc_t      *A,
+                                       const CHAM_desc_t      *Ak,
+                                       int                     Akm,
+                                       int                     Akn,
                                        const CHAM_desc_t      *U,
+                                       int                     Um,
+                                       int                     Un,
                                        void                  **clargs_ptr )
 {
     (void)options;
-    (void)k;
-    (void)n;
     (void)ipiv;
     (void)ipivk;
-    (void)A;
+    (void)Ak;
+    (void)Akm;
+    (void)Akn;
     (void)U;
+    (void)Um;
+    (void)Un;
     (void)clargs_ptr;
 }
diff --git a/runtime/quark/codelets/codelet_zlaswp_batched.c b/runtime/quark/codelets/codelet_zlaswp_batched.c
index f96414f27..9ec2148fb 100644
--- a/runtime/quark/codelets/codelet_zlaswp_batched.c
+++ b/runtime/quark/codelets/codelet_zlaswp_batched.c
@@ -21,45 +21,57 @@
 void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
                                  int                     m0,
                                  int                     minmn,
-                                 int                     k,
-                                 int                     m,
-                                 int                     n,
                                  void                   *ws,
                                  const CHAM_ipiv_t      *ipiv,
                                  int                     ipivk,
-                                 const CHAM_desc_t      *A,
-                                 const CHAM_desc_t      *Wu,
+                                 const CHAM_desc_t      *Am,
+                                 int                     Amm,
+                                 int                     Amn,
+                                 const CHAM_desc_t      *Ak,
+                                 int                     Akm,
+                                 int                     Akn,
+                                 const CHAM_desc_t      *U,
+                                 int                     Um,
+                                 int                     Un,
                                  void                  **clargs_ptr )
 {
     (void)options;
     (void)m0;
     (void)minmn;
-    (void)k;
-    (void)m;
-    (void)n;
     (void)ws;
     (void)ipiv;
     (void)ipivk;
-    (void)A;
-    (void)Wu;
+    (void)Am;
+    (void)Amm;
+    (void)Amn;
+    (void)Ak;
+    (void)Akm;
+    (void)Akn;
+    (void)U;
+    (void)Um;
+    (void)Un;
     (void)clargs_ptr;
 }
 
 void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
-                                       int                     k,
-                                       int                     n,
                                        const CHAM_ipiv_t      *ipiv,
                                        int                     ipivk,
-                                       const CHAM_desc_t      *A,
+                                       const CHAM_desc_t      *Ak,
+                                       int                     Akm,
+                                       int                     Akn,
                                        const CHAM_desc_t      *U,
+                                       int                     Um,
+                                       int                     Un,
                                        void                  **clargs_ptr )
 {
     (void)options;
-    (void)k;
-    (void)n;
     (void)ipiv;
     (void)ipivk;
-    (void)A;
+    (void)Ak;
+    (void)Akm;
+    (void)Akn;
     (void)U;
+    (void)Um;
+    (void)Un;
     (void)clargs_ptr;
 }
diff --git a/runtime/starpu/codelets/codelet_zlaswp.c b/runtime/starpu/codelets/codelet_zlaswp.c
index ade365c68..96d3108a8 100644
--- a/runtime/starpu/codelets/codelet_zlaswp.c
+++ b/runtime/starpu/codelets/codelet_zlaswp.c
@@ -47,6 +47,9 @@ void INSERT_TASK_zlaswp_get( const RUNTIME_option_t *options,
                              const CHAM_desc_t *U, int Um, int Un )
 {
     struct starpu_codelet *codelet = &cl_zlaswp_get;
+    if ( A->get_rankof( A, Am, An) != A->myrank ) {
+        return;
+    }
 
     //void (*callback)(void*) = options->profiling ? cl_zlaswp_get_callback : NULL;
 
@@ -91,6 +94,9 @@ void INSERT_TASK_zlaswp_set( const RUNTIME_option_t *options,
                              const CHAM_desc_t *B, int Bm, int Bn )
 {
     struct starpu_codelet *codelet = &cl_zlaswp_set;
+    if ( A->get_rankof( B, Bm, Bn) != A->myrank ) {
+        return;
+    }
 
     //void (*callback)(void*) = options->profiling ? cl_zlaswp_set_callback : NULL;
 
diff --git a/runtime/starpu/codelets/codelet_zlaswp_batched.c b/runtime/starpu/codelets/codelet_zlaswp_batched.c
index 6af43659c..b17f26a48 100644
--- a/runtime/starpu/codelets/codelet_zlaswp_batched.c
+++ b/runtime/starpu/codelets/codelet_zlaswp_batched.c
@@ -57,21 +57,25 @@ CODELETS_CPU( zlaswp_batched, cl_zlaswp_batched_cpu_func )
 void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
                                  int                     m0,
                                  int                     minmn,
-                                 int                     k,
-                                 int                     m,
-                                 int                     n,
                                  void                   *ws,
                                  const CHAM_ipiv_t      *ipiv,
                                  int                     ipivk,
-                                 const CHAM_desc_t      *A,
-                                 const CHAM_desc_t      *Wu,
+                                 const CHAM_desc_t      *Am,
+                                 int                     Amm,
+                                 int                     Amn,
+                                 const CHAM_desc_t      *Ak,
+                                 int                     Akm,
+                                 int                     Akn,
+                                 const CHAM_desc_t      *U,
+                                 int                     Um,
+                                 int                     Un,
                                  void                  **clargs_ptr )
 {
     int task_num   = 0;
     int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size;
     int nhandles;
     struct cl_laswp_batched_args_t *clargs = *clargs_ptr;
-    if ( A->get_rankof( A, m, n) != A->myrank ) {
+    if ( Am->get_rankof( Am, Amm, Amn) != Am->myrank ) {
         return;
     }
 
@@ -84,7 +88,7 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
 
     task_num               = clargs->tasks_nbr;
     clargs->m0[ task_num ] = m0;
-    clargs->handle_mode[ task_num ].handle = RTBLKADDR(A, CHAMELEON_Complex64_t, m, n);
+    clargs->handle_mode[ task_num ].handle = RTBLKADDR(Am, CHAMELEON_Complex64_t, Amm, Amn);
     clargs->handle_mode[ task_num ].mode   = STARPU_RW;
     clargs->tasks_nbr ++;
 
@@ -95,8 +99,8 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
             STARPU_CL_ARGS,             clargs, sizeof(struct cl_laswp_batched_args_t),
             STARPU_R,                   RUNTIME_perm_getaddr( ipiv, ipivk ),
             STARPU_R,                   RUNTIME_invp_getaddr( ipiv, ipivk ),
-            STARPU_RW | STARPU_COMMUTE, RTBLKADDR(Wu, ChamComplexDouble, A->myrank, n),
-            STARPU_R,                   RTBLKADDR(A, ChamComplexDouble, k, n),
+            STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, Um, Un),
+            STARPU_R,                   RTBLKADDR(Ak, ChamComplexDouble, Akm, Akn),
             STARPU_DATA_MODE_ARRAY,     clargs->handle_mode, nhandles,
             STARPU_PRIORITY,            options->priority,
             STARPU_EXECUTE_ON_WORKER,   options->workerid,
@@ -108,12 +112,14 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
 }
 
 void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
-                                       int                     k,
-                                       int                     n,
                                        const CHAM_ipiv_t      *ipiv,
                                        int                     ipivk,
-                                       const CHAM_desc_t      *A,
+                                       const CHAM_desc_t      *Ak,
+                                       int                     Akm,
+                                       int                     Akn,
                                        const CHAM_desc_t      *U,
+                                       int                     Um,
+                                       int                     Un,
                                        void                  **clargs_ptr )
 {
     struct cl_laswp_batched_args_t *clargs   = *clargs_ptr;
@@ -129,8 +135,8 @@ void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
         STARPU_CL_ARGS,             clargs, sizeof(struct cl_laswp_batched_args_t),
         STARPU_R,                   RUNTIME_perm_getaddr( ipiv, ipivk ),
         STARPU_R,                   RUNTIME_invp_getaddr( ipiv, ipivk ),
-        STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, k, n),
-        STARPU_R,                   RTBLKADDR(A, ChamComplexDouble, k, n),
+        STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, Um, Un),
+        STARPU_R,                   RTBLKADDR(Ak, ChamComplexDouble, Akm, Akn),
         STARPU_DATA_MODE_ARRAY,     clargs->handle_mode, nhandles,
         STARPU_PRIORITY,            options->priority,
         STARPU_EXECUTE_ON_WORKER,   options->workerid,
-- 
GitLab