From f0dfcdb7eb91a1baab8e2d2c4c792ea73dba1037 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Sat, 27 May 2023 11:40:38 -0400
Subject: [PATCH] starpu/codelets/gemmex: add gemmex codelet

---
 include/chameleon/tasks.h                 |   7 +
 runtime/starpu/codelets/codelet_gemmex.c  | 278 ++++++++++++++++++++++
 runtime/starpu/include/runtime_codelets.h |   1 +
 3 files changed, 286 insertions(+)
 create mode 100644 runtime/starpu/codelets/codelet_gemmex.c

diff --git a/include/chameleon/tasks.h b/include/chameleon/tasks.h
index 01234dabf..bc7a59e6f 100644
--- a/include/chameleon/tasks.h
+++ b/include/chameleon/tasks.h
@@ -107,6 +107,13 @@ void INSERT_TASK_gemm( const RUNTIME_option_t *options,
                        const CHAM_desc_t *B, int Bm, int Bn,
                        double beta, const CHAM_desc_t *C, int Cm, int Cn );
 
+void INSERT_TASK_gemmex( const RUNTIME_option_t *options,
+                         cham_trans_t transA, cham_trans_t transB,
+                         int m, int n, int k, int nb,
+                         double alpha, const CHAM_desc_t *A, int Am, int An,
+                         const CHAM_desc_t *B, int Bm, int Bn,
+                         double beta, const CHAM_desc_t *C, int Cm, int Cn );
+
 void INSERT_TASK_hgemm( const RUNTIME_option_t *options,
                         cham_trans_t transA, cham_trans_t transB,
                         int m, int n, int k, int nb,
diff --git a/runtime/starpu/codelets/codelet_gemmex.c b/runtime/starpu/codelets/codelet_gemmex.c
new file mode 100644
index 000000000..efc33dab3
--- /dev/null
+++ b/runtime/starpu/codelets/codelet_gemmex.c
@@ -0,0 +1,278 @@
+/**
+ *
+ * @file starpu/codelet_gemmex.c
+ *
+ * @copyright 2009-2014 The University of Tennessee and The University of
+ *                      Tennessee Research Foundation. All rights reserved.
+ * @copyright 2012-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
+ *                      Univ. Bordeaux. All rights reserved.
+ *
+ ***
+ *
+ * @brief Chameleon gemmex StarPU codelet
+ *
+ * @version 1.3.0
+ * @author Mathieu Faverge
+ * @date 2023-07-06
+ *
+ */
+#include "chameleon_starpu.h"
+#include "runtime_codelets.h"
+
+CHAMELEON_CL_CB( gemmex, cti_handle_get_m(task->handles[2]), cti_handle_get_n(task->handles[2]), cti_handle_get_n(task->handles[0]), 2. *M*N*K) /* If A^t, computation is wrong */
+
+struct cl_gemmex_args_s {
+    cham_trans_t transA;
+    cham_trans_t transB;
+    int m;
+    int n;
+    int k;
+    double alpha;
+    double beta;
+};
+
+#if !defined(CHAMELEON_SIMULATION)
+#if defined(CHAMELEON_USE_CUDA)
+static void
+cl_gemmex_cuda_func( void *descr[], void *cl_arg )
+{
+    struct cl_gemmex_args_s *clargs = (struct cl_gemmex_args_s *)cl_arg;
+    cublasHandle_t          handle = starpu_cublas_get_local_handle();
+    CHAM_tile_t *tileA;
+    CHAM_tile_t *tileB;
+    CHAM_tile_t *tileC;
+    void *ptrAlpha, *ptrBeta;
+
+    switch( tileC->flttype ) {
+    case ChamRealHalf:
+    {
+        CHAMELEON_Real16_t halpha = clargs->alpha;
+        CHAMELEON_Real16_t hbeta  = clargs->beta;
+        ptrAlpha = &halpha;
+        ptrBeta  = &hbeta;
+    }
+    break;
+    case ChamRealFloat:
+    {
+        float salpha = clargs->alpha;
+        float sbeta  = clargs->beta;
+        ptrAlpha = &salpha;
+        ptrBeta  = &sbeta;
+    }
+    break;
+    case ChamRealDouble:
+    {
+        double dalpha = clargs->alpha;
+        double dbeta  = clargs->beta;
+        ptrAlpha = &dalpha;
+        ptrBeta  = &dbeta;
+    }
+    break;
+    case ChamComplexFloat:
+    {
+        CHAMELEON_Complex32_t calpha = clargs->alpha;
+        CHAMELEON_Complex32_t cbeta  = clargs->beta;
+        ptrAlpha = &calpha;
+        ptrBeta  = &cbeta;
+    }
+    break;
+    case ChamComplexDouble:
+    {
+        CHAMELEON_Complex64_t zalpha = clargs->alpha;
+        CHAMELEON_Complex64_t zbeta  = clargs->beta;
+        ptrAlpha = &zalpha;
+        ptrBeta  = &zbeta;
+    }
+    break;
+    default:
+        fprintf( stderr, "cl_gemmex: Unknown C datatype\n" );
+        return;
+    }
+
+    tileA = cti_interface_get(descr[0]);
+    tileB = cti_interface_get(descr[1]);
+    tileC = cti_interface_get(descr[2]);
+
+    assert( tileA->format & CHAMELEON_TILE_FULLRANK );
+    assert( tileB->format & CHAMELEON_TILE_FULLRANK );
+    assert( tileC->format & CHAMELEON_TILE_FULLRANK );
+
+    CUDA_gemmex(
+        clargs->transA, clargs->transB,
+        clargs->m, clargs->n, clargs->k,
+        ptrAlpha,
+        tileA->mat, tileA->ld, tileA->flttype,
+        tileB->mat, tileB->ld, tileB->flttype,
+        ptrBeta,
+        tileC->mat, tileC->ld, tileC->flttype,
+        handle );
+}
+#endif /* defined(CHAMELEON_USE_CUDA) */
+#endif /* !defined(CHAMELEON_SIMULATION) */
+
+/*
+ * Codelet definition
+ */
+CODELETS( gemmex, NULL, cl_gemmex_cuda_func, STARPU_CUDA_ASYNC )
+
+void INSERT_TASK_gemmex_Astat( const RUNTIME_option_t *options,
+                               cham_trans_t transA, cham_trans_t transB,
+                               int m, int n, int k, int nb,
+                               double alpha, const CHAM_desc_t *A, int Am, int An,
+                                             const CHAM_desc_t *B, int Bm, int Bn,
+                               double beta,  const CHAM_desc_t *C, int Cm, int Cn )
+{
+    /* if ( alpha == 0. ) { */
+    /*     INSERT_TASK_hlascal( options, ChamUpperLower, m, n, nb, */
+    /*                          beta, C, Cm, Cn ); */
+    /*     return; */
+    /* } */
+
+    struct cl_gemmex_args_s  *clargs = NULL;
+    void (*callback)(void*);
+    int                      accessC;
+    int                      exec    = 0;
+    char                    *cl_name = "gemmex_Astat";
+
+    /* Handle cache */
+    CHAMELEON_BEGIN_ACCESS_DECLARATION;
+     /* Check A as write, since it will be the owner of the computation */
+    CHAMELEON_ACCESS_W(A, Am, An);
+    CHAMELEON_ACCESS_R(B, Bm, Bn);
+     /* Check C as read, since it will be used in a reduction */
+    CHAMELEON_ACCESS_R(C, Cm, Cn);
+    exec = __chameleon_need_exec;
+    CHAMELEON_END_ACCESS_DECLARATION;
+
+    if ( exec ) {
+        clargs = malloc( sizeof( struct cl_gemmex_args_s ) );
+        clargs->transA = transA;
+        clargs->transB = transB;
+        clargs->m      = m;
+        clargs->n      = n;
+        clargs->k      = k;
+        clargs->alpha  = alpha;
+        clargs->beta   = beta;
+    }
+
+    /* Callback for profiling information */
+    callback = options->profiling ? cl_gemmex_callback : NULL;
+
+    /* Reduce the C access if needed */
+    if ( beta == 0. ) {
+        accessC = STARPU_W;
+    }
+#if defined(HAVE_STARPU_MPI_REDUX)
+    else if ( beta == 1. ) {
+        accessC = STARPU_MPI_REDUX;
+    }
+#endif
+    else {
+        accessC = STARPU_RW;
+    }
+
+    /* Refine name */
+    cl_name = chameleon_codelet_name( cl_name, 3,
+                                      A->get_blktile( A, Am, An ),
+                                      B->get_blktile( B, Bm, Bn ),
+                                      C->get_blktile( C, Cm, Cn ) );
+
+    /* Insert the task */
+    rt_starpu_insert_task(
+        &cl_gemmex,
+        /* Task codelet arguments */
+        STARPU_CL_ARGS, clargs, sizeof(struct cl_gemmex_args_s),
+
+        /* Task handles */
+        STARPU_R, RTBLKADDR(A, ChamRealHalf, Am, An),
+        STARPU_R, RTBLKADDR(B, ChamRealHalf, Bm, Bn),
+        accessC,  RTBLKADDR(C, ChamRealHalf, Cm, Cn),
+
+        /* Common task arguments */
+        STARPU_PRIORITY,          options->priority,
+        STARPU_CALLBACK,          callback,
+        STARPU_EXECUTE_ON_NODE,   A->get_rankof(A, Am, An),
+#if defined(CHAMELEON_CODELETS_HAVE_NAME)
+        STARPU_NAME,              cl_name,
+#endif
+        0 );
+}
+
+void INSERT_TASK_gemmex( const RUNTIME_option_t *options,
+                         cham_trans_t transA, cham_trans_t transB,
+                         int m, int n, int k, int nb,
+                         double alpha, const CHAM_desc_t *A, int Am, int An,
+                                       const CHAM_desc_t *B, int Bm, int Bn,
+                         double beta,  const CHAM_desc_t *C, int Cm, int Cn )
+{
+    /* if ( alpha == 0. ) { */
+    /*     INSERT_TASK_zlascal( options, ChamUpperLower, m, n, nb, */
+    /*                          beta, C, Cm, Cn ); */
+    /*     return; */
+    /* } */
+
+    struct cl_gemmex_args_s  *clargs = NULL;
+    void (*callback)(void*);
+    int                      accessC;
+    int                      exec = 0;
+    char                    *cl_name = "gemmex";
+
+    if ( !(options->withcuda) ) {
+        /* Fallback to cpu version */
+        INSERT_TASK_gemm( options, transA, transB, m, n, k, nb,
+                          alpha, A, Am, An, B, Bm, Bn, beta, C, Cm, Cn );
+        return;
+    }
+
+    /* Handle cache */
+    CHAMELEON_BEGIN_ACCESS_DECLARATION;
+    CHAMELEON_ACCESS_R(A, Am, An);
+    CHAMELEON_ACCESS_R(B, Bm, Bn);
+    CHAMELEON_ACCESS_RW(C, Cm, Cn);
+    exec = __chameleon_need_exec;
+    CHAMELEON_END_ACCESS_DECLARATION;
+
+    if ( exec ) {
+        clargs = malloc( sizeof( struct cl_gemmex_args_s ) );
+        clargs->transA = transA;
+        clargs->transB = transB;
+        clargs->m      = m;
+        clargs->n      = n;
+        clargs->k      = k;
+        clargs->alpha  = alpha;
+        clargs->beta   = beta;
+    }
+
+    /* Callback for profiling information */
+    callback = options->profiling ? cl_gemmex_callback : NULL;
+
+    /* Reduce the C access if needed */
+    accessC = ( beta == 0. ) ? STARPU_W : (STARPU_RW | ((beta == 1.) ? STARPU_COMMUTE : 0));
+
+    /* Refine name */
+    cl_name = chameleon_codelet_name( cl_name, 3,
+                                      A->get_blktile( A, Am, An ),
+                                      B->get_blktile( B, Bm, Bn ),
+                                      C->get_blktile( C, Cm, Cn ) );
+
+    /* Insert the task */
+    rt_starpu_insert_task(
+        &cl_gemmex,
+        /* Task codelet arguments */
+        STARPU_CL_ARGS, clargs, sizeof(struct cl_gemmex_args_s),
+
+        /* Task handles */
+        STARPU_R, RTBLKADDR(A, ChamRealHalf, Am, An),
+        STARPU_R, RTBLKADDR(B, ChamRealHalf, Bm, Bn),
+        accessC,  RTBLKADDR(C, ChamRealHalf, Cm, Cn),
+
+        /* Common task arguments */
+        STARPU_PRIORITY,          options->priority,
+        STARPU_CALLBACK,          callback,
+        STARPU_EXECUTE_ON_WORKER, options->workerid,
+        STARPU_POSSIBLY_PARALLEL, options->parallel,
+#if defined(CHAMELEON_CODELETS_HAVE_NAME)
+        STARPU_NAME,              cl_name,
+#endif
+        0 );
+}
diff --git a/runtime/starpu/include/runtime_codelets.h b/runtime/starpu/include/runtime_codelets.h
index 9c9af1b6b..c27d6b913 100644
--- a/runtime/starpu/include/runtime_codelets.h
+++ b/runtime/starpu/include/runtime_codelets.h
@@ -152,6 +152,7 @@
 
 CODELETS_HEADER(map);
 CODELETS_HEADER(hgemm);
+CODELETS_HEADER(gemmex);
 
 struct cl_hgemm_args_s {
     cham_trans_t transA;
-- 
GitLab