From 981e38a4c5f4f4938bff6eaf282fd513083c2203 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Tue, 6 Sep 2022 10:48:25 +0200
Subject: [PATCH] descriptor: Add a rank_of_init function to store the initial
 distribution, and exploit the rank stored in the tile

---
 compute/pzgenm2.c          |  2 +-
 compute/zgemm.c            |  2 +-
 compute/zhemm.c            |  2 +-
 compute/zsymm.c            |  2 +-
 control/compute_z.h        |  2 +-
 control/descriptor.c       | 23 +++++++++++++++--------
 control/descriptor.h       |  6 ++++++
 include/chameleon/struct.h |  2 ++
 8 files changed, 28 insertions(+), 13 deletions(-)

diff --git a/compute/pzgenm2.c b/compute/pzgenm2.c
index 26689f874..bb440cc2d 100644
--- a/compute/pzgenm2.c
+++ b/compute/pzgenm2.c
@@ -42,7 +42,7 @@ chameleon_pzgenm2( double tol, const CHAM_desc_t *A, double *result,
     int cnt, maxiter;
     double e0, normx, normsx, beta, scl;
 
-    if ( A->get_rankof != chameleon_getrankof_2d ) {
+    if ( A->get_rankof_init != chameleon_getrankof_2d ) {
         sequence->status = CHAMELEON_ERR_NOT_SUPPORTED;
     }
 
diff --git a/compute/zgemm.c b/compute/zgemm.c
index 9ce47be60..f634ed6be 100644
--- a/compute/zgemm.c
+++ b/compute/zgemm.c
@@ -171,7 +171,7 @@ void *CHAMELEON_zgemm_WS_Alloc( cham_trans_t       transA __attribute__((unused)
 
     /* Now that we have decided which algorithm, let's allocate the required data structures. */
     if ( (options->alg == ChamGemmAlgSummaC ) &&
-         (C->get_rankof == chameleon_getrankof_2d ) )
+         (C->get_rankof_init == chameleon_getrankof_2d ) )
     {
         int lookahead = chamctxt->lookahead;
 
diff --git a/compute/zhemm.c b/compute/zhemm.c
index 23c7fdf98..6f2eb33d6 100644
--- a/compute/zhemm.c
+++ b/compute/zhemm.c
@@ -150,7 +150,7 @@ void *CHAMELEON_zhemm_WS_Alloc( cham_side_t        side __attribute__((unused)),
 
     /* Now that we have decided which algorithm, let's allocate the required data structures. */
     if ( (options->alg == ChamGemmAlgSummaC ) &&
-         (C->get_rankof == chameleon_getrankof_2d ) )
+         (C->get_rankof_init == chameleon_getrankof_2d ) )
     {
         int lookahead = chamctxt->lookahead;
 
diff --git a/compute/zsymm.c b/compute/zsymm.c
index 397c8f65c..cf9f19427 100644
--- a/compute/zsymm.c
+++ b/compute/zsymm.c
@@ -150,7 +150,7 @@ void *CHAMELEON_zsymm_WS_Alloc( cham_side_t        side __attribute__((unused)),
 
     /* Now that we have decided which algorithm, let's allocate the required data structures. */
     if ( (options->alg == ChamGemmAlgSummaC ) &&
-         (C->get_rankof == chameleon_getrankof_2d ) )
+         (C->get_rankof_init == chameleon_getrankof_2d ) )
     {
         int lookahead = chamctxt->lookahead;
 
diff --git a/control/compute_z.h b/control/compute_z.h
index 1ee61f628..b58984838 100644
--- a/control/compute_z.h
+++ b/control/compute_z.h
@@ -242,7 +242,7 @@ chameleon_zdesc_copy_and_restrict( const CHAM_desc_t *descIn,
                               m, n, 0, 0, m, n, descIn->p, descIn->q,
                               descIn->get_blkaddr,
                               descIn->get_blkldd,
-                              descIn->get_rankof );
+                              descIn->get_rankof_init );
     return rc;
 }
 
diff --git a/control/descriptor.c b/control/descriptor.c
index f391f1ac8..7e4e6117b 100644
--- a/control/descriptor.c
+++ b/control/descriptor.c
@@ -88,17 +88,18 @@ int chameleon_desc_mat_free( CHAM_desc_t *desc )
     return CHAMELEON_SUCCESS;
 }
 
-void chameleon_desc_init_tiles( CHAM_desc_t *desc )
+void chameleon_desc_init_tiles( CHAM_desc_t *desc, blkrankof_fct_t rankof )
 {
     CHAM_tile_t *tile;
     int ii, jj;
 
+    assert( rankof != chameleon_getrankof_tile );
     desc->tiles = malloc( desc->lmt * desc->lnt * sizeof(CHAM_tile_t) );
 
     tile = desc->tiles;
     for( jj=0; jj<desc->lnt; jj++ ) {
         for( ii=0; ii<desc->lmt; ii++, tile++ ) {
-            int rank = desc->get_rankof( desc, ii, jj );
+            int rank = rankof( desc, ii, jj );
             tile->format = CHAMELEON_TILE_FULLRANK;
             tile->rank   = rank;
             tile->m      = ii == desc->lmt-1 ? desc->lm - ii * desc->mb : desc->mb;
@@ -216,7 +217,8 @@ int chameleon_desc_init_internal( CHAM_desc_t *desc, const char *name, void *mat
     desc->get_blktile = chameleon_desc_gettile;
     desc->get_blkaddr = get_blkaddr ? get_blkaddr : chameleon_getaddr_ccrb;
     desc->get_blkldd  = get_blkldd  ? get_blkldd  : chameleon_getblkldd_ccrb;
-    desc->get_rankof  = get_rankof  ? get_rankof  : chameleon_getrankof_2d;
+    desc->get_rankof  = chameleon_getrankof_tile;
+    desc->get_rankof_init = get_rankof ? get_rankof : chameleon_getrankof_2d;
 
     /* Matrix properties */
     desc->dtyp = dtyp;
@@ -332,7 +334,7 @@ int chameleon_desc_init_internal( CHAM_desc_t *desc, const char *name, void *mat
     desc->A12 = (size_t)(            desc->llm%mb)*(size_t)(desc->lln - desc->lln%nb) + desc->A21;
     desc->A22 = (size_t)(desc->llm - desc->llm%mb)*(size_t)(            desc->lln%nb) + desc->A12;
 
-    chameleon_desc_init_tiles( desc );
+    chameleon_desc_init_tiles( desc, desc->get_rankof_init );
 
     /* Create runtime specific structure like registering data */
     RUNTIME_desc_create( desc );
@@ -791,7 +793,7 @@ CHAM_desc_t *CHAMELEON_Desc_Copy( const CHAM_desc_t *descin, void *mat )
     CHAMELEON_Desc_Create_User( &descout, mat,
                                 descin->dtyp, descin->mb, descin->nb, descin->bsiz,
                                 descin->lm, descin->ln, descin->i, descin->j, descin->m, descin->n, descin->p, descin->q,
-                                NULL, NULL, descin->get_rankof );
+                                NULL, NULL, descin->get_rankof_init );
     return descout;
 }
 
@@ -826,7 +828,7 @@ CHAM_desc_t *CHAMELEON_Desc_CopyOnZero( const CHAM_desc_t *descin, void *mat )
     CHAMELEON_Desc_Create_User( &descout, mat,
                                 descin->dtyp, descin->mb, descin->nb, descin->bsiz,
                                 descin->lm, descin->ln, descin->i, descin->j, descin->m, descin->n, 1, 1,
-                                NULL, NULL, descin->get_rankof );
+                                NULL, NULL, descin->get_rankof_init );
     return descout;
 }
 
@@ -1004,6 +1006,7 @@ chameleon_desc_print( const CHAM_desc_t *desc, int shift )
             trank    = desc->get_rankof( desc, m, n );
             tile     = desc->get_blktile( desc, m, n );
             tiledesc = tile->mat;
+            assert( trank == tile->rank );
 
             ptr = ( tile->format == CHAMELEON_TILE_DESC ) ? (intptr_t)(tiledesc->mat) : (intptr_t)(tile->mat);
 
@@ -1140,7 +1143,7 @@ int CHAMELEON_Desc_Change_Distribution_Async( cham_uplo_t         uplo,
     }
 
     /* Nothing to do if the new mapping is the same as the original one */
-    if ( ( new_get_rankof == desc->get_rankof ) ||
+    if ( ( new_get_rankof == desc->get_rankof_init ) ||
          ( RUNTIME_comm_size( chamctxt ) == 1 ) )
     {
         return CHAMELEON_SUCCESS;
@@ -1165,12 +1168,16 @@ int CHAMELEON_Desc_Change_Distribution_Async( cham_uplo_t         uplo,
         mmin = ( uplo == ChamLower ) ? chameleon_min( n,   desc->mt ) : 0;
         mmax = ( uplo == ChamUpper ) ? chameleon_min( n+1, desc->mt ) : desc->mt;
         for ( m = mmin; m < mmax; m++ ) {
+            CHAM_tile_t *tile = desc->get_blktile( desc, m, n );
+            int rank = new_get_rankof( desc, m, n );
+
             RUNTIME_data_migrate( sequence, desc, m, n, new_get_rankof( desc, m, n ) );
+            tile->rank = rank;
         }
     }
 
     /* Actually change data location in Chameleon */
-    desc->get_rankof = new_get_rankof;
+    desc->get_rankof_init = new_get_rankof;
 
     return CHAMELEON_SUCCESS;
 }
diff --git a/control/descriptor.h b/control/descriptor.h
index 47f0edbdf..1e0d883f6 100644
--- a/control/descriptor.h
+++ b/control/descriptor.h
@@ -51,6 +51,12 @@ inline static int   chameleon_getblkldd_ccrb(const CHAM_desc_t *A, int m);
 int chameleon_getrankof_2d(const CHAM_desc_t *desc, int m, int n);
 int chameleon_getrankof_2d_diag(const CHAM_desc_t *desc, int m, int n);
 
+static inline int chameleon_getrankof_tile(const CHAM_desc_t *desc, int m, int n) {
+    CHAM_tile_t *tile = desc->get_blktile( desc, m, n );
+    assert( tile != NULL );
+    return tile->rank;
+}
+
 int chameleon_desc_init_internal( CHAM_desc_t *desc, const char *name, void *mat,
                                   cham_flttype_t dtyp, int mb, int nb,
                                   int lm, int ln, int m, int n, int p, int q,
diff --git a/include/chameleon/struct.h b/include/chameleon/struct.h
index cf751e199..785185550 100644
--- a/include/chameleon/struct.h
+++ b/include/chameleon/struct.h
@@ -80,6 +80,8 @@ struct chameleon_desc_s {
     blkldd_fct_t    get_blkldd;
     // function to get chameleon tiles MPI rank
     blkrankof_fct_t get_rankof;
+    // function to get chameleon tiles MPI rank
+    blkrankof_fct_t get_rankof_init;
     CHAM_tile_t *tiles; // pointer to the array of tiles descriptors
     void *mat;        // pointer to the beginning of the matrix
     size_t A21;       // pointer to the beginning of the matrix A21
-- 
GitLab