From a4624e9d4bd72049c978b9979fc31a50af5feec5 Mon Sep 17 00:00:00 2001 From: Alycia Lisito <alycia.lisito@inria.fr> Date: Wed, 5 Mar 2025 11:22:20 +0100 Subject: [PATCH] starpu/codelet: Add new task submit to codelet_zlaswp_batched.c --- .../starpu/codelets/codelet_zlaswp_batched.c | 102 ++++++++++++++---- 1 file changed, 80 insertions(+), 22 deletions(-) diff --git a/runtime/starpu/codelets/codelet_zlaswp_batched.c b/runtime/starpu/codelets/codelet_zlaswp_batched.c index f43a68947..8cc2a3adc 100644 --- a/runtime/starpu/codelets/codelet_zlaswp_batched.c +++ b/runtime/starpu/codelets/codelet_zlaswp_batched.c @@ -18,7 +18,7 @@ #include "chameleon_starpu_internal.h" #include "runtime_codelet_z.h" -struct cl_laswp_batched_args_t { +struct cl_zlaswp_batched_args_s { int tasks_nbr; int minmn; int m0[CHAMELEON_BATCH_SIZE]; @@ -32,7 +32,7 @@ cl_zlaswp_batched_cpu_func( void *descr[], { int i, m0, minmn, *perm, *invp; CHAM_tile_t *A, *U, *B; - struct cl_laswp_batched_args_t *clargs = ( struct cl_laswp_batched_args_t * ) cl_arg; + struct cl_zlaswp_batched_args_s *clargs = ( struct cl_zlaswp_batched_args_s * ) cl_arg; minmn = clargs->minmn; perm = (int *)STARPU_VECTOR_GET_PTR( descr[0] ); @@ -73,14 +73,13 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options, { int task_num = 0; int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size_swap; - int nhandles; - struct cl_laswp_batched_args_t *clargs = *clargs_ptr; + struct cl_zlaswp_batched_args_s *clargs = *clargs_ptr; if ( Am->get_rankof( Am, Amm, Amn) != Am->myrank ) { return; } if( clargs == NULL ) { - clargs = malloc( sizeof( struct cl_laswp_batched_args_t ) ) ; + clargs = malloc( sizeof( struct cl_zlaswp_batched_args_s ) ) ; clargs->tasks_nbr = 0; clargs->minmn = minmn; *clargs_ptr = clargs; @@ -93,24 +92,12 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options, clargs->tasks_nbr ++; if ( clargs->tasks_nbr == batch_size ) { - nhandles = clargs->tasks_nbr; - rt_starpu_insert_task( - &cl_zlaswp_batched, - STARPU_CL_ARGS, clargs, sizeof(struct cl_laswp_batched_args_t), - STARPU_R, RUNTIME_perm_getaddr( ipiv, ipivk ), - STARPU_R, RUNTIME_invp_getaddr( ipiv, ipivk ), - STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, Um, Un), - STARPU_R, RTBLKADDR(Ak, ChamComplexDouble, Akm, Akn), - STARPU_DATA_MODE_ARRAY, clargs->handle_mode, nhandles, - STARPU_PRIORITY, options->priority, - STARPU_EXECUTE_ON_WORKER, options->workerid, - 0 ); - - /* clargs is freed by starpu. */ - *clargs_ptr = NULL; + INSERT_TASK_zlaswp_batched_flush( options, ipiv, ipivk, Ak, Akm, Akn, U, Um, Un, clargs_ptr ); } } +#if defined(CHAMELEON_STARPU_USE_INSERT) + void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options, const CHAM_ipiv_t *ipiv, int ipivk, @@ -122,7 +109,7 @@ void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options, int Un, void **clargs_ptr ) { - struct cl_laswp_batched_args_t *clargs = *clargs_ptr; + struct cl_zlaswp_batched_args_s *clargs = *clargs_ptr; int nhandles; if( clargs == NULL ) { @@ -132,7 +119,7 @@ void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options, nhandles = clargs->tasks_nbr; rt_starpu_insert_task( &cl_zlaswp_batched, - STARPU_CL_ARGS, clargs, sizeof(struct cl_laswp_batched_args_t), + STARPU_CL_ARGS, clargs, sizeof(struct cl_zlaswp_batched_args_s), STARPU_R, RUNTIME_perm_getaddr( ipiv, ipivk ), STARPU_R, RUNTIME_invp_getaddr( ipiv, ipivk ), STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, Um, Un), @@ -145,3 +132,74 @@ void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options, /* clargs is freed by starpu. */ *clargs_ptr = NULL; } + +#else /* defined(CHAMELEON_STARPU_USE_INSERT) */ + +void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options, + const CHAM_ipiv_t *ipiv, + int ipivk, + const CHAM_desc_t *Ak, + int Akm, + int Akn, + const CHAM_desc_t *U, + int Um, + int Un, + void **clargs_ptr ) +{ + int ret, k; + struct starpu_task *task; + struct cl_zlaswp_batched_args_s *myclargs = *clargs_ptr; + + if( myclargs == NULL ) { + return; + } + + INSERT_TASK_COMMON_PARAMETERS( zlaswp_batched, myclargs->tasks_nbr + 4 ); + + /* + * Register the data handles, might need to receive perm and invp + */ + starpu_cham_exchange_init_params( options, ¶ms, Ak->myrank ); + starpu_cham_exchange_handle_before_execution( options, ¶ms, &nbdata, descrs, + RUNTIME_perm_getaddr( ipiv, ipivk ), + STARPU_R ); + starpu_cham_exchange_handle_before_execution( options, ¶ms, &nbdata, descrs, + RUNTIME_invp_getaddr( ipiv, ipivk ), + STARPU_R ); + starpu_cham_register_descr( &nbdata, descrs, RTBLKADDR( U, ChamComplexDouble, Um, Un ), + STARPU_RW | STARPU_COMMUTE ); + starpu_cham_register_descr( &nbdata, descrs, RTBLKADDR( Ak, ChamComplexDouble, Akm, Akn ), STARPU_R ); + for ( k = 0; k < myclargs->tasks_nbr; k++ ) { + starpu_cham_register_descr( &nbdata, descrs, myclargs->handle_mode[ k ].handle, STARPU_RW ); + } + + task = starpu_task_create(); + task->cl = cl; + + /* Set codelet parameters */ + task->cl_arg = myclargs; + task->cl_arg_size = sizeof( struct cl_zlaswp_batched_args_s ); + task->cl_arg_free = 1; + + /* Set common parameters */ + starpu_cham_task_set_options( options, task, nbdata, descrs, NULL ); + + /* Flops */ + task->flops = 0.; + + ret = starpu_task_submit( task ); + if ( ret == -ENODEV ) { + task->destroy = 0; + starpu_task_destroy( task ); + chameleon_error( "INSERT_TASK_zlaswp_batched", "Failed to submit the task to StarPU" ); + return; + } + starpu_cham_task_exchange_data_after_execution( options, params, nbdata, descrs ); + + /* clargs is freed by starpu. */ + *clargs_ptr = NULL; + (void)clargs; + (void)cl_name; +} + +#endif /* defined(CHAMELEON_STARPU_USE_INSERT) */ -- GitLab