From 6ae3b3ac63ba5876d22f3d5d31242c738d4eaf0f Mon Sep 17 00:00:00 2001 From: Florent Pruvost <florent.pruvost@inria.fr> Date: Fri, 24 Nov 2023 09:15:34 +0100 Subject: [PATCH] Fix --forcegpu option to be able to force starpu to execute kernels on gpus, RUNTIME_CUDA was not well defined and introduce RUNTIME_HIP. --- include/chameleon/runtime_struct.h | 3 ++- runtime/starpu/control/runtime_zlocality.c | 26 +++++++++++++++++++++- testing/chameleon_ztesting.c | 13 ++++++++++- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/include/chameleon/runtime_struct.h b/include/chameleon/runtime_struct.h index 4b1727825..4d345a0fd 100644 --- a/include/chameleon/runtime_struct.h +++ b/include/chameleon/runtime_struct.h @@ -39,7 +39,8 @@ typedef enum runtime_id_e { * @brief Ids of the worker type */ #define RUNTIME_CPU ((1ULL)<<1) -#define RUNTIME_CUDA ((1ULL)<<3) +#define RUNTIME_CUDA ((1ULL)<<2) +#define RUNTIME_HIP ((1ULL)<<3) /** * @brief RUNTIME request structure diff --git a/runtime/starpu/control/runtime_zlocality.c b/runtime/starpu/control/runtime_zlocality.c index 3ee9881d4..216c00679 100644 --- a/runtime/starpu/control/runtime_zlocality.c +++ b/runtime/starpu/control/runtime_zlocality.c @@ -23,11 +23,32 @@ #include "chameleon_starpu.h" #include "runtime_codelet_z.h" -#ifdef CHAMELEON_USE_CUDA +#if defined(CHAMELEON_USE_CUDA) || defined(CHAMELEON_USE_HIP) + +/* Convert worker id from Chameleon to Starpu */ +static uint32_t cham_to_starpu_where( uint32_t where ) +{ + int32_t starpu_where = 0; + + if ( where & RUNTIME_CPU ) { + starpu_where |= STARPU_CPU; + } + if ( where & RUNTIME_CUDA ) { + starpu_where |= STARPU_CUDA; + } + if ( where & RUNTIME_HIP ) { + starpu_where |= STARPU_HIP; + } + return starpu_where; +} + /* Only codelets with multiple choices are present here */ void RUNTIME_zlocality_allrestrict( uint32_t where ) { + /* Convert worker id from Chameleon to Starpu */ + where = cham_to_starpu_where( where ); + /* Blas 3 */ cl_zgemm_restrict_where( where ); #if defined(PRECISION_z) || defined(PRECISION_c) @@ -73,6 +94,9 @@ void RUNTIME_zlocality_allrestrict( uint32_t where ) void RUNTIME_zlocality_onerestrict( cham_tasktype_t kernel, uint32_t where ) { + /* Convert worker id from Chameleon to Starpu */ + where = cham_to_starpu_where( where ); + switch( kernel ) { /* Blas 3 */ case TASK_GEMM: cl_zgemm_restrict_where( where ); break; diff --git a/testing/chameleon_ztesting.c b/testing/chameleon_ztesting.c index 4d8317ed3..a5ce3ecd6 100644 --- a/testing/chameleon_ztesting.c +++ b/testing/chameleon_ztesting.c @@ -208,7 +208,18 @@ int main (int argc, char **argv) { info = 1; goto end; } - RUNTIME_zlocality_allrestrict( RUNTIME_CUDA ); +#if defined(CHAMELEON_SCHED_STARPU) + int restriction = 0; +#if defined(CHAMELEON_USE_CUDA) + restriction |= RUNTIME_CUDA; +#endif +#if defined(CHAMELEON_USE_HIP) + restriction |= RUNTIME_HIP; +#endif + if ( restriction != 0 ) { + RUNTIME_zlocality_allrestrict( restriction ); + } +#endif } /* Warmup */ -- GitLab