diff --git a/runtime/starpu/codelets/codelet_zher2k.c b/runtime/starpu/codelets/codelet_zher2k.c index 587f8cb87e2f739e6a378a97d9c59f60c6fa1710..7176bbe393513f4d9db74f48b19c7b7d63dabbef 100644 --- a/runtime/starpu/codelets/codelet_zher2k.c +++ b/runtime/starpu/codelets/codelet_zher2k.c @@ -27,53 +27,50 @@ #include "chameleon_starpu_internal.h" #include "runtime_codelet_z.h" +struct cl_zher2k_args_s { + cham_uplo_t uplo; + cham_trans_t trans; + int n; + int k; + CHAMELEON_Complex64_t alpha; + double beta; +}; + #if !defined(CHAMELEON_SIMULATION) static void cl_zher2k_cpu_func(void *descr[], void *cl_arg) { - cham_uplo_t uplo; - cham_trans_t trans; - int n; - int k; - CHAMELEON_Complex64_t alpha; + struct cl_zher2k_args_s *clargs = (struct cl_zher2k_args_s *)cl_arg; CHAM_tile_t *tileA; CHAM_tile_t *tileB; - double beta; CHAM_tile_t *tileC; tileA = cti_interface_get(descr[0]); tileB = cti_interface_get(descr[1]); tileC = cti_interface_get(descr[2]); - starpu_codelet_unpack_args(cl_arg, &uplo, &trans, &n, &k, &alpha, &beta); - TCORE_zher2k(uplo, trans, - n, k, alpha, tileA, tileB, beta, tileC); + TCORE_zher2k( clargs->uplo, clargs->trans, + clargs->n, clargs->k, clargs->alpha, + tileA, tileB, clargs->beta, tileC ); } #if defined(CHAMELEON_USE_CUDA) static void cl_zher2k_cuda_func(void *descr[], void *cl_arg) { cublasHandle_t handle = starpu_cublas_get_local_handle(); - cham_uplo_t uplo; - cham_trans_t trans; - int n; - int k; - cuDoubleComplex alpha; + struct cl_zher2k_args_s *clargs = (struct cl_zher2k_args_s *)cl_arg; CHAM_tile_t *tileA; CHAM_tile_t *tileB; - double beta; CHAM_tile_t *tileC; tileA = cti_interface_get(descr[0]); tileB = cti_interface_get(descr[1]); tileC = cti_interface_get(descr[2]); - starpu_codelet_unpack_args(cl_arg, &uplo, &trans, &n, &k, &alpha, &beta); - - CUDA_zher2k( uplo, trans, - n, k, - &alpha, tileA->mat, tileA->ld, - tileB->mat, tileB->ld, - &beta, tileC->mat, tileC->ld, + CUDA_zher2k( clargs->uplo, clargs->trans, + clargs->n, clargs->k, + (cuDoubleComplex*)&(clargs->alpha), tileA->mat, tileA->ld, + tileB->mat, tileB->ld, + &(clargs->beta), tileC->mat, tileC->ld, handle ); } #endif /* defined(CHAMELEON_USE_CUDA) */ @@ -82,28 +79,21 @@ static void cl_zher2k_cuda_func(void *descr[], void *cl_arg) static void cl_zher2k_hip_func(void *descr[], void *cl_arg) { hipblasHandle_t handle = starpu_hipblas_get_local_handle(); - cham_uplo_t uplo; - cham_trans_t trans; - int n; - int k; - hipblasDoubleComplex alpha; + struct cl_zher2k_args_s *clargs = (struct cl_zher2k_args_s *)cl_arg; CHAM_tile_t *tileA; CHAM_tile_t *tileB; - double beta; CHAM_tile_t *tileC; tileA = cti_interface_get(descr[0]); tileB = cti_interface_get(descr[1]); tileC = cti_interface_get(descr[2]); - starpu_codelet_unpack_args(cl_arg, &uplo, &trans, &n, &k, &alpha, &beta); - - HIP_zher2k( uplo, trans, - n, k, - &alpha, tileA->mat, tileA->ld, - tileB->mat, tileB->ld, - &beta, tileC->mat, tileC->ld, - handle ); + HIP_zher2k( clargs->uplo, clargs->trans, + clargs->n, clargs->k, + (hipblasDoubleComplex*)&(clargs->alpha), tileA->mat, tileA->ld, + tileB->mat, tileB->ld, + &(clargs->beta), tileC->mat, tileC->ld, + handle ); } #endif /* defined(CHAMELEON_USE_HIP) */ #endif /* !defined(CHAMELEON_SIMULATION) */ @@ -117,18 +107,13 @@ CODELETS_GPU( zher2k, cl_zher2k_cpu_func, cl_zher2k_hip_func, STARPU_HIP_ASYNC ) CODELETS( zher2k, cl_zher2k_cpu_func, cl_zher2k_cuda_func, STARPU_CUDA_ASYNC ) #endif -/** - * - * @ingroup INSERT_TASK_Complex64_t - * - */ -void -INSERT_TASK_zher2k( const RUNTIME_option_t *options, - cham_uplo_t uplo, cham_trans_t trans, - int n, int k, int nb, - CHAMELEON_Complex64_t alpha, const CHAM_desc_t *A, int Am, int An, - const CHAM_desc_t *B, int Bm, int Bn, - double beta, const CHAM_desc_t *C, int Cm, int Cn ) +#if defined(CHAMELEON_STARPU_USE_INSERT) +void INSERT_TASK_zher2k( const RUNTIME_option_t *options, + cham_uplo_t uplo, cham_trans_t trans, + int n, int k, int nb, + CHAMELEON_Complex64_t alpha, const CHAM_desc_t *A, int Am, int An, + const CHAM_desc_t *B, int Bm, int Bn, + double beta, const CHAM_desc_t *C, int Cm, int Cn ) { if ( alpha == 0. ) { INSERT_TASK_zlascal( options, uplo, n, n, nb, @@ -136,30 +121,139 @@ INSERT_TASK_zher2k( const RUNTIME_option_t *options, return; } - (void)nb; - struct starpu_codelet *codelet = &cl_zher2k; - void (*callback)(void*) = options->profiling ? cl_zher2k_callback : NULL; - int accessC = ( beta == 0. ) ? STARPU_W : STARPU_RW; + void (*callback)(void*); + struct cl_zher2k_args_s *clargs = NULL; + int exec = 0; + const char *cl_name = "zher2k"; + int accessC; CHAMELEON_BEGIN_ACCESS_DECLARATION; CHAMELEON_ACCESS_R(A, Am, An); CHAMELEON_ACCESS_R(B, Bm, Bn); CHAMELEON_ACCESS_RW(C, Cm, Cn); + exec = __chameleon_need_exec; CHAMELEON_END_ACCESS_DECLARATION; + if ( exec ) { + clargs = malloc( sizeof( struct cl_zher2k_args_s ) ); + clargs->uplo = uplo; + clargs->trans = trans; + clargs->n = n; + clargs->k = k; + clargs->alpha = alpha; + clargs->beta = beta; + } + + /* Callback fro profiling information */ + callback = options->profiling ? cl_zher2k_callback : NULL; + + /* Reduce the C access if needed */ + accessC = ( beta == 0. ) ? STARPU_W : STARPU_RW; + + /* Refine name */ + cl_name = chameleon_codelet_name( cl_name, 3, + A->get_blktile( A, Am, An ), + B->get_blktile( B, Bm, Bn ), + C->get_blktile( C, Cm, Cn ) ); + rt_starpu_insert_task( - codelet, - STARPU_VALUE, &uplo, sizeof(int), - STARPU_VALUE, &trans, sizeof(int), - STARPU_VALUE, &n, sizeof(int), - STARPU_VALUE, &k, sizeof(int), - STARPU_VALUE, &alpha, sizeof(CHAMELEON_Complex64_t), - STARPU_R, RTBLKADDR(A, ChamComplexDouble, Am, An), - STARPU_R, RTBLKADDR(B, ChamComplexDouble, Bm, Bn), - STARPU_VALUE, &beta, sizeof(double), - accessC, RTBLKADDR(C, ChamComplexDouble, Cm, Cn), - STARPU_PRIORITY, options->priority, - STARPU_CALLBACK, callback, + &cl_zher2k, + /* Task codelet arguments */ + STARPU_CL_ARGS, clargs, sizeof(struct cl_zher2k_args_s), + STARPU_R, RTBLKADDR(A, ChamComplexDouble, Am, An), + STARPU_R, RTBLKADDR(B, ChamComplexDouble, Bm, Bn), + accessC, RTBLKADDR(C, ChamComplexDouble, Cm, Cn), + + /* Common task arguments */ + STARPU_PRIORITY, options->priority, + STARPU_CALLBACK, callback, STARPU_EXECUTE_ON_WORKER, options->workerid, + STARPU_POSSIBLY_PARALLEL, options->parallel, + STARPU_NAME, cl_name, 0 ); + + (void)nb; +} + +#else + +void INSERT_TASK_zher2k( const RUNTIME_option_t *options, + cham_uplo_t uplo, cham_trans_t trans, + int n, int k, int nb, + CHAMELEON_Complex64_t alpha, const CHAM_desc_t *A, int Am, int An, + const CHAM_desc_t *B, int Bm, int Bn, + double beta, const CHAM_desc_t *C, int Cm, int Cn ) +{ + if ( alpha == 0. ) { + INSERT_TASK_zlascal( options, uplo, n, n, nb, + beta, C, Cm, Cn ); + return; + } + + INSERT_TASK_COMMON_PARAMETERS( zher2k, 3 ); + int accessC; + + /* Reduce the C access if needed */ + accessC = ( beta == (double)0. ) ? STARPU_W : STARPU_RW; + + /* + * Set the data handles and initialize exchanges if needed + */ + starpu_cham_exchange_init_params( options, ¶ms, C->get_rankof( C, Cm, Cn ) ); + starpu_cham_exchange_data_before_execution( options, params, &nbdata, descrs, A, Am, An, STARPU_R ); + starpu_cham_exchange_data_before_execution( options, params, &nbdata, descrs, B, Bm, Bn, STARPU_R ); + starpu_cham_exchange_data_before_execution( options, params, &nbdata, descrs, C, Cm, Cn, accessC ); + + /* + * Not involved, let's return + */ + if ( nbdata == 0 ) { + return; + } + + if ( params.do_execute ) + { + int ret; + struct starpu_task *task = starpu_task_create(); + task->cl = cl; + + /* Set codelet parameters */ + clargs = malloc( sizeof( struct cl_zher2k_args_s ) ); + clargs->uplo = uplo; + clargs->trans = trans; + clargs->n = n; + clargs->k = k; + clargs->alpha = alpha; + clargs->beta = beta; + + task->cl_arg = clargs; + task->cl_arg_size = sizeof( struct cl_zher2k_args_s ); + task->cl_arg_free = 1; + + /* Set common parameters */ + starpu_cham_task_set_options( options, task, nbdata, descrs, cl_zher2k_callback ); + + /* Flops */ + task->flops = flops_zher2k( k, n ); + + /* Refine name */ + task->name = chameleon_codelet_name( cl_name, 3, + A->get_blktile( A, Am, An ), + B->get_blktile( B, Bm, Bn ), + C->get_blktile( C, Cm, Cn ) ); + + ret = starpu_task_submit( task ); + if ( ret == -ENODEV ) { + task->destroy = 0; + starpu_task_destroy( task ); + chameleon_error( "INSERT_TASK_zher2k", "Failed to submit the task to StarPU" ); + return; + } + } + + starpu_cham_task_exchange_data_after_execution( options, params, nbdata, descrs ); + + (void)nb; } + +#endif