diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c index ca7f4d1208ed81c0d613d822ee03293a1c4c0bdb..98f9d047097e646c3eb54287e678684e0218f586 100644 --- a/compute/pzgetrf.c +++ b/compute/pzgetrf.c @@ -350,7 +350,7 @@ chameleon_pzgetrf_panel_facto( struct chameleon_pzgetrf_s *ws, break; case ChamGetrfPPivPerColumn: - if ( ws->batch_size > 0 ) { + if ( ws->batch_size_blas2 > 0 ) { chameleon_pzgetrf_panel_facto_percol_batched( ws, A, ipiv, k, options ); } else { @@ -359,7 +359,7 @@ chameleon_pzgetrf_panel_facto( struct chameleon_pzgetrf_s *ws, break; case ChamGetrfPPiv: - if ( ws->batch_size > 0 ) { + if ( ws->batch_size_blas2 > 0 ) { chameleon_pzgetrf_panel_facto_blocked_batched( ws, A, ipiv, k, options ); } else { @@ -583,7 +583,7 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws, tempkm = A->get_blkdim( A, k, DIM_m, A->m ); tempnn = A->get_blkdim( A, n, DIM_n, A->n ); - if ( ws->batch_size > 0 ) { + if ( ws->batch_size_swap > 0 ) { chameleon_pzgetrf_panel_permute_batched( ws, A, ipiv, k, n, options ); } else { diff --git a/compute/zgetrf.c b/compute/zgetrf.c index 976ba2ad5e60bbeac851664e1dac8b2d991c4be7..254020a55c478dcb6982d5b002fff2d5e69c9902 100644 --- a/compute/zgetrf.c +++ b/compute/zgetrf.c @@ -113,10 +113,20 @@ CHAMELEON_zgetrf_WS_Alloc( const CHAM_desc_t *A ) chameleon_cleanenv( allreduce ); } - ws->batch_size = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_BATCH_SIZE", 0 ); - if ( ws->batch_size > CHAMELEON_BATCH_SIZE ) { - chameleon_warning( "CHAMELEON_BATCH_SIZE", "CHAMELEON_GETRF_BATCH_SIZE must be smaller than CHAMELEON_BATCH_SIZE, please recompile with the right CHAMELEON_BATCH_SIZE, or reduce the CHAMELEON_GETRF_BATCH_SIZE value\n" ); - ws->batch_size = CHAMELEON_BATCH_SIZE; + ws->batch_size_blas2 = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_BATCH_SIZE_BLAS2", 0 ); + if ( ws->batch_size_blas2 > CHAMELEON_BATCH_SIZE ) { + chameleon_warning( "CHAMELEON_BATCH_SIZE", "CHAMELEON_GETRF_BATCH_SIZE_BLAS2 must be smaller than CHAMELEON_BATCH_SIZE, please recompile with the right CHAMELEON_BATCH_SIZE, or reduce the CHAMELEON_GETRF_BATCH_SIZE_BLAS2 value\n" ); + ws->batch_size_blas2 = CHAMELEON_BATCH_SIZE; + } + ws->batch_size_blas3 = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_BATCH_SIZE_BLAS3", 0 ); + if ( ws->batch_size_blas3 > CHAMELEON_BATCH_SIZE ) { + chameleon_warning( "CHAMELEON_BATCH_SIZE", "CHAMELEON_GETRF_BATCH_SIZE_BLAS3 must be smaller than CHAMELEON_BATCH_SIZE, please recompile with the right CHAMELEON_BATCH_SIZE, or reduce the CHAMELEON_GETRF_BATCH_SIZE_BLAS3 value\n" ); + ws->batch_size_blas3 = CHAMELEON_BATCH_SIZE; + } + ws->batch_size_swap = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_BATCH_SIZE_SWAP", 0 ); + if ( ws->batch_size_swap > CHAMELEON_BATCH_SIZE ) { + chameleon_warning( "CHAMELEON_BATCH_SIZE", "CHAMELEON_GETRF_BATCH_SIZE_SWAP must be smaller than CHAMELEON_BATCH_SIZE, please recompile with the right CHAMELEON_BATCH_SIZE, or reduce the CHAMELEON_GETRF_BATCH_SIZE_SWAP value\n" ); + ws->batch_size_swap = CHAMELEON_BATCH_SIZE; } ws->ringswitch = chameleon_getenv_get_value_int( "CHAMELEON_GETRF_RINGSWITCH", INT_MAX ); diff --git a/control/compute_z.h b/control/compute_z.h index 855f820b16f0d5427ce9d00fba7557b5d47a3a86..1229a1797915be3d358cf018f3b779bca8aeefe2 100644 --- a/control/compute_z.h +++ b/control/compute_z.h @@ -46,7 +46,9 @@ struct chameleon_pzgetrf_s { cham_getrf_t alg; cham_getrf_allreduce_t alg_allreduce; int ib; /**< Internal blocking parameter */ - int batch_size; /**< Batch size for the panel */ + int batch_size_blas2; /**< Batch size for the blas 2 operations of the panel factorization */ + int batch_size_blas3; /**< Batch size for the blas 3 operations of the panel factorization */ + int batch_size_swap; /**< Batch size for the permutation */ int ringswitch; /**< Define when to switch to ring bcast */ CHAM_desc_t U; CHAM_desc_t Up; /**< Workspace used for the panel factorization */ diff --git a/runtime/starpu/codelets/codelet_zgetrf_batched.c b/runtime/starpu/codelets/codelet_zgetrf_batched.c index 2e04493df242f90fb18499abfab703724e90d197..011785aa2f459629adaabe457e77a908786ac14d 100644 --- a/runtime/starpu/codelets/codelet_zgetrf_batched.c +++ b/runtime/starpu/codelets/codelet_zgetrf_batched.c @@ -74,7 +74,7 @@ INSERT_TASK_zgetrf_panel_offdiag_batched( const RUNTIME_option_t *options, CHAM_ipiv_t *ipiv ) { int task_num = 0; - int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size; + int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size_blas2; void (*callback)(void*) = NULL; struct cl_getrf_batched_args_t *clargs = *clargs_ptr; int rankA = A->get_rankof( A, Am, An ); @@ -241,8 +241,9 @@ INSERT_TASK_zgetrf_panel_blocked_batched( const RUNTIME_option_t *options, void **clargs_ptr, CHAM_ipiv_t *ipiv ) { - int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size; - int ib = ((struct chameleon_pzgetrf_s *)ws)->ib; + struct chameleon_pzgetrf_s *tmp = (struct chameleon_pzgetrf_s *) ws; + int ib = tmp->ib; + int batch_size = ( (h % ib) != 0 ) ? tmp->batch_size_blas2 : tmp->batch_size_blas3; int task_num = 0; void (*callback)(void*) = NULL; int accessU, access_npiv, access_ipiv, access_ppiv; diff --git a/runtime/starpu/codelets/codelet_zlaswp_batched.c b/runtime/starpu/codelets/codelet_zlaswp_batched.c index b17f26a486dc87e5d8dcb807369bfa431e809b06..303e6a674b564a9fbe3833931a5190af9e8ed136 100644 --- a/runtime/starpu/codelets/codelet_zlaswp_batched.c +++ b/runtime/starpu/codelets/codelet_zlaswp_batched.c @@ -72,7 +72,7 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options, void **clargs_ptr ) { int task_num = 0; - int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size; + int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size_swap; int nhandles; struct cl_laswp_batched_args_t *clargs = *clargs_ptr; if ( Am->get_rankof( Am, Amm, Amn) != Am->myrank ) {