From b7fb48cf21887ca4ea8006c97f196bd0e6476307 Mon Sep 17 00:00:00 2001
From: Alycia Lisito <alycia.lisito@inria.fr>
Date: Wed, 24 Jul 2024 16:55:29 +0200
Subject: [PATCH] zgetrf: make zgetrf working on multiple mpi processes

---
 compute/pzgetrf.c                             | 16 +++++---
 compute/zgetrf.c                              |  4 --
 .../starpu/codelets/codelet_zgetrf_batched.c  | 39 +++++++++++++++++++
 .../starpu/codelets/codelet_zgetrf_blocked.c  | 39 ++++++++++++++++++-
 .../starpu/codelets/codelet_zgetrf_percol.c   | 19 +++++++++
 .../starpu/codelets/codelet_zipiv_allreduce.c |  2 +-
 6 files changed, 107 insertions(+), 12 deletions(-)

diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c
index 0d17f0034..a56e1c9a1 100644
--- a/compute/pzgetrf.c
+++ b/compute/pzgetrf.c
@@ -148,7 +148,7 @@ chameleon_pzgetrf_panel_facto_percol( struct chameleon_pzgetrf_s *ws,
         }
 
         /* Reduce globally (between MPI processes) */
-        INSERT_TASK_zipiv_redux( A, options, ipiv, ws->proc_involved, k, h, tempkn );
+        INSERT_TASK_zipiv_allreduce( A, options, ipiv, ws->proc_involved, k, h, tempkn );
     }
 
     /* Flush temporary data used for the pivoting */
@@ -194,7 +194,7 @@ chameleon_pzgetrf_panel_facto_percol_batched( struct chameleon_pzgetrf_s *ws,
         }
         INSERT_TASK_zgetrf_panel_offdiag_batched_flush( options, A, k, clargs, ipiv );
 
-        INSERT_TASK_zipiv_redux( A, options, ipiv, ws->proc_involved, k, h, tempkn );
+        INSERT_TASK_zipiv_allreduce( A, options, ipiv, ws->proc_involved, k, h, tempkn );
     }
 
     free( clargs );
@@ -252,7 +252,7 @@ chameleon_pzgetrf_panel_facto_blocked( struct chameleon_pzgetrf_s *ws,
 
             assert( j <= minmn );
             /* Reduce globally (between MPI processes) */
-            INSERT_TASK_zipiv_redux( A, options, ipiv, ws->proc_involved, k, j, tempkn );
+            INSERT_TASK_zipiv_allreduce( A, options, ipiv, ws->proc_involved, k, j, tempkn );
 
             if ( ( b < (nbblock-1) ) && ( h == hmax-1 ) ) {
                 INSERT_TASK_zgetrf_blocked_trsm(
@@ -282,8 +282,8 @@ chameleon_pzgetrf_panel_facto_blocked_batched( struct chameleon_pzgetrf_s *ws,
 {
     int m, h, b, nbblock, hmax, j;
     int tempkm, tempkn, tempmm, minmn;
-    void **clargs = malloc( sizeof(char *) * A->p );
-    memset( clargs, 0, sizeof(char *) * A->p );
+    void **clargs = malloc( sizeof(char *) );
+    memset( clargs, 0, sizeof(char *) );
 
     tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
     tempkn = k == A->nt-1 ? A->n-k*A->nb : A->nb;
@@ -314,7 +314,7 @@ chameleon_pzgetrf_panel_facto_blocked_batched( struct chameleon_pzgetrf_s *ws,
 
             assert( j <= minmn );
             /* Reduce globally (between MPI processes) */
-            INSERT_TASK_zipiv_redux( A, options, ipiv, ws->proc_involved, k, j, tempkn );
+            INSERT_TASK_zipiv_allreduce( A, options, ipiv, ws->proc_involved, k, j, tempkn );
 
             if ( (b < (nbblock-1)) && (h == hmax-1) ) {
                 INSERT_TASK_zgetrf_blocked_trsm(
@@ -340,6 +340,7 @@ chameleon_pzgetrf_panel_facto( struct chameleon_pzgetrf_s *ws,
                                int                         k,
                                RUNTIME_option_t           *options )
 {
+#if defined ( CHAMELEON_USE_MPI )
     int *proc_involved = malloc( sizeof( int ) * chameleon_min( A->p, A->mt - k) );
     int  b;
 
@@ -357,6 +358,7 @@ chameleon_pzgetrf_panel_facto( struct chameleon_pzgetrf_s *ws,
 	free( proc_involved );
         return;
     }
+#endif
 
     /* TODO: Should be replaced by a function pointer */
     switch( ws->alg ) {
@@ -387,7 +389,9 @@ chameleon_pzgetrf_panel_facto( struct chameleon_pzgetrf_s *ws,
     default:
         chameleon_pzgetrf_panel_facto_nopiv( ws, A, ipiv, k, options );
     }
+#if defined ( CHAMELEON_USE_MPI )
     free( proc_involved );
+#endif
 }
 
 /**
diff --git a/compute/zgetrf.c b/compute/zgetrf.c
index 508a78125..8fb6734d3 100644
--- a/compute/zgetrf.c
+++ b/compute/zgetrf.c
@@ -95,10 +95,6 @@ CHAMELEON_zgetrf_WS_Alloc( const CHAM_desc_t *A )
         chameleon_warning( "CHAMELEON_BATCH_SIZE", "CHAMELEON_GETRF_BATCH_SIZE must be smaller than CHAMELEON_BATCH_SIZE, please recompile with the right CHAMELEON_BATCH_SIZE, or reduce the CHAMELEON_GETRF_BATCH_SIZE value\n" );
         ws->batch_size = CHAMELEON_BATCH_SIZE;
     }
-    if ( (ws->batch_size > 1) && (CHAMELEON_Comm_rank() > 1) ) {
-        chameleon_warning( "CHAMELEON_BATCH_SIZE", "CHAMELEON_GETRF_BATCH_SIZE is unavailable in distributed, value forced to 1\n" );
-        ws->batch_size = 1;
-    }
 
     /* Allocation of U for permutation of the panels */
     if ( ws->alg == ChamGetrfNoPivPerColumn ) {
diff --git a/runtime/starpu/codelets/codelet_zgetrf_batched.c b/runtime/starpu/codelets/codelet_zgetrf_batched.c
index 1ead5ec17..ab9b96020 100644
--- a/runtime/starpu/codelets/codelet_zgetrf_batched.c
+++ b/runtime/starpu/codelets/codelet_zgetrf_batched.c
@@ -79,6 +79,14 @@ INSERT_TASK_zgetrf_panel_offdiag_batched( const RUNTIME_option_t *options,
     void (*callback)(void*) = NULL;
     struct cl_getrf_batched_args_t *clargs = *clargs_ptr;
     int rankA = A->get_rankof( A, Am, An );
+    if ( rankA != A->myrank ) {
+        return;
+    }
+#if !defined ( HAVE_STARPU_NONE_NONZERO )
+    /* STARPU_NONE can't be equal to 0 */
+    fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" );
+    assert( 0 );
+#endif
 
     /* Handle cache */
     CHAMELEON_BEGIN_ACCESS_DECLARATION;
@@ -138,6 +146,11 @@ INSERT_TASK_zgetrf_panel_offdiag_batched_flush( const RUNTIME_option_t *options,
     void (*callback)(void*) = NULL;
     struct cl_getrf_batched_args_t *clargs = *clargs_ptr;
     int rankA = A->myrank;
+#if !defined ( HAVE_STARPU_NONE_NONZERO )
+    /* STARPU_NONE can't be equal to 0 */
+    fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" );
+    assert( 0 );
+#endif
 
     if ( clargs == NULL ) {
         return;
@@ -241,6 +254,27 @@ INSERT_TASK_zgetrf_panel_blocked_batched( const RUNTIME_option_t *options,
     int accessU, access_npiv, access_ipiv, access_ppiv;
     struct cl_getrf_batched_args_t *clargs = *clargs_ptr;
     int rankA = A->get_rankof(A, Am, An);
+#if !defined ( HAVE_STARPU_NONE_NONZERO )
+    /* STARPU_NONE can't be equal to 0 */
+    fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" );
+    assert( 0 );
+#endif
+
+#if defined ( CHAMELEON_USE_MPI )
+    if ( ( Am == An ) && ( h % ib == 0 ) && ( h > 0 ) ) {
+        starpu_mpi_cache_flush( options->sequence->comm,
+                                RTBLKADDR(U, CHAMELEON_Complex64_t, Um, Un) );
+    }
+
+    if ( rankA != A->myrank ) {
+        if ( ( h % ib == 0 ) && ( h > 0 ) && ( A->myrank == A->get_rankof( A, An, An ) ) ) {
+            starpu_mpi_get_data_on_node_detached( options->sequence->comm,
+                                                  RTBLKADDR(U, CHAMELEON_Complex64_t, Um, Un),
+                                                  rankA, NULL, NULL );
+        }
+        return;
+    }
+#endif
 
     /* Handle cache */
     CHAMELEON_BEGIN_ACCESS_DECLARATION;
@@ -325,6 +359,11 @@ INSERT_TASK_zgetrf_panel_blocked_batched_flush( const RUNTIME_option_t *options,
     void (*callback)(void*) = NULL;
     struct cl_getrf_batched_args_t *clargs = *clargs_ptr;
     int rankA = A->myrank;
+#if !defined ( HAVE_STARPU_NONE_NONZERO )
+    /* STARPU_NONE can't be equal to 0 */
+    fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" );
+    assert( 0 );
+#endif
 
     if ( clargs == NULL ) {
         return;
diff --git a/runtime/starpu/codelets/codelet_zgetrf_blocked.c b/runtime/starpu/codelets/codelet_zgetrf_blocked.c
index d11d27365..fff1d0723 100644
--- a/runtime/starpu/codelets/codelet_zgetrf_blocked.c
+++ b/runtime/starpu/codelets/codelet_zgetrf_blocked.c
@@ -98,6 +98,21 @@ void INSERT_TASK_zgetrf_blocked_diag( const RUNTIME_option_t *options,
     void (*callback)(void*) = options->profiling ? cl_zgetrf_blocked_diag_callback : NULL;
     const char *cl_name = "zgetrf_blocked_diag";
     int rankA           = A->get_rankof(A, Am, An);
+#if !defined ( HAVE_STARPU_NONE_NONZERO )
+    /* STARPU_NONE can't be equal to 0 */
+    fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" );
+    assert( 0 );
+#endif
+
+#if defined ( CHAMELEON_USE_MPI )
+    if ( ( h % ib == 0 ) && ( h > 0 ) ) {
+        starpu_mpi_cache_flush( options->sequence->comm, RTBLKADDR(U, CHAMELEON_Complex64_t, Um, Un) );
+    }
+
+    if ( rankA != A->myrank ) {
+        return;
+    }
+#endif
 
     int access_ipiv = ( h == 0 )       ? STARPU_W    : STARPU_RW;
     int access_npiv = ( h == ipiv->n ) ? STARPU_R    : STARPU_REDUX;
@@ -111,7 +126,7 @@ void INSERT_TASK_zgetrf_blocked_diag( const RUNTIME_option_t *options,
     else if ( h%ib == 0 ) {
         accessU = STARPU_R;
     }
-    else if ( h%ib == 1 ) {
+    else if ( ( h%ib == 1 ) || ( ib == 1 ) ) {
         accessU = STARPU_W;
     }
 
@@ -213,6 +228,24 @@ void INSERT_TASK_zgetrf_blocked_offdiag( const RUNTIME_option_t *options,
     int access_ppiv = ( h == 0 )       ? STARPU_NONE : STARPU_R;
     int accessU     = ((h%ib == 0) && (h > 0)) ? STARPU_R : STARPU_NONE;
     int rankA       = A->get_rankof(A, Am, An);
+#if !defined ( HAVE_STARPU_NONE_NONZERO )
+    /* STARPU_NONE can't be equal to 0 */
+    fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" );
+    assert( 0 );
+#endif
+
+#if defined ( CHAMELEON_USE_MPI )
+    if ( rankA != A->myrank ) {
+        if ( ( accessU != STARPU_NONE ) &&
+             ( A->myrank == A->get_rankof( A, An, An ) ) )
+        {
+            starpu_mpi_get_data_on_node_detached( options->sequence->comm,
+                                                  RTBLKADDR(U, CHAMELEON_Complex64_t, Um, Un),
+                                                  rankA, NULL, NULL );
+        }
+        return;
+    }
+#endif
 
     void (*callback)(void*) = options->profiling ? cl_zgetrf_blocked_offdiag_callback : NULL;
     const char *cl_name = "zgetrf_blocked_offdiag";
@@ -312,6 +345,10 @@ void INSERT_TASK_zgetrf_blocked_trsm( const RUNTIME_option_t *options,
     cl_name = chameleon_codelet_name( cl_name, 1,
                                       U->get_blktile( U, Um, Un ) );
 
+    if ( U->myrank != U->get_rankof(U, Um, Un) ) {
+        return;
+    }
+
     rt_starpu_insert_task(
         codelet,
         STARPU_VALUE,             &m,                   sizeof(int),
diff --git a/runtime/starpu/codelets/codelet_zgetrf_percol.c b/runtime/starpu/codelets/codelet_zgetrf_percol.c
index df2301782..c8ff33aff 100644
--- a/runtime/starpu/codelets/codelet_zgetrf_percol.c
+++ b/runtime/starpu/codelets/codelet_zgetrf_percol.c
@@ -85,6 +85,16 @@ void INSERT_TASK_zgetrf_percol_diag( const RUNTIME_option_t *options,
     struct starpu_codelet *codelet = &cl_zgetrf_percol_diag;
     void (*callback)(void*) = options->profiling ? cl_zgetrf_percol_diag_callback : NULL;
     const char *cl_name = "zgetrf_percol_diag";
+    int rankA           = A->get_rankof(A, Am, An);
+#if !defined ( HAVE_STARPU_NONE_NONZERO )
+    /* STARPU_NONE can't be equal to 0 */
+    fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" );
+    assert( 0 );
+#endif
+
+    if ( rankA != A->myrank ) {
+        return;
+    }
 
     int access_ipiv = ( h == 0 )       ? STARPU_W    : STARPU_RW;
     int access_npiv = ( h == ipiv->n ) ? STARPU_R    : STARPU_REDUX;
@@ -162,6 +172,15 @@ void INSERT_TASK_zgetrf_percol_offdiag( const RUNTIME_option_t *options,
     int access_npiv = ( h == ipiv->n ) ? STARPU_R    : STARPU_REDUX;
     int access_ppiv = ( h == 0 )       ? STARPU_NONE : STARPU_R;
     int rankA       = A->get_rankof(A, Am, An);
+#if !defined ( HAVE_STARPU_NONE_NONZERO )
+    /* STARPU_NONE can't be equal to 0 */
+    fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" );
+    assert( 0 );
+#endif
+
+    if ( rankA != A->myrank ) {
+        return;
+    }
 
     /* Handle cache */
     CHAMELEON_BEGIN_ACCESS_DECLARATION;
diff --git a/runtime/starpu/codelets/codelet_zipiv_allreduce.c b/runtime/starpu/codelets/codelet_zipiv_allreduce.c
index 9856258bb..13a41ceb0 100644
--- a/runtime/starpu/codelets/codelet_zipiv_allreduce.c
+++ b/runtime/starpu/codelets/codelet_zipiv_allreduce.c
@@ -19,7 +19,7 @@
 #include "runtime_codelet_z.h"
 #include <coreblas/cblas_wrapper.h>
 
-#if defined ( CHAMELEON_USE_MPI )
+#if defined(CHAMELEON_USE_MPI)
 struct cl_redux_args_t {
     int h;
     int n;
-- 
GitLab