From 2bc31a0dffc58418453fb025d67fec3eca40f035 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Mon, 22 May 2023 15:25:54 -0400
Subject: [PATCH] compute/gerst: Add a function to restore the initial
 precision of a given matrix

---
 cmake_modules/local_subs.py              |   2 +
 compute/CMakeLists.txt                   |   2 +
 compute/pzgerst.c                        |  58 +++++++++
 compute/zgerst.c                         | 148 +++++++++++++++++++++++
 control/compute_z.h                      |   2 +
 include/chameleon/chameleon_z.h          |   2 +
 include/chameleon/tasks_z.h              |   3 +
 runtime/CMakeLists.txt                   |   1 +
 runtime/starpu/CMakeLists.txt            |   1 -
 runtime/starpu/codelets/codelet_zgerst.c | 118 ++++++++++++++++++
 10 files changed, 336 insertions(+), 1 deletion(-)
 create mode 100644 compute/pzgerst.c
 create mode 100644 compute/zgerst.c
 create mode 100644 runtime/starpu/codelets/codelet_zgerst.c

diff --git a/cmake_modules/local_subs.py b/cmake_modules/local_subs.py
index 5318928a2..a833ae816 100644
--- a/cmake_modules/local_subs.py
+++ b/cmake_modules/local_subs.py
@@ -39,6 +39,7 @@ _extra_blas = [
     ('',                     'slatm1',               'dlatm1',               'slatm1',               'dlatm1'              ),
     ('',                     'sgenm2',               'dgenm2',               'cgenm2',               'zgenm2'              ),
     ('',                     'slag2c_fake',          'dlag2z_fake',          'slag2c',               'dlag2z'              ),
+    ('',                     'slag2d',               'slag2d',               'clag2z',               'clag2z'              ),
     ('',                     'slag2h',               'dlag2h',               'slag2h',               'dlag2h'              ),
     ('',                     'hlag2s',               'hlag2d',               'hlag2s',               'hlag2d'              ),
     ('',                     'slag2h',               'dlag2h',               'clag2x',               'zlag2x'              ),
@@ -49,6 +50,7 @@ _extra_blas = [
     ('',                     'sgersum',              'dgersum',              'cgersum',              'zgersum'             ),
     ('',                     'sprint',               'dprint',               'cprint',               'zprint'              ),
     ('',                     'sgered',               'dgered',               'cgered',               'zgered'              ),
+    ('',                     'sgerst',               'dgerst',               'cgerst',               'zgerst'              ),
 ]
 
 _extra_BLAS = [ [ x.upper() for x in row ] for row in _extra_blas ]
diff --git a/compute/CMakeLists.txt b/compute/CMakeLists.txt
index b9dc95f8b..1c7b5a0b2 100644
--- a/compute/CMakeLists.txt
+++ b/compute/CMakeLists.txt
@@ -193,8 +193,10 @@ set(ZSRC
     ##################
     pzlag2c.c
     pzgered.c
+    pzgerst.c
     ###
     zgered.c
+    zgerst.c
     #zcgels.c
     #zcgesv.c
     #zcposv.c
diff --git a/compute/pzgerst.c b/compute/pzgerst.c
new file mode 100644
index 000000000..86d01e168
--- /dev/null
+++ b/compute/pzgerst.c
@@ -0,0 +1,58 @@
+/**
+ *
+ * @file pzgerst.c
+ *
+ * @copyright 2012-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
+ *                      Univ. Bordeaux. All rights reserved.
+ *
+ ***
+ *
+ * @brief Chameleon zgerst parallel algorithm
+ *
+ * @version 1.3.0
+ * @author Mathieu Faverge
+ * @date 2023-07-06
+ * @precisions normal z -> d
+ *
+ */
+#include "control/common.h"
+
+#define A(m,n) A,  m,  n
+#define B(m,n) B,  m,  n
+
+void chameleon_pzgerst( cham_uplo_t         uplo,
+                        CHAM_desc_t        *A,
+                        RUNTIME_sequence_t *sequence,
+                        RUNTIME_request_t  *request )
+{
+    CHAM_context_t *chamctxt;
+    RUNTIME_option_t options;
+    int m, n;
+
+    chamctxt = chameleon_context_self();
+    if (sequence->status != CHAMELEON_SUCCESS) {
+        return;
+    }
+    RUNTIME_options_init(&options, chamctxt, sequence, request);
+
+    for(m = 0; m < A->mt; m++) {
+        int tempmm = ( m == (A->mt-1) ) ? A->m - m * A->mb : A->mb;
+        int nmin   = ( uplo == ChamUpper ) ? m                         : 0;
+        int nmax   = ( uplo == ChamLower ) ? chameleon_min(m+1, A->nt) : A->nt;
+
+        for(n = nmin; n < nmax; n++) {
+            CHAM_tile_t *tile = A->get_blktile( A, m, n );
+
+            if (( tile->rank == A->myrank ) &&
+                ( tile->flttype != ChamComplexDouble ) )
+            {
+                int tempnn = ( n == (A->nt-1) ) ? A->n - n * A->nb : A->nb;
+
+                INSERT_TASK_zgerst( &options,
+                                     tempmm, tempnn, A( m, n ) );
+            }
+        }
+    }
+
+    RUNTIME_options_finalize(&options, chamctxt);
+}
diff --git a/compute/zgerst.c b/compute/zgerst.c
new file mode 100644
index 000000000..8d283a6f8
--- /dev/null
+++ b/compute/zgerst.c
@@ -0,0 +1,148 @@
+/**
+ *
+ * @file zgerst.c
+ *
+ * @copyright 2012-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
+ *                      Univ. Bordeaux. All rights reserved.
+ *
+ ***
+ *
+ * @brief Chameleon auxiliary routines to restore the original precision of a matrix.
+ *
+ * @version 1.3.0
+ * @author Mathieu Faverge
+ * @author Yuxi Hong
+ * @date 2023-07-06
+ * @precisions normal z -> d
+ *
+ */
+#include "control/common.h"
+
+/**
+ ********************************************************************************
+ *
+ * @ingroup CHAMELEON_Complex64_t_Tile
+ *
+ * @brief Restore the original precision of a given matrix that may have been
+ * used in reduced precision during some computations. See
+ * CHAMELEON_zgered_Tile() to introduce mixed-precision tiles into the matrix.
+ *
+ *******************************************************************************
+ *
+ * @param[in] uplo
+ *          Specifies the shape of the matrix A:
+ *          = ChamUpper: A is upper triangular;
+ *          = ChamLower: A is lower triangular;
+ *          = ChamUpperLower: A is general.
+ *
+ * @param[in] A
+ *          Descriptor of the CHAMELEON matrix to restore.
+ *
+ *******************************************************************************
+ *
+ * @retval CHAMELEON_SUCCESS successful exit
+ *
+ *******************************************************************************
+ *
+ * @sa CHAMELEON_zgered_Tile
+ * @sa CHAMELEON_zgered_Tile_Async
+ * @sa CHAMELEON_zgerst_Tile
+ * @sa CHAMELEON_cgerst_Tile
+ * @sa CHAMELEON_dgerst_Tile
+ * @sa CHAMELEON_sgerst_Tile
+ *
+ */
+int CHAMELEON_zgerst_Tile( cham_uplo_t  uplo,
+                            CHAM_desc_t *A )
+{
+    CHAM_context_t *chamctxt;
+    RUNTIME_sequence_t *sequence = NULL;
+    RUNTIME_request_t request = RUNTIME_REQUEST_INITIALIZER;
+    int status;
+
+    chamctxt = chameleon_context_self();
+    if (chamctxt == NULL) {
+        chameleon_fatal_error("CHAMELEON_zgerst_Tile", "CHAMELEON not initialized");
+        return CHAMELEON_ERR_NOT_INITIALIZED;
+    }
+    chameleon_sequence_create( chamctxt, &sequence );
+
+    CHAMELEON_zgerst_Tile_Async( uplo, A, sequence, &request );
+
+    CHAMELEON_Desc_Flush( A, sequence );
+
+    chameleon_sequence_wait( chamctxt, sequence );
+    status = sequence->status;
+    chameleon_sequence_destroy( chamctxt, sequence );
+    return status;
+}
+
+/**
+ ********************************************************************************
+ *
+ * @ingroup CHAMELEON_Complex64_t_Tile_Async
+ *
+ * @brief Restore the original precision of a given matrix that may have been
+ * used in reduced precision during some computations. See
+ * CHAMELEON_zgered_Tile() to introduce mixed-precision tiles into the matrix.
+ *
+ * This is the non-blocking equivalent of CHAMELEON_zgerst_Tile(). It
+ * operates on matrices stored by tiles with tiles of potentially different
+ * precisions.  All matrices are passed through descriptors.  All dimensions are
+ * taken from the descriptors. It may return before the computation is
+ * finished. This function allows for pipelining operations at runtime.
+ *
+ *******************************************************************************
+ *
+ * @param[in] sequence
+ *          Identifies the sequence of function calls that this call belongs to
+ *          (for completion checks and exception handling purposes).
+ *
+ * @param[out] request
+ *          Identifies this function call (for exception handling purposes).
+ *
+ *******************************************************************************
+ *
+ * @sa CHAMELEON_zgerst_Tile
+ * @sa CHAMELEON_zgered_Tile
+ * @sa CHAMELEON_zgered_Tile_Async
+ *
+ */
+int CHAMELEON_zgerst_Tile_Async( cham_uplo_t         uplo,
+                                  CHAM_desc_t        *A,
+                                  RUNTIME_sequence_t *sequence,
+                                  RUNTIME_request_t  *request )
+{
+    CHAM_context_t *chamctxt;
+
+    chamctxt = chameleon_context_self();
+    if (chamctxt == NULL) {
+        chameleon_fatal_error("CHAMELEON_zgerst_Tile_Async", "CHAMELEON not initialized");
+        return CHAMELEON_ERR_NOT_INITIALIZED;
+    }
+    if (sequence == NULL) {
+        chameleon_fatal_error("CHAMELEON_zgerst_Tile_Async", "NULL sequence");
+        return CHAMELEON_ERR_UNALLOCATED;
+    }
+    if (request == NULL) {
+        chameleon_fatal_error("CHAMELEON_zgerst_Tile_Async", "NULL request");
+        return CHAMELEON_ERR_UNALLOCATED;
+    }
+    /* Check sequence status */
+    if (sequence->status == CHAMELEON_SUCCESS) {
+        request->status = CHAMELEON_SUCCESS;
+    }
+    else {
+        return chameleon_request_fail(sequence, request, CHAMELEON_ERR_SEQUENCE_FLUSHED);
+    }
+
+    /* Check descriptors for correctness */
+    if (chameleon_desc_check(A) != CHAMELEON_SUCCESS) {
+        chameleon_error("CHAMELEON_zgerst_Tile_Async", "invalid descriptor");
+        return chameleon_request_fail(sequence, request, CHAMELEON_ERR_ILLEGAL_VALUE);
+    }
+
+    chameleon_pzgerst( uplo, A, sequence, request );
+
+    return CHAMELEON_SUCCESS;
+}
diff --git a/control/compute_z.h b/control/compute_z.h
index 5bcb8e0cd..06eae17b5 100644
--- a/control/compute_z.h
+++ b/control/compute_z.h
@@ -79,6 +79,8 @@ int chameleon_zshift(CHAM_context_t *chamctxt, int m, int n, CHAMELEON_Complex64
 #if defined(PRECISION_z) || defined(PRECISION_d)
 void chameleon_pzgered( cham_uplo_t uplo, double prec, CHAM_desc_t *A,
                         RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
+void chameleon_pzgerst( cham_uplo_t uplo, CHAM_desc_t *A,
+                        RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
 #endif
 int chameleon_pzgebrd( int genD, cham_job_t jobu, cham_job_t jobvt,
                        CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *D,
diff --git a/include/chameleon/chameleon_z.h b/include/chameleon/chameleon_z.h
index c60da7dae..fa5f069e6 100644
--- a/include/chameleon/chameleon_z.h
+++ b/include/chameleon/chameleon_z.h
@@ -170,6 +170,7 @@ int CHAMELEON_zposv_Tile(cham_uplo_t uplo, CHAM_desc_t *A, CHAM_desc_t *B);
 int CHAMELEON_zpotrf_Tile(cham_uplo_t uplo, CHAM_desc_t *A);
 #if defined(PRECISION_z) || defined(PRECISION_d)
 int CHAMELEON_zgered_Tile( cham_uplo_t uplo, double prec, CHAM_desc_t *A );
+int CHAMELEON_zgerst_Tile( cham_uplo_t uplo, CHAM_desc_t *A );
 #endif
 int CHAMELEON_zsytrf_Tile(cham_uplo_t uplo, CHAM_desc_t *A);
 int CHAMELEON_zpotri_Tile(cham_uplo_t uplo, CHAM_desc_t *A);
@@ -250,6 +251,7 @@ int CHAMELEON_zposv_Tile_Async(cham_uplo_t uplo, CHAM_desc_t *A, CHAM_desc_t *B,
 int CHAMELEON_zpotrf_Tile_Async(cham_uplo_t uplo, CHAM_desc_t *A, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
 #if defined(PRECISION_z) || defined(PRECISION_d)
 int CHAMELEON_zgered_Tile_Async(cham_uplo_t uplo, double prec, CHAM_desc_t *A, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
+int CHAMELEON_zgerst_Tile_Async( cham_uplo_t uplo, CHAM_desc_t *A, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
 #endif
 int CHAMELEON_zsytrf_Tile_Async(cham_uplo_t uplo, CHAM_desc_t *A, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
 int CHAMELEON_zpotri_Tile_Async(cham_uplo_t uplo, CHAM_desc_t *A, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
diff --git a/include/chameleon/tasks_z.h b/include/chameleon/tasks_z.h
index 794d1fe61..b58895aa4 100644
--- a/include/chameleon/tasks_z.h
+++ b/include/chameleon/tasks_z.h
@@ -81,6 +81,9 @@ void INSERT_TASK_zgeqrt( const RUNTIME_option_t *options,
 void INSERT_TASK_zgered( const RUNTIME_option_t *options,
                          double threshold, double Anorm, int m, int n,
                          const CHAM_desc_t *A, int Am, int An );
+void INSERT_TASK_zgerst( const RUNTIME_option_t *options,
+                         int m, int n,
+                         const CHAM_desc_t *A, int Am, int An );
 void INSERT_TASK_zgessm( const RUNTIME_option_t *options,
                          int m, int n, int k, int ib, int nb,
                          int *IPIV,
diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt
index d5f0e12ca..fc1aac335 100644
--- a/runtime/CMakeLists.txt
+++ b/runtime/CMakeLists.txt
@@ -118,6 +118,7 @@ set(CODELETS_ZSRC
     # Precision modification kernels
     ##################
     codelets/codelet_zgered.c
+    codelets/codelet_zgerst.c
     )
 
 set(CODELETS_SRC
diff --git a/runtime/starpu/CMakeLists.txt b/runtime/starpu/CMakeLists.txt
index abb4f79aa..30ea76045 100644
--- a/runtime/starpu/CMakeLists.txt
+++ b/runtime/starpu/CMakeLists.txt
@@ -243,7 +243,6 @@ set(ZSRC
   codelets/codelet_zcallback.c
   codelets/codelet_zccallback.c
   codelets/codelet_dlag2h.c
-  codelets/codelet_zgerst.c
   ${CODELETS_ZSRC}
   )
 
diff --git a/runtime/starpu/codelets/codelet_zgerst.c b/runtime/starpu/codelets/codelet_zgerst.c
new file mode 100644
index 000000000..68490f011
--- /dev/null
+++ b/runtime/starpu/codelets/codelet_zgerst.c
@@ -0,0 +1,118 @@
+/**
+ *
+ * @file starpu/codelet_zgerst.c
+ *
+ * @copyright 2012-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
+ *                      Univ. Bordeaux. All rights reserved.
+ *
+ ***
+ *
+ * @brief Chameleon zgerst StarPU codelet
+ *
+ * @version 1.3.0
+ * @author Mathieu Faverge
+ * @date 2023-07-06
+ * @precisions normal z -> d
+ *
+ */
+#include "chameleon_starpu.h"
+#include <coreblas/lapacke.h>
+#include "runtime_codelet_zc.h"
+#include "runtime_codelet_z.h"
+
+//#define CHAMELEON_DEBUG_GERST
+
+void INSERT_TASK_zgerst( const RUNTIME_option_t *options,
+                         int m, int n,
+                         const CHAM_desc_t *A, int Am, int An )
+{
+    CHAM_tile_t          *tileA;
+    int64_t               mm, nn;
+#if defined(CHAMELEON_USE_MPI)
+    int                   tag;
+#endif
+    starpu_data_handle_t *handleAin;
+    starpu_data_handle_t  handleAout;
+
+    CHAMELEON_BEGIN_ACCESS_DECLARATION;
+    CHAMELEON_ACCESS_RW(A, Am, An);
+    CHAMELEON_END_ACCESS_DECLARATION;
+
+    tileA = A->get_blktile( A, Am, An );
+    if ( tileA->flttype == ChamComplexDouble ) {
+        return;
+    }
+
+    /* Get the Input handle */
+    mm = Am + (A->i / A->mb);
+    nn = An + (A->j / A->nb);
+    handleAin = A->schedopt;
+    handleAin += ((int64_t)A->lmt) * nn + mm;
+
+    assert( *handleAin != NULL );
+
+#if defined(CHAMELEON_USE_MPI)
+    tag = starpu_mpi_data_get_tag( *handleAin );
+#endif /* defined(CHAMELEON_USE_MPI) */
+
+    starpu_cham_tile_register( &handleAout, -1, tileA, ChamComplexDouble );
+
+    switch( tileA->flttype ) {
+#if defined(CHAMELEON_USE_CUDA) && (CUDA_VERSION >= 7500)
+#if defined(PRECISION_d)
+    /*
+     * Restore from half precision
+     */
+    case ChamComplexHalf:
+#if defined(CHAMELEON_DEBUG_GERST)
+        fprintf( stderr,
+                 "[%2d] Convert back the tile ( %d, %d ) from half precision\n",
+                 A->myrank, Am, An );
+#endif
+        rt_starpu_insert_task(
+            &cl_hlag2d,
+            STARPU_VALUE,    &m,                 sizeof(int),
+            STARPU_VALUE,    &n,                 sizeof(int),
+            STARPU_R,        *handleAin,
+            STARPU_W,         handleAout,
+            STARPU_PRIORITY,  options->priority,
+            STARPU_EXECUTE_ON_WORKER, options->workerid,
+#if defined(CHAMELEON_CODELETS_HAVE_NAME)
+            STARPU_NAME, "hlag2d",
+#endif
+            0);
+        break;
+#endif
+#endif
+
+    case ChamComplexFloat:
+#if defined(CHAMELEON_DEBUG_GERST)
+        fprintf( stderr,
+                 "[%2d] Convert back the tile ( %d, %d ) from half precision\n",
+                 A->myrank, Am, An );
+#endif
+        rt_starpu_insert_task(
+            &cl_clag2z,
+            STARPU_VALUE,    &m,                 sizeof(int),
+            STARPU_VALUE,    &n,                 sizeof(int),
+            STARPU_R,        *handleAin,
+            STARPU_W,         handleAout,
+            STARPU_PRIORITY,  options->priority,
+            STARPU_EXECUTE_ON_WORKER, options->workerid,
+#if defined(CHAMELEON_CODELETS_HAVE_NAME)
+            STARPU_NAME, "clag2z",
+#endif
+            0);
+        break;
+
+    default:
+        fprintf( stderr, "ERROR: Unknonw input datatype" );
+    }
+
+    starpu_data_unregister_submit( *handleAin );
+    *handleAin = handleAout;
+    tileA->flttype = ChamComplexDouble;
+#if defined(CHAMELEON_USE_MPI)
+    starpu_mpi_data_register( handleAout, tag, tileA->rank );
+#endif
+}
-- 
GitLab