diff --git a/include/chameleon/tasks.h b/include/chameleon/tasks.h index 39e741a66ba4d6b10582acfdaf1fad8aa85f34c7..01234dabf284796110ade07e8d9304debabf480a 100644 --- a/include/chameleon/tasks.h +++ b/include/chameleon/tasks.h @@ -100,6 +100,13 @@ void INSERT_TASK_map( const RUNTIME_option_t *options, cham_access_t accessA, cham_uplo_t uplo, const CHAM_desc_t *A, int Am, int An, cham_unary_operator_t op_fct, void *op_args ); +void INSERT_TASK_gemm( const RUNTIME_option_t *options, + cham_trans_t transA, cham_trans_t transB, + int m, int n, int k, int nb, + double alpha, const CHAM_desc_t *A, int Am, int An, + const CHAM_desc_t *B, int Bm, int Bn, + double beta, const CHAM_desc_t *C, int Cm, int Cn ); + void INSERT_TASK_hgemm( const RUNTIME_option_t *options, cham_trans_t transA, cham_trans_t transB, int m, int n, int k, int nb, diff --git a/runtime/starpu/CMakeLists.txt b/runtime/starpu/CMakeLists.txt index c6e600f17bdd10e26f41fbb2592627dbefb66cf1..30ea76045131884c610a0f4ee430393f732f6360 100644 --- a/runtime/starpu/CMakeLists.txt +++ b/runtime/starpu/CMakeLists.txt @@ -253,6 +253,7 @@ precisions_rules_py(RUNTIME_SRCS_GENERATED "${ZSRC}" set(CODELETS_SRC codelets/codelet_convert.c codelets/codelet_hgemm.c + codelets/codelet_gemm.c ${CODELETS_SRC} ) diff --git a/runtime/starpu/codelets/codelet_gemm.c b/runtime/starpu/codelets/codelet_gemm.c new file mode 100644 index 0000000000000000000000000000000000000000..69cd67a294d2fcdf116e28326072702527a8d699 --- /dev/null +++ b/runtime/starpu/codelets/codelet_gemm.c @@ -0,0 +1,207 @@ +/** + * + * @file starpu/codelet_gemm.c + * + * @copyright 2009-2014 The University of Tennessee and The University of + * Tennessee Research Foundation. All rights reserved. + * @copyright 2012-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria, + * Univ. Bordeaux. All rights reserved. + * + *** + * + * @brief Chameleon gemm StarPU codelet + * + * @version 1.3.0 + * @author Mathieu Faverge + * @date 2023-07-06 + * + */ +#include "chameleon_starpu.h" +#include "runtime_codelets.h" +#define PRECISION_z +#include "runtime_codelet_z.h" +#undef PRECISION_z +#define PRECISION_d +#include "runtime_codelet_d.h" +#undef PRECISION_d +#define PRECISION_c +#include "runtime_codelet_c.h" +#undef PRECISION_c +#define PRECISION_s +#include "runtime_codelet_s.h" +#undef PRECISION_s + +void +INSERT_TASK_gemm( const RUNTIME_option_t *options, + cham_trans_t transA, cham_trans_t transB, + int m, int n, int k, int nb, + double alpha, const CHAM_desc_t *A, int Am, int An, + const CHAM_desc_t *B, int Bm, int Bn, + double beta, const CHAM_desc_t *C, int Cm, int Cn ) +{ + struct starpu_codelet *codelet = NULL; + void (*callback)(void*) = NULL; + + /* if ( alpha == 0. ) { */ + /* INSERT_TASK_zlascal( options, ChamUpperLower, m, n, nb, */ + /* beta, C, Cm, Cn ); */ + /* return; */ + /* } */ + + void *clargs = NULL; + int accessC; + int exec = 0; + size_t argssize = 0; + char *cl_name = "Xgemm"; + CHAM_tile_t *tileC; + cham_flttype_t Cflttype; + + /* Handle cache */ + CHAMELEON_BEGIN_ACCESS_DECLARATION; + CHAMELEON_ACCESS_R(A, Am, An); + CHAMELEON_ACCESS_R(B, Bm, Bn); + CHAMELEON_ACCESS_RW(C, Cm, Cn); + exec = __chameleon_need_exec; + CHAMELEON_END_ACCESS_DECLARATION; + + /* Reduce the C access if needed */ + accessC = ( beta == 0. ) ? STARPU_W : (STARPU_RW | ((beta == 1.) ? STARPU_COMMUTE : 0)); + + tileC = C->get_blktile( C, Cm, Cn ); + Cflttype = tileC->flttype; + + switch( Cflttype ) { +#if defined(CHAMELEON_PREC_Z) + case ChamComplexDouble: + codelet = &cl_zgemm; + callback = cl_zgemm_callback; + if ( exec ) { + struct cl_zgemm_args_s *cl_zargs; + cl_zargs = malloc( sizeof( struct cl_zgemm_args_s ) ); + cl_zargs->transA = transA; + cl_zargs->transB = transB; + cl_zargs->m = m; + cl_zargs->n = n; + cl_zargs->k = k; + cl_zargs->alpha = alpha; + cl_zargs->beta = beta; + clargs = (void*)cl_zargs; + argssize = sizeof( struct cl_zgemm_args_s ); + } + break; +#endif +#if defined(CHAMELEON_PREC_C) + case ChamComplexSingle: + codelet = &cl_cgemm; + callback = cl_cgemm_callback; + if ( exec ) { + struct cl_cgemm_args_s *cl_cargs; + cl_cargs = malloc( sizeof( struct cl_cgemm_args_s ) ); + cl_cargs->transA = transA; + cl_cargs->transB = transB; + cl_cargs->m = m; + cl_cargs->n = n; + cl_cargs->k = k; + cl_cargs->alpha = alpha; + cl_cargs->beta = beta; + clargs = (void*)cl_cargs; + argssize = sizeof( struct cl_cgemm_args_s ); + } + break; +#endif +#if defined(CHAMELEON_PREC_D) + case ChamRealDouble: + codelet = &cl_dgemm; + callback = cl_dgemm_callback; + if ( exec ) { + struct cl_dgemm_args_s *cl_dargs; + cl_dargs = malloc( sizeof( struct cl_dgemm_args_s ) ); + cl_dargs->transA = transA; + cl_dargs->transB = transB; + cl_dargs->m = m; + cl_dargs->n = n; + cl_dargs->k = k; + cl_dargs->alpha = alpha; + cl_dargs->beta = beta; + clargs = (void*)cl_dargs; + argssize = sizeof( struct cl_dgemm_args_s ); + } + break; +#endif +#if defined(CHAMELEON_PREC_S) + case ChamRealSingle: + codelet = &cl_sgemm; + callback = cl_sgemm_callback; + if ( exec ) { + struct cl_sgemm_args_s *cl_sargs; + cl_sargs = malloc( sizeof( struct cl_sgemm_args_s ) ); + cl_sargs->transA = transA; + cl_sargs->transB = transB; + cl_sargs->m = m; + cl_sargs->n = n; + cl_sargs->k = k; + cl_sargs->alpha = alpha; + cl_sargs->beta = beta; + clargs = (void*)cl_sargs; + argssize = sizeof( struct cl_sgemm_args_s ); + } + break; +#endif +#if (defined(CHAMELEON_PREC_D) || defined(CHAMELEON_PREC_S)) && defined(CHAMELEON_USE_CUDA) + case ChamRealHalf: + codelet = &cl_hgemm; + callback = cl_hgemm_callback; + if ( exec ) { + struct cl_hgemm_args_s *cl_hargs; + cl_hargs = malloc( sizeof( struct cl_hgemm_args_s ) ); + cl_hargs->transA = transA; + cl_hargs->transB = transB; + cl_hargs->m = m; + cl_hargs->n = n; + cl_hargs->k = k; + cl_hargs->alpha = alpha; + cl_hargs->beta = beta; + clargs = (void*)cl_hargs; + argssize = sizeof( struct cl_hgemm_args_s ); + } + break; +#endif + default: + fprintf( stderr, "INSERT_TASK_gemm: Unknown datatype %d (Mixed=%3s, Type=%d, Size=%d\n", + Cflttype, cham_is_mixed(Cflttype) ? "Yes" : "No", + cham_get_ftype(Cflttype), cham_get_arith(Cflttype) ); + return; + } + + /* Refine name */ + cl_name = chameleon_codelet_name( cl_name, 3, + A->get_blktile( A, Am, An ), + B->get_blktile( B, Bm, Bn ), + C->get_blktile( C, Cm, Cn ) ); + + /* Callback for profiling information */ + callback = options->profiling ? callback : NULL; + + /* Insert the task */ + rt_starpu_insert_task( + codelet, + /* Task codelet arguments */ + STARPU_CL_ARGS, clargs, argssize, + + /* Task handles */ + STARPU_R, RUNTIME_data_getaddr_withconversion( options, STARPU_R, Cflttype, A, Am, An ), + STARPU_R, RUNTIME_data_getaddr_withconversion( options, STARPU_R, Cflttype, B, Bm, Bn ), + accessC, RUNTIME_data_getaddr_withconversion( options, accessC, Cflttype, C, Cm, Cn ), + + /* Common task arguments */ + STARPU_PRIORITY, options->priority, + STARPU_CALLBACK, callback, + STARPU_EXECUTE_ON_WORKER, options->workerid, + STARPU_POSSIBLY_PARALLEL, options->parallel, +#if defined(CHAMELEON_CODELETS_HAVE_NAME) + STARPU_NAME, cl_name, +#endif + 0 ); + + return; +} diff --git a/runtime/starpu/codelets/codelet_zgemm.c b/runtime/starpu/codelets/codelet_zgemm.c index 5ac2f76cf17d9c0c1c8ed0a1d686b9252388f824..4b47627258ba618751e82a33945f3fbeed27c8bc 100644 --- a/runtime/starpu/codelets/codelet_zgemm.c +++ b/runtime/starpu/codelets/codelet_zgemm.c @@ -31,16 +31,6 @@ #include "chameleon_starpu.h" #include "runtime_codelet_z.h" -struct cl_zgemm_args_s { - cham_trans_t transA; - cham_trans_t transB; - int m; - int n; - int k; - CHAMELEON_Complex64_t alpha; - CHAMELEON_Complex64_t beta; -}; - #if !defined(CHAMELEON_SIMULATION) static void cl_zgemm_cpu_func( void *descr[], void *cl_arg ) diff --git a/runtime/starpu/include/runtime_codelet_z.h b/runtime/starpu/include/runtime_codelet_z.h index 03f2dee935938ebf0202b09bcfe53442950f70e6..f3f89af41aab00bc870fa663f2d8aff4e0dea39b 100644 --- a/runtime/starpu/include/runtime_codelet_z.h +++ b/runtime/starpu/include/runtime_codelet_z.h @@ -139,4 +139,14 @@ CODELETS_HEADER(dlag2h); CODELETS_HEADER(hlag2d); #endif +struct cl_zgemm_args_s { + cham_trans_t transA; + cham_trans_t transB; + int m; + int n; + int k; + CHAMELEON_Complex64_t alpha; + CHAMELEON_Complex64_t beta; +}; + #endif /* _runtime_codelet_z_h_ */