diff --git a/runtime/starpu/codelets/codelet_zperm_allreduce.c b/runtime/starpu/codelets/codelet_zperm_allreduce.c index c21490d02f42447b07e8516842b2b9b840850006..ab9cf702294f7a54348b8d7995f45aca5afc32e3 100644 --- a/runtime/starpu/codelets/codelet_zperm_allreduce.c +++ b/runtime/starpu/codelets/codelet_zperm_allreduce.c @@ -20,6 +20,110 @@ #include <coreblas/cblas_wrapper.h> #if defined(CHAMELEON_USE_MPI) +struct cl_redux_args_t { + int tempmm; + int n; + int p; + int q; + int p_first; + int me; + int shift; + int np_inv; +}; + +static void +cl_zperm_allreduce_cpu_func( void *descr[], void *cl_arg ) +{ + struct cl_redux_args_t *clargs = (struct cl_redux_args_t *) cl_arg; + const CHAM_tile_t *tileUinout = cti_interface_get( descr[0] ); + const CHAM_tile_t *tileUin = cti_interface_get( descr[1] ); + const int *perm = (int *)STARPU_VECTOR_GET_PTR( descr[2] ); + CHAMELEON_Complex64_t *Uinout = CHAM_tile_get_ptr( tileUinout ); + const CHAMELEON_Complex64_t *Uin = CHAM_tile_get_ptr( tileUin ); + + int tempmm = clargs->tempmm; + int n = clargs->n; + int p = clargs->p; + int q = clargs->q; + int p_first = clargs->p_first / q; + int shift = clargs->shift; + int np = clargs->np_inv; + int me = ( p <= np ) ? clargs->me / q : ( ( clargs->me / q ) - p_first + p ) % p; + int nb = tileUinout->n; + int mb = tileUinout->m; + int first = me - 2 * shift + 1; + int last = me - shift; + int i, m, ownerp; + + for ( i = 0; i < tempmm; i++ ) { + m = perm[ i ] / mb; + ownerp = ( p <= np ) ? ( (m % p) * q + (n % q) ) / q : ( ( (m % p) * q + (n % q) ) / q - p_first + p ) % p; + + if ( ( (first <= ownerp) && (ownerp <= last ) ) || + ( (first+np <= ownerp) && (ownerp <= last+np) ) ) + { + cblas_zcopy( nb, Uin + i, tileUin->ld, + Uinout + i, tileUinout->ld ); + } + } +} + +CODELETS_CPU( zperm_allreduce, cl_zperm_allreduce_cpu_func ) + +static void +INSERT_TASK_zperm_allreduce_send( const RUNTIME_option_t *options, + CHAM_desc_t *U, + int me, + int dst, + int n ) +{ + rt_starpu_insert_task( + NULL, + STARPU_EXECUTE_ON_NODE, dst, + STARPU_R, RTBLKADDR(U, CHAMELEON_Complex64_t, me, n), + STARPU_PRIORITY, options->priority, + 0 ); +} + +static void +INSERT_TASK_zperm_allreduce_recv( const RUNTIME_option_t *options, + CHAM_desc_t *U, + CHAM_ipiv_t *ipiv, + int ipivk, + int me, + int src, + int n, + int tempmm, + int p, + int q, + int shift, + int np, + int p_first ) +{ + struct cl_redux_args_t *clargs; + clargs = malloc( sizeof( struct cl_redux_args_t ) ); + clargs->tempmm = tempmm; + clargs->n = n; + clargs->p = p; + clargs->q = q; + clargs->p_first = p_first; + clargs->me = me; + clargs->shift = shift; + clargs->np_inv = np; + + rt_starpu_insert_task( + &cl_zperm_allreduce, + STARPU_CL_ARGS, clargs, sizeof(struct cl_redux_args_t), + STARPU_RW, RTBLKADDR(U, CHAMELEON_Complex64_t, me, n), + STARPU_R, RTBLKADDR(U, CHAMELEON_Complex64_t, src, n), + STARPU_R, RUNTIME_perm_getaddr( ipiv, ipivk ), + STARPU_EXECUTE_ON_NODE, me, + STARPU_EXECUTE_ON_WORKER, options->workerid, + STARPU_PRIORITY, options->priority, + 0 ); + starpu_mpi_cache_flush( options->sequence->comm, RTBLKADDR(U, CHAMELEON_Complex64_t, src, n) ); +} + void INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options, const CHAM_desc_t *A, @@ -33,6 +137,36 @@ INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options, void *ws ) { struct chameleon_pzgetrf_s *tmp = (struct chameleon_pzgetrf_s *)ws; + int *proc_involved = tmp->proc_involved; + int np_involved = chameleon_min( A->p, A->mt - k); + int np_iter = np_involved; + int p_recv, p_send, me, p_first; + int shift = 1; + + if ( np_involved == 1 ) { + assert( proc_involved[0] == A->myrank ); + } + else { + p_first = proc_involved[0]; + for( me = 0; me < np_involved; me++ ) { + if ( proc_involved[me] == A->myrank ) { + break; + } + } + assert( me < np_involved ); + while ( np_iter > 1 ) { + p_send = proc_involved[ ( me + shift ) % np_involved ]; + p_recv = proc_involved[ ( me - shift + np_involved ) % np_involved ]; + + INSERT_TASK_zperm_allreduce_send( options, U, A->myrank, p_send, n ); + INSERT_TASK_zperm_allreduce_recv( options, U, ipiv, ipivk, A->myrank, p_recv, + n, k == (A->mt-1) ? A->m - k * A->mb : A->mb, + A->p, A->q, shift, np_involved, p_first ); + + shift = shift << 1; + np_iter = chameleon_ceil( np_iter, 2 ); + } + } } void