From bec932d132646b9bedb2331aaa76950dcb534ae3 Mon Sep 17 00:00:00 2001 From: Alycia Lisito <alycia.lisito@inria.fr> Date: Thu, 6 Mar 2025 07:55:29 +0100 Subject: [PATCH] starpu/codelet: Add new task submit to codelet_zgetrf_batched --- .../starpu/codelets/codelet_zgetrf_batched.c | 287 ++++++++++++------ 1 file changed, 193 insertions(+), 94 deletions(-) diff --git a/runtime/starpu/codelets/codelet_zgetrf_batched.c b/runtime/starpu/codelets/codelet_zgetrf_batched.c index 4bb70d4b4..40a1b443c 100644 --- a/runtime/starpu/codelets/codelet_zgetrf_batched.c +++ b/runtime/starpu/codelets/codelet_zgetrf_batched.c @@ -24,7 +24,7 @@ #include "chameleon_starpu_internal.h" #include "runtime_codelet_z.h" -struct cl_getrf_batched_args_t { +struct cl_zgetrf_batched_args_s { const char *cl_name; int tasks_nbr; int diag; @@ -41,11 +41,11 @@ static void cl_zgetrf_panel_offdiag_batched_cpu_func( void *descr[], void *cl_arg ) { - struct cl_getrf_batched_args_t *clargs = (struct cl_getrf_batched_args_t *) cl_arg; - cppi_interface_t *nextpiv = (cppi_interface_t*) descr[ clargs->tasks_nbr ]; - cppi_interface_t *prevpiv = (cppi_interface_t*) descr[ clargs->tasks_nbr + 1 ]; - int i, m, n, h, m0, lda; - CHAM_tile_t *tileA; + struct cl_zgetrf_batched_args_s *clargs = (struct cl_zgetrf_batched_args_s *) cl_arg; + cppi_interface_t *nextpiv = (cppi_interface_t*) descr[ clargs->tasks_nbr ]; + cppi_interface_t *prevpiv = (cppi_interface_t*) descr[ clargs->tasks_nbr + 1 ]; + int i, m, n, h, m0, lda; + CHAM_tile_t *tileA; nextpiv->h = clargs->h; nextpiv->has_diag = chameleon_max( -1, nextpiv->has_diag ); @@ -73,19 +73,18 @@ INSERT_TASK_zgetrf_panel_offdiag_batched( const RUNTIME_option_t *options, void **clargs_ptr, CHAM_ipiv_t *ipiv ) { - int task_num = 0; - int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size_blas2; - void (*callback)(void*) = NULL; - struct cl_getrf_batched_args_t *clargs = *clargs_ptr; - int rankA = A->get_rankof( A, Am, An ); - if ( rankA != A->myrank ) { - return; - } #if !defined(HAVE_STARPU_NONE_NONZERO) /* STARPU_NONE can't be equal to 0 */ - fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" ); + fprintf( stderr, "INSERT_TASK_zgetrf_percol_offdiag_batched: STARPU_NONE can not be equal to 0\n" ); assert( 0 ); #endif + int task_num = 0; + int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size_blas2; + struct cl_zgetrf_batched_args_s *clargs = *clargs_ptr; + int rankA = A->get_rankof( A, Am, An ); + if ( rankA != A->myrank ) { + return; + } /* Handle cache */ CHAMELEON_BEGIN_ACCESS_DECLARATION; @@ -93,8 +92,8 @@ INSERT_TASK_zgetrf_panel_offdiag_batched( const RUNTIME_option_t *options, CHAMELEON_END_ACCESS_DECLARATION; if ( clargs == NULL ) { - clargs = malloc( sizeof( struct cl_getrf_batched_args_t ) ) ; - memset( clargs, 0, sizeof( struct cl_getrf_batched_args_t ) ); + clargs = malloc( sizeof( struct cl_zgetrf_batched_args_s ) ) ; + memset( clargs, 0, sizeof( struct cl_zgetrf_batched_args_s ) ); clargs->tasks_nbr = 0; clargs->h = h; clargs->cl_name = "zgetrf_panel_offdiag_batched"; @@ -114,39 +113,26 @@ INSERT_TASK_zgetrf_panel_offdiag_batched( const RUNTIME_option_t *options, A->get_blktile( A, Am, An ) ); if ( clargs->tasks_nbr == batch_size ) { - int access_npiv = ( h == ipiv->n ) ? STARPU_R : STARPU_REDUX; - int access_ppiv = ( h == 0 ) ? STARPU_NONE : STARPU_R; - rt_starpu_insert_task( - &cl_zgetrf_panel_offdiag_batched, - /* Task codelet arguments */ - STARPU_CL_ARGS, clargs, sizeof(struct cl_getrf_batched_args_t), - STARPU_DATA_MODE_ARRAY, clargs->handle_mode, clargs->tasks_nbr, - access_npiv, RUNTIME_pivot_getaddr( ipiv, rankA, An, h ), - access_ppiv, RUNTIME_pivot_getaddr( ipiv, rankA, An, h-1 ), - STARPU_PRIORITY, options->priority, - STARPU_CALLBACK, callback, - STARPU_EXECUTE_ON_WORKER, options->workerid, - 0 ); - - /* clargs is freed by starpu. */ - *clargs_ptr = NULL; + INSERT_TASK_zgetrf_panel_offdiag_batched_flush( options, A, An, clargs_ptr, ipiv ); } } +#if defined(CHAMELEON_STARPU_USE_INSERT) + void INSERT_TASK_zgetrf_panel_offdiag_batched_flush( const RUNTIME_option_t *options, CHAM_desc_t *A, int An, void **clargs_ptr, CHAM_ipiv_t *ipiv ) { - void (*callback)(void*) = NULL; - struct cl_getrf_batched_args_t *clargs = *clargs_ptr; - int rankA = A->myrank; #if !defined(HAVE_STARPU_NONE_NONZERO) /* STARPU_NONE can't be equal to 0 */ - fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" ); + fprintf( stderr, "INSERT_TASK_zgetrf_percol_offdiag_batched: STARPU_NONE can not be equal to 0\n" ); assert( 0 ); #endif + void (*callback)(void*) = NULL; + struct cl_zgetrf_batched_args_s *clargs = *clargs_ptr; + int rankA = A->myrank; if ( clargs == NULL ) { return; @@ -157,7 +143,7 @@ INSERT_TASK_zgetrf_panel_offdiag_batched_flush( const RUNTIME_option_t *options, rt_starpu_insert_task( &cl_zgetrf_panel_offdiag_batched, /* Task codelet arguments */ - STARPU_CL_ARGS, clargs, sizeof(struct cl_getrf_batched_args_t), + STARPU_CL_ARGS, clargs, sizeof(struct cl_zgetrf_batched_args_s), STARPU_DATA_MODE_ARRAY, clargs->handle_mode, clargs->tasks_nbr, access_npiv, RUNTIME_pivot_getaddr( ipiv, rankA, An, clargs->h ), access_ppiv, RUNTIME_pivot_getaddr( ipiv, rankA, An, clargs->h-1 ), @@ -171,12 +157,75 @@ INSERT_TASK_zgetrf_panel_offdiag_batched_flush( const RUNTIME_option_t *options, *clargs_ptr = NULL; } +#else /* defined(CHAMELEON_STARPU_USE_INSERT) */ + +void +INSERT_TASK_zgetrf_panel_offdiag_batched_flush( const RUNTIME_option_t *options, + CHAM_desc_t *A, int An, + void **clargs_ptr, + CHAM_ipiv_t *ipiv ) +{ + struct cl_zgetrf_batched_args_s *myclargs = *clargs_ptr; + int rankA = A->myrank; + int k, ret, access_npiv, access_ppiv; + struct starpu_task *task; + + if ( myclargs == NULL ) { + return; + } + + INSERT_TASK_COMMON_PARAMETERS_EXTENDED( zgetrf_panel_percol_offdiag_batched, zgetrf_panel_offdiag_batched, zgetrf_batched, myclargs->tasks_nbr + 2 ); + + access_npiv = ( myclargs->h == ipiv->n ) ? STARPU_R : STARPU_REDUX; + access_ppiv = ( myclargs->h == 0 ) ? STARPU_NONE : STARPU_R; + + /* + * Register the data handles, no exchange needed + */ + starpu_cham_exchange_init_params( options, ¶ms, rankA ); + for ( k = 0; k < myclargs->tasks_nbr; k++ ) { + starpu_cham_register_descr( &nbdata, descrs, myclargs->handle_mode[ k ].handle, STARPU_RW ); + } + starpu_cham_register_descr( &nbdata, descrs, RUNTIME_pivot_getaddr( ipiv, rankA, An, myclargs->h ), access_npiv ); + starpu_cham_register_descr( &nbdata, descrs, RUNTIME_pivot_getaddr( ipiv, rankA, An, myclargs->h-1 ), access_ppiv ); + + task = starpu_task_create(); + task->cl = cl; + + /* Set codelet parameters */ + task->cl_arg = myclargs; + task->cl_arg_size = sizeof( struct cl_zgetrf_batched_args_s ); + task->cl_arg_free = 1; + + /* Set common parameters */ + starpu_cham_task_set_options( options, task, nbdata, descrs, NULL ); + + /* Flops */ + // task->flops = TODO; + + ret = starpu_task_submit( task ); + if ( ret == -ENODEV ) { + task->destroy = 0; + starpu_task_destroy( task ); + chameleon_error( "INSERT_TASK_zgetrf_percol_diag", "Failed to submit the task to StarPU" ); + return; + } + starpu_cham_task_exchange_data_after_execution( options, params, nbdata, descrs ); + + /* clargs is freed by starpu. */ + *clargs_ptr = NULL; + (void)clargs; + (void)cl_name; +} + +#endif /* defined(CHAMELEON_STARPU_USE_INSERT) */ + #if !defined(CHAMELEON_SIMULATION) static void cl_zgetrf_panel_blocked_batched_cpu_func( void *descr[], void *cl_arg ) { - struct cl_getrf_batched_args_t *clargs = ( struct cl_getrf_batched_args_t * ) cl_arg; + struct cl_zgetrf_batched_args_s *clargs = ( struct cl_zgetrf_batched_args_s * ) cl_arg; int *ipiv; cppi_interface_t *nextpiv = (cppi_interface_t*) descr[clargs->tasks_nbr ]; cppi_interface_t *prevpiv = (cppi_interface_t*) descr[clargs->tasks_nbr + 1]; @@ -241,21 +290,19 @@ INSERT_TASK_zgetrf_panel_blocked_batched( const RUNTIME_option_t *options, void **clargs_ptr, CHAM_ipiv_t *ipiv ) { - struct chameleon_pzgetrf_s *tmp = (struct chameleon_pzgetrf_s *) ws; - int ib = tmp->ib; - int batch_size = ( (h % ib) != 0 ) ? tmp->batch_size_blas2 : tmp->batch_size_blas3; - int task_num = 0; - void (*callback)(void*) = NULL; - int accessU, access_npiv, access_ipiv, access_ppiv; - struct cl_getrf_batched_args_t *clargs = *clargs_ptr; - int rankA = A->get_rankof(A, Am, An); #if !defined(HAVE_STARPU_NONE_NONZERO) /* STARPU_NONE can't be equal to 0 */ - fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" ); + fprintf( stderr, "INSERT_TASK_zgetrf_panel_blocked_batched: STARPU_NONE can not be equal to 0\n" ); assert( 0 ); #endif + struct chameleon_pzgetrf_s *tmp = (struct chameleon_pzgetrf_s *) ws; + int ib = tmp->ib; + int batch_size = ( (h % ib) != 0 ) ? tmp->batch_size_blas2 : tmp->batch_size_blas3; + int task_num = 0; + struct cl_zgetrf_batched_args_s *clargs = *clargs_ptr; #if defined ( CHAMELEON_USE_MPI ) + int rankA = A->get_rankof(A, Am, An); if ( ( Am == An ) && ( h % ib == 0 ) && ( h > 0 ) ) { starpu_mpi_cache_flush( options->sequence->comm, RTBLKADDR(U, CHAMELEON_Complex64_t, Um, Un) ); @@ -277,8 +324,8 @@ INSERT_TASK_zgetrf_panel_blocked_batched( const RUNTIME_option_t *options, CHAMELEON_END_ACCESS_DECLARATION; if ( clargs == NULL ) { - clargs = malloc( sizeof( struct cl_getrf_batched_args_t ) ); - memset( clargs, 0, sizeof( struct cl_getrf_batched_args_t ) ); + clargs = malloc( sizeof( struct cl_zgetrf_batched_args_s ) ); + memset( clargs, 0, sizeof( struct cl_zgetrf_batched_args_s ) ); clargs->tasks_nbr = 0; clargs->diag = ( Am == An ); clargs->ib = ib; @@ -300,47 +347,12 @@ INSERT_TASK_zgetrf_panel_blocked_batched( const RUNTIME_option_t *options, A->get_blktile( A, Am, An ) ); if ( clargs->tasks_nbr == batch_size ) { - access_npiv = ( clargs->h == ipiv->n ) ? STARPU_R : STARPU_REDUX; - access_ipiv = STARPU_RW; - access_ppiv = STARPU_R; - accessU = STARPU_RW; - if ( clargs->h == 0 ) { - access_ipiv = STARPU_W; - access_ppiv = STARPU_NONE; - accessU = STARPU_NONE; - } - else if ( clargs->h % clargs->ib == 0 ) { - accessU = STARPU_R; - } - else if ( clargs->h % clargs->ib == 1 ) { - accessU = STARPU_W; - } - /* If there isn't a diag task then use offdiag access */ - if ( clargs->diag == 0 ) { - accessU = ((h%ib == 0) && (h > 0)) ? STARPU_R : STARPU_NONE; - access_ipiv = STARPU_NONE; - } - - rt_starpu_insert_task( - &cl_zgetrf_panel_blocked_batched, - /* Task codelet arguments */ - STARPU_CL_ARGS, clargs, sizeof(struct cl_getrf_batched_args_t), - STARPU_DATA_MODE_ARRAY, clargs->handle_mode, clargs->tasks_nbr, - access_npiv, RUNTIME_pivot_getaddr( ipiv, rankA, An, h ), - access_ppiv, RUNTIME_pivot_getaddr( ipiv, rankA, An, h-1 ), - access_ipiv, RUNTIME_ipiv_getaddr( ipiv, An ), - accessU, RTBLKADDR(U, CHAMELEON_Complex64_t, Um, Un ), - STARPU_PRIORITY, options->priority, - STARPU_CALLBACK, callback, - STARPU_EXECUTE_ON_WORKER, options->workerid, - STARPU_NAME, clargs->cl_name, - 0 ); - - /* clargs is freed by starpu. */ - *clargs_ptr = NULL; + INSERT_TASK_zgetrf_panel_blocked_batched_flush( options, A, An, U, Um, Un, clargs_ptr, ipiv ); } } +#if defined(CHAMELEON_STARPU_USE_INSERT) + void INSERT_TASK_zgetrf_panel_blocked_batched_flush( const RUNTIME_option_t *options, CHAM_desc_t *A, int An, @@ -348,15 +360,15 @@ INSERT_TASK_zgetrf_panel_blocked_batched_flush( const RUNTIME_option_t *options, void **clargs_ptr, CHAM_ipiv_t *ipiv ) { - int accessU, access_npiv, access_ipiv, access_ppiv; - void (*callback)(void*) = NULL; - struct cl_getrf_batched_args_t *clargs = *clargs_ptr; - int rankA = A->myrank; #if !defined(HAVE_STARPU_NONE_NONZERO) /* STARPU_NONE can't be equal to 0 */ - fprintf( stderr, "INSERT_TASK_zgetrf_percol_diag: STARPU_NONE can not be equal to 0\n" ); + fprintf( stderr, "INSERT_TASK_zgetrf_panel_blocked_batched: STARPU_NONE can not be equal to 0\n" ); assert( 0 ); #endif + int accessU, access_npiv, access_ipiv, access_ppiv; + void (*callback)(void*) = NULL; + struct cl_zgetrf_batched_args_s *clargs = *clargs_ptr; + int rankA = A->myrank; if ( clargs == NULL ) { return; @@ -386,7 +398,7 @@ INSERT_TASK_zgetrf_panel_blocked_batched_flush( const RUNTIME_option_t *options, rt_starpu_insert_task( &cl_zgetrf_panel_blocked_batched, /* Task codelet arguments */ - STARPU_CL_ARGS, clargs, sizeof(struct cl_getrf_batched_args_t), + STARPU_CL_ARGS, clargs, sizeof(struct cl_zgetrf_batched_args_s), STARPU_DATA_MODE_ARRAY, clargs->handle_mode, clargs->tasks_nbr, access_npiv, RUNTIME_pivot_getaddr( ipiv, rankA, An, clargs->h ), access_ppiv, RUNTIME_pivot_getaddr( ipiv, rankA, An, clargs->h - 1 ), @@ -401,3 +413,90 @@ INSERT_TASK_zgetrf_panel_blocked_batched_flush( const RUNTIME_option_t *options, /* clargs is freed by starpu. */ *clargs_ptr = NULL; } + +#else /* defined(CHAMELEON_STARPU_USE_INSERT) */ + +void +INSERT_TASK_zgetrf_panel_blocked_batched_flush( const RUNTIME_option_t *options, + CHAM_desc_t *A, int An, + CHAM_desc_t *U, int Um, int Un, + void **clargs_ptr, + CHAM_ipiv_t *ipiv ) +{ + struct cl_zgetrf_batched_args_s *myclargs = *clargs_ptr; + int rankA = A->myrank; + int accessU, access_npiv, access_ipiv, access_ppiv, k; + int ret; + struct starpu_task *task; + + if ( myclargs == NULL ) { + return; + } + + INSERT_TASK_COMMON_PARAMETERS_EXTENDED( zgetrf_panel_blocked_batched, zgetrf_panel_blocked_batched, zgetrf_batched, myclargs->tasks_nbr + 4 ); + + access_npiv = ( myclargs->h == ipiv->n ) ? STARPU_R : STARPU_REDUX; + access_ipiv = STARPU_RW; + access_ppiv = STARPU_R; + accessU = STARPU_RW; + if ( myclargs->h == 0 ) { + access_ipiv = STARPU_W; + access_ppiv = STARPU_NONE; + accessU = STARPU_NONE; + } + else if ( myclargs->h % myclargs->ib == 0 ) { + accessU = STARPU_R; + } + else if ( myclargs->h % myclargs->ib == 1 ) { + accessU = STARPU_W; + } + /* If there isn't a diag task then use offdiag access */ + if ( myclargs->diag == 0 ) { + accessU = ((myclargs->h%myclargs->ib == 0) && (myclargs->h > 0)) ? STARPU_R : STARPU_NONE; + access_ipiv = STARPU_NONE; + } + + /* + * Register the data handles, exchange needed only for U + */ + starpu_cham_exchange_init_params( options, ¶ms, rankA ); + for ( k = 0; k < myclargs->tasks_nbr; k++ ) { + starpu_cham_register_descr( &nbdata, descrs, myclargs->handle_mode[ k ].handle, STARPU_RW ); + } + starpu_cham_register_descr( &nbdata, descrs, RUNTIME_pivot_getaddr( ipiv, rankA, An, myclargs->h ), access_npiv ); + starpu_cham_register_descr( &nbdata, descrs, RUNTIME_pivot_getaddr( ipiv, rankA, An, myclargs->h-1 ), access_ppiv ); + starpu_cham_register_descr( &nbdata, descrs, RUNTIME_ipiv_getaddr( ipiv, An), access_ipiv ); + starpu_cham_exchange_handle_before_execution( options, ¶ms, &nbdata, descrs, + RTBLKADDR(U, CHAMELEON_Complex64_t, Um, Un), + accessU ); + + task = starpu_task_create(); + task->cl = cl; + + /* Set codelet parameters */ + task->cl_arg = myclargs; + task->cl_arg_size = sizeof( struct cl_zgetrf_batched_args_s ); + task->cl_arg_free = 1; + + /* Set common parameters */ + starpu_cham_task_set_options( options, task, nbdata, descrs, NULL ); + + /* Flops */ + // task->flops = TODO; + + ret = starpu_task_submit( task ); + if ( ret == -ENODEV ) { + task->destroy = 0; + starpu_task_destroy( task ); + chameleon_error( "INSERT_TASK_zgetrf_panel_blocked_batched", "Failed to submit the task to StarPU" ); + return; + } + starpu_cham_task_exchange_data_after_execution( options, params, nbdata, descrs ); + + /* clargs is freed by starpu. */ + *clargs_ptr = NULL; + (void)clargs; + (void)cl_name; +} + +#endif /* defined(CHAMELEON_STARPU_USE_INSERT) */ -- GitLab