From 0e56a1baee25627a0c87852b31c9a4844aeb2c83 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Tue, 21 Jan 2025 00:06:18 +0100
Subject: [PATCH] descriptor: Add the getblkdim function to prepare the switch
 in tile dimension computations

---
 control/descriptor.c                   |  1 +
 include/chameleon/constants.h          |  8 +++
 include/chameleon/descriptor_helpers.h | 84 ++++++++++++++++++++++++++
 include/chameleon/struct.h             | 12 ++--
 4 files changed, 100 insertions(+), 5 deletions(-)

diff --git a/control/descriptor.c b/control/descriptor.c
index ff2732b2c..4a1f2ef00 100644
--- a/control/descriptor.c
+++ b/control/descriptor.c
@@ -222,6 +222,7 @@ int chameleon_desc_init_internal( CHAM_desc_t *desc, const char *name, void *mat
 
     /* If one of the function get_* is NULL, we switch back to the default */
     desc->get_blktile = chameleon_desc_gettile;
+    desc->get_blkdim  = chameleon_getblkdim;
 
     /* Data addresses */
     if ( get_blkaddr ) {
diff --git a/include/chameleon/constants.h b/include/chameleon/constants.h
index d3cfb9620..09bfc942f 100644
--- a/include/chameleon/constants.h
+++ b/include/chameleon/constants.h
@@ -113,6 +113,14 @@ typedef enum chameleon_flttype_e {
 #define ChamConvertRealHalfToSingle   ChamConvert( ChamRealHalf,   ChamRealSingle )
 #define ChamConvertRealHalfToHalf     ChamConvert( ChamRealHalf,   ChamRealHalf   )
 
+/**
+ * @brief Matrix dimensions naming
+ */
+typedef enum chameleon_dim_e {
+    DIM_m = 0,
+    DIM_n = 1,
+} cham_dim_t;
+
 /**
  * @brief Matrix tile storage
  */
diff --git a/include/chameleon/descriptor_helpers.h b/include/chameleon/descriptor_helpers.h
index 7bfdeb77b..9e60ef27d 100644
--- a/include/chameleon/descriptor_helpers.h
+++ b/include/chameleon/descriptor_helpers.h
@@ -88,8 +88,92 @@ int chameleon_getblkldd_cm  ( const CHAM_desc_t *A, int m );
 int chameleon_getblkldd_ccrb( const CHAM_desc_t *A, int m );
 /**
  * @}
+ * @name Tile dimensions computation in algorithms
+ * @{
  */
 
+/**
+ *
+ * @ingroup Descriptor
+ *
+ * @brief Return tile dimension along the m dimension with regular tile sizes
+ *
+ * @param[in] A
+ *          The chameleon descriptor for which to compute the size
+ *
+ * @param[in] m
+ *          The row index of the tile
+ *
+ * @param[in] lm
+ *          The matrix row dimension against which to compute the size
+ *
+ * @retval The dimension of the tile along the row/first dimension with a limit on lm
+ *
+ */
+static inline int
+chameleon_getblkdim_m( const CHAM_desc_t *A, int m, int lm )
+{
+    return (((m + 1) * A->mb) > lm ) ? lm - m * A->mb : A->mb;
+}
+
+/**
+ *
+ * @ingroup Descriptor
+ *
+ * @brief Return tile dimension along the n dimension with regular tile sizes
+ *
+ * @param[in] A
+ *          The chameleon descriptor for which to compute the size
+ *
+ * @param[in] n
+ *          The column index of the tile
+ *
+ * @param[in] ln
+ *          The matrix column dimension against which to compute the size
+ *
+ * @retval The dimension of the tile along the column/second dimension with a limit on ln
+ *
+ */
+static inline int
+chameleon_getblkdim_n( const CHAM_desc_t *A, int n, int ln )
+{
+    return (((n + 1) * A->nb) > ln ) ? ln - n * A->nb : A->nb;
+}
+
+/**
+ *
+ * @ingroup Descriptor
+ *
+ * @brief Return tile dimension along the dim dimension with regular tile sizes
+ *
+ * @param[in] A
+ *          The chameleon descriptor for which to compute the size
+ *
+ * @param[in] m
+ *          The index of the tile in the given dimension
+ *
+ * @param[in] dim
+ *          The dimension on which to compute the size
+ *
+ * @param[in] lm
+ *          The matrix dimension along the chosen dim.
+ *
+ * @retval The dimension of the tile along the dim dimension with a limit on lm
+ *
+ */
+static inline int
+chameleon_getblkdim( const CHAM_desc_t *A, int m, cham_dim_t dim, int lm )
+{
+    if ( dim == 0 ) {
+        return chameleon_getblkdim_m( A, m, lm );
+    }
+    else {
+        return chameleon_getblkdim_n( A, m, lm );
+    }
+}
+/**
+ * @}
+ */
 #ifdef __cplusplus
 }
 #endif
diff --git a/include/chameleon/struct.h b/include/chameleon/struct.h
index e61f95b18..d36932500 100644
--- a/include/chameleon/struct.h
+++ b/include/chameleon/struct.h
@@ -80,6 +80,7 @@ typedef struct chameleon_desc_s CHAM_desc_t;
 
 typedef void*        (*blkaddr_fct_t)        ( const CHAM_desc_t*, int, int );
 typedef int          (*blkldd_fct_t)         ( const CHAM_desc_t*, int );
+typedef int          (*blkdim_fct_t)         ( const CHAM_desc_t*, int, cham_dim_t, int );
 typedef int          (*blkrankof_fct_t)      ( const CHAM_desc_t*, int, int );
 typedef int          (*datadist_access_fct_t)( const CHAM_desc_t*, int, ... );
 typedef CHAM_tile_t* (*blktile_fct_t)        ( const CHAM_desc_t*, int, int );
@@ -116,11 +117,12 @@ void chameleon_desc_set_datadist( CHAM_desc_t *to, cham_data_dist_t *from );
 
 struct chameleon_desc_s {
     const char *name;
-    blktile_fct_t   get_blktile;     /**> function to get chameleon tiles address           */
-    blkaddr_fct_t   get_blkaddr;     /**> function to get chameleon tiles address           */
-    blkldd_fct_t    get_blkldd;      /**> function to get chameleon tiles leading dimension */
-    blkrankof_fct_t get_rankof;      /**> function to get chameleon tiles MPI rank          */
-    blkrankof_fct_t get_rankof_init; /**> function to get chameleon tiles MPI rank          */
+    blktile_fct_t   get_blktile;     /**> function to get chameleon tiles address                     */
+    blkaddr_fct_t   get_blkaddr;     /**> function to get chameleon tiles address                     */
+    blkldd_fct_t    get_blkldd;      /**> function to get chameleon tiles leading dimension           */
+    blkdim_fct_t    get_blkdim;      /**> function to get chameleon tiles dimension within algorithms */
+    blkrankof_fct_t get_rankof;      /**> function to get chameleon tiles MPI rank                    */
+    blkrankof_fct_t get_rankof_init; /**> function to get chameleon tiles MPI rank                    */
 
     void* get_rankof_init_arg;
     CHAM_tile_t *tiles;  /**> pointer to the array of tiles descriptors  */
-- 
GitLab