diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c index adbce6c0c281a6a93a0ada8a909897ba21e77c2c..964097662c7381be1ce69fe30d543881d7eb81e9 100644 --- a/compute/pzgetrf.c +++ b/compute/pzgetrf.c @@ -22,6 +22,7 @@ * */ #include "control/common.h" +#include "include/chameleon/flops.h" #define A(m,n) A, m, n #define U(m,n) &(ws->U), m, n @@ -29,6 +30,31 @@ #define Wu(m,n) &(ws->Wu), m, n #define Wl(m,n) &(ws->Wl), m, n +static inline void +chameleon_pzgetrf_batch_size( struct chameleon_pzgetrf_s *ws, + int m, int nb, int j, int k ) +{ + if ( ws->batch_adaptive == 0 ) { + ws->batch_size = ( ( j % ws->ib ) != 0 ) ? ws->batch_size_blas2 : ws->batch_size_blas3; + return; + } + int task_left = m / nb - k; + int batch_max = chameleon_min( CHAMELEON_BATCH_SIZE, task_left ); + double flops = flops_zgetrf_blocked_offdiag( nb, nb, j, ws->ib ); + ws->batch_size = batch_max; + if ( j != 0 ) { + ws->batch_size = chameleon_min( chameleon_max( ws->flops_min / flops, 1 ), batch_max ); + } + + if ( task_left % ws->batch_size != 0 ) { + ws->batch_size = chameleon_min( chameleon_ceil( task_left, task_left / ws->batch_size ), batch_max ); + } + + // batch_mathieu = chameleon_ceil( task_left, + // chameleon_max( chameleon_ceil( task_left, batch_max ), + // ( task_left * flops ) / ws->flops_th ) ); +} + /* * All the functions below are panel factorization variant. * The parameters are: @@ -302,6 +328,7 @@ chameleon_pzgetrf_panel_facto_blocked_batched( struct chameleon_pzgetrf_s *ws, for ( h = 0; h < hmax; h++ ) { j = h + b * ws->ib; + chameleon_pzgetrf_batch_size( ws, A->m, A->nb, j, k ); for ( m = k; m < A->mt; m++ ) { tempmm = A->get_blkdim( A, m, DIM_m, A->m ); INSERT_TASK_zgetrf_panel_blocked_batched( options, tempmm, tempkn, j, m * A->mb, diff --git a/compute/zgetrf.c b/compute/zgetrf.c index d69abaecb3ea9de4f21b5b34c5789f44f4b1380f..2ea2b15d87569f58ac0f8f75bb525fbc550e76a1 100644 --- a/compute/zgetrf.c +++ b/compute/zgetrf.c @@ -114,6 +114,8 @@ CHAMELEON_zgetrf_WS_Alloc( const CHAM_desc_t *A ) chameleon_cleanenv( allreduce ); } + ws->batch_size = -1; + ws->batch_adaptive = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_BATCH_ADAPTIVE", 0 ); batch_size = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_BATCH_SIZE", 0 ); if ( batch_size > CHAMELEON_BATCH_SIZE ) { chameleon_warning( "CHAMELEON_BATCH_SIZE", "CHAMELEON_GETRF_BATCH_SIZE must be smaller than CHAMELEON_BATCH_SIZE, please recompile with the right CHAMELEON_BATCH_SIZE, or reduce the CHAMELEON_GETRF_BATCH_SIZE value\n" ); @@ -124,6 +126,7 @@ CHAMELEON_zgetrf_WS_Alloc( const CHAM_desc_t *A ) ws->batch_size_blas3 = ( ws->batch_size_blas3 > CHAMELEON_BATCH_SIZE ) ? CHAMELEON_BATCH_SIZE : ws->batch_size_blas3; ws->batch_size_swap = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_BATCH_SIZE_SWAP", batch_size ); ws->batch_size_swap = ( ws->batch_size_swap > CHAMELEON_BATCH_SIZE ) ? CHAMELEON_BATCH_SIZE : ws->batch_size_swap; + ws->flops_min = chameleon_max( chameleon_getenv_get_value_int( "CHAMELEON_GETRF_FLOPS_MIN_BATCH", 26e6 ), 1 ); ws->ringswitch = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_RINGSWITCH", INT_MAX ); diff --git a/control/compute_z.h b/control/compute_z.h index 812af3dce918e74926506810d38b2db8fc167e33..09de3d0fe02301b5b4e73da47d14751498d58d65 100644 --- a/control/compute_z.h +++ b/control/compute_z.h @@ -47,9 +47,12 @@ struct chameleon_pzgetrf_s { cham_getrf_t alg; cham_getrf_allreduce_t alg_allreduce; int ib; /**< Internal blocking parameter */ + int batch_adaptive; /**< Whether to use adaptative batch or not */ + int batch_size; /**< Batch size */ int batch_size_blas2; /**< Batch size for the blas 2 operations of the panel factorization */ int batch_size_blas3; /**< Batch size for the blas 3 operations of the panel factorization */ int batch_size_swap; /**< Batch size for the permutation */ + int flops_min; int ringswitch; /**< Define when to switch to ring bcast */ CHAM_desc_t U; CHAM_desc_t Up; /**< Workspace used for the panel factorization */ diff --git a/runtime/starpu/codelets/codelet_zgetrf_batched.c b/runtime/starpu/codelets/codelet_zgetrf_batched.c index 0ff4ed9854228109928e30ae4b34013338a32a5c..f8b88bc733ae549f819132044f4e42cbb44d9a76 100644 --- a/runtime/starpu/codelets/codelet_zgetrf_batched.c +++ b/runtime/starpu/codelets/codelet_zgetrf_batched.c @@ -320,7 +320,7 @@ INSERT_TASK_zgetrf_panel_blocked_batched( const RUNTIME_option_t *options, #endif struct chameleon_pzgetrf_s *tmp = (struct chameleon_pzgetrf_s *) ws; int ib = tmp->ib; - int batch_size = ( (h % ib) != 0 ) ? tmp->batch_size_blas2 : tmp->batch_size_blas3; + int batch_size = tmp->batch_size; int task_num = 0; struct cl_zgetrf_batched_args_s *clargs = *clargs_ptr;