diff --git a/include/chameleon/runtime_struct.h b/include/chameleon/runtime_struct.h index 4b172782507b5b43e17fb4f3672b6aebf13a7c9d..4d345a0fdbc2be301a6af7d91a4da28be8bd4def 100644 --- a/include/chameleon/runtime_struct.h +++ b/include/chameleon/runtime_struct.h @@ -39,7 +39,8 @@ typedef enum runtime_id_e { * @brief Ids of the worker type */ #define RUNTIME_CPU ((1ULL)<<1) -#define RUNTIME_CUDA ((1ULL)<<3) +#define RUNTIME_CUDA ((1ULL)<<2) +#define RUNTIME_HIP ((1ULL)<<3) /** * @brief RUNTIME request structure diff --git a/runtime/starpu/control/runtime_zlocality.c b/runtime/starpu/control/runtime_zlocality.c index 3ee9881d40a2b380ac15068c8035bbd89fc8aff3..216c0067929acd7fb3c3847ad6cb4ce3390e6cff 100644 --- a/runtime/starpu/control/runtime_zlocality.c +++ b/runtime/starpu/control/runtime_zlocality.c @@ -23,11 +23,32 @@ #include "chameleon_starpu.h" #include "runtime_codelet_z.h" -#ifdef CHAMELEON_USE_CUDA +#if defined(CHAMELEON_USE_CUDA) || defined(CHAMELEON_USE_HIP) + +/* Convert worker id from Chameleon to Starpu */ +static uint32_t cham_to_starpu_where( uint32_t where ) +{ + int32_t starpu_where = 0; + + if ( where & RUNTIME_CPU ) { + starpu_where |= STARPU_CPU; + } + if ( where & RUNTIME_CUDA ) { + starpu_where |= STARPU_CUDA; + } + if ( where & RUNTIME_HIP ) { + starpu_where |= STARPU_HIP; + } + return starpu_where; +} + /* Only codelets with multiple choices are present here */ void RUNTIME_zlocality_allrestrict( uint32_t where ) { + /* Convert worker id from Chameleon to Starpu */ + where = cham_to_starpu_where( where ); + /* Blas 3 */ cl_zgemm_restrict_where( where ); #if defined(PRECISION_z) || defined(PRECISION_c) @@ -73,6 +94,9 @@ void RUNTIME_zlocality_allrestrict( uint32_t where ) void RUNTIME_zlocality_onerestrict( cham_tasktype_t kernel, uint32_t where ) { + /* Convert worker id from Chameleon to Starpu */ + where = cham_to_starpu_where( where ); + switch( kernel ) { /* Blas 3 */ case TASK_GEMM: cl_zgemm_restrict_where( where ); break; diff --git a/testing/chameleon_ztesting.c b/testing/chameleon_ztesting.c index 4d8317ed37f1939ab2135059616ce766071ae12d..a5ce3ecd63cf095b33d3a7defdd880f8d61c0ecd 100644 --- a/testing/chameleon_ztesting.c +++ b/testing/chameleon_ztesting.c @@ -208,7 +208,18 @@ int main (int argc, char **argv) { info = 1; goto end; } - RUNTIME_zlocality_allrestrict( RUNTIME_CUDA ); +#if defined(CHAMELEON_SCHED_STARPU) + int restriction = 0; +#if defined(CHAMELEON_USE_CUDA) + restriction |= RUNTIME_CUDA; +#endif +#if defined(CHAMELEON_USE_HIP) + restriction |= RUNTIME_HIP; +#endif + if ( restriction != 0 ) { + RUNTIME_zlocality_allrestrict( restriction ); + } +#endif } /* Warmup */