diff --git a/compute/zgetrf.c b/compute/zgetrf.c index ba2a98c480d757c2457cc3c796c81680e78a9fe7..0d5ad1b8f6e5e45c57be55134a0e8e1a0ecf1573 100644 --- a/compute/zgetrf.c +++ b/compute/zgetrf.c @@ -448,6 +448,13 @@ CHAMELEON_zgetrf_Tile( CHAM_desc_t *A, CHAM_ipiv_t *IPIV ) CHAMELEON_Ipiv_Flush( IPIV, sequence ); chameleon_sequence_wait( chamctxt, sequence ); + +#if defined ( CHAMELEON_USE_MPI ) + if ( ((struct chameleon_pzgetrf_s *)ws)->alg_allreduce == ChamStarPUMPITasks ) { + INSERT_TASK_zperm_allreduce_tag_free( ); + } +#endif + CHAMELEON_zgetrf_WS_Free( ws ); status = sequence->status; diff --git a/include/chameleon/tasks_z.h b/include/chameleon/tasks_z.h index bf3831af524b2cbea33377a4ff8ab4ce1e124bb6..7b550f8e60f769dbe9e0eb76ad9720870f32b64e 100644 --- a/include/chameleon/tasks_z.h +++ b/include/chameleon/tasks_z.h @@ -760,4 +760,6 @@ void INSERT_TASK_zperm_allreduce_send_invp( const RUNTIME_option_t *options, int k, int n ); +void INSERT_TASK_zperm_allreduce_tag_free( ); + #endif /* _chameleon_tasks_z_h_ */ diff --git a/runtime/openmp/codelets/codelet_zperm_allreduce.c b/runtime/openmp/codelets/codelet_zperm_allreduce.c index 7aeb24faebda059ad96dec2819b8793d467eae05..e0a4d31c9f0c3e49510f1628cd343e9c43e599f2 100644 --- a/runtime/openmp/codelets/codelet_zperm_allreduce.c +++ b/runtime/openmp/codelets/codelet_zperm_allreduce.c @@ -91,3 +91,7 @@ INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options, (void)n; (void)ws; } + +void +INSERT_TASK_zperm_allreduce_tag_free( ) +{ } diff --git a/runtime/parsec/codelets/codelet_zperm_allreduce.c b/runtime/parsec/codelets/codelet_zperm_allreduce.c index 5acfa4a2b099785e7397807309d104d5421c34fb..95eb785822097f3ae6a5483773ce52a29dc96c8a 100644 --- a/runtime/parsec/codelets/codelet_zperm_allreduce.c +++ b/runtime/parsec/codelets/codelet_zperm_allreduce.c @@ -91,3 +91,7 @@ INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options, (void)n; (void)ws; } + +void +INSERT_TASK_zperm_allreduce_tag_free( ) +{ } diff --git a/runtime/quark/codelets/codelet_zperm_allreduce.c b/runtime/quark/codelets/codelet_zperm_allreduce.c index f6c5f98e6d59ed67db6ae9ca7dbe37abca31d617..608d12360b1207988d9abff8a9d422c0c1b92096 100644 --- a/runtime/quark/codelets/codelet_zperm_allreduce.c +++ b/runtime/quark/codelets/codelet_zperm_allreduce.c @@ -91,3 +91,7 @@ INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options, (void)n; (void)ws; } + +void +INSERT_TASK_zperm_allreduce_tag_free( ) +{ } diff --git a/runtime/starpu/codelets/codelet_zperm_allreduce.c b/runtime/starpu/codelets/codelet_zperm_allreduce.c index 1c8d44164e9a97dec5427f2a9d775cb7f28b9315..c254296a3266958645077a5df8015728c1f6d79f 100644 --- a/runtime/starpu/codelets/codelet_zperm_allreduce.c +++ b/runtime/starpu/codelets/codelet_zperm_allreduce.c @@ -31,6 +31,98 @@ struct cl_redux_args_t { int np_inv; }; +struct cl_redux_mpi_args_t { + int mb; + int nt; + int k; + int n; + int p; + int q; + int tag; + int myrank; + int np_involved; + MPI_Comm comm_panel; + int proc_involved[1]; +}; + +static void +zperm_allreduce_buffer_send( int *perm, + CHAM_tile_t *tileU, + int myrank, + int mb_full, + int n, + int p, + int q, + int *idx, + char *send_data ) +{ + int64_t i, m; + int64_t count = 0; + char *send = send_data; + int mb = tileU->m; + int nb = tileU->n; + CHAMELEON_Complex64_t *ptr_U = CHAM_tile_get_ptr( tileU ); + + for ( i = 0; i < mb; i++ ) { + m = perm[ i ] / mb_full; + if ( (m % p) * q + (n % q) == myrank ) { + count++; + } + } + + memcpy( send_data, &count, sizeof(int64_t) ); + send_data += sizeof(int64_t); + for ( i = 0; i < mb; i++ ) { + m = perm[ i ] / mb_full; + if ( (m % p) * q + (n % q) == myrank ) { + memcpy( send_data, &i, sizeof(int64_t) ); + send_data += sizeof(int64_t); + cblas_zcopy( nb, ptr_U + i, tileU->ld, (void *)send_data, 1 ); + send_data += nb * sizeof(CHAMELEON_Complex64_t); + idx[ i ] = 1; + } + else { + idx[ i ] = 0; + } + } + send += ( count + 1 ) * sizeof(int64_t) + count * nb * sizeof(CHAMELEON_Complex64_t); + assert( (char *)send == send_data ); +} + +static void +zperm_allreduce_copy( char *recv_rows, + char *send_data, + CHAM_tile_t *tileU, + int *idx ) +{ + int64_t i, j; + char *recv = recv_rows; + int64_t count_recv = *(int64_t *)recv_rows; + int64_t count_send = *(int64_t *)send_data; + int64_t new_count = count_recv + count_send; + int nb = tileU->n; + CHAMELEON_Complex64_t *ptr_U = CHAM_tile_get_ptr( tileU ); + + recv_rows += sizeof(int64_t) ; + memcpy( send_data, &new_count, sizeof(int64_t) ); + send_data += ( count_send + 1 ) * sizeof(int64_t) + count_send * nb * sizeof(CHAMELEON_Complex64_t); + for ( j = 0; j < count_recv; j++ ) { + i = *(int64_t *)recv_rows; + recv_rows += sizeof(int64_t); + if ( idx[ i ] == 0 ) { + cblas_zcopy( nb, (void *)recv_rows, 1, ptr_U + i, tileU->ld ); + memcpy( send_data, &i, sizeof(int64_t) ); + send_data += sizeof(int64_t); + memcpy( send_data, recv_rows, nb * sizeof(CHAMELEON_Complex64_t) ); + send_data += nb * sizeof(CHAMELEON_Complex64_t); + idx[ i ] = 1; + } + recv_rows += nb * sizeof(CHAMELEON_Complex64_t); + } + recv += ( count_recv + 1 ) * sizeof(int64_t) + count_recv * nb * sizeof(CHAMELEON_Complex64_t); + assert( (char *)recv == recv_rows ); +} + static void cl_zperm_allreduce_cpu_func( void *descr[], void *cl_arg ) { @@ -70,6 +162,72 @@ cl_zperm_allreduce_cpu_func( void *descr[], void *cl_arg ) CODELETS_CPU( zperm_allreduce, cl_zperm_allreduce_cpu_func ) +static void +cl_zperm_allreduce_mpi_cpu_func( void *descr[], void *cl_arg ) +{ + struct cl_redux_mpi_args_t *clargs = (struct cl_redux_mpi_args_t *) cl_arg; + CHAM_tile_t *tileU = cti_interface_get( descr[0] ); + int *perm = (int *)STARPU_VECTOR_GET_PTR( descr[1] ); + int mb_full = clargs->mb; + int nb = tileU->n; + int mb = tileU->m; + int nt = clargs->nt; + int k = clargs->k; + int n = clargs->n; + int p = clargs->p; + int q = clargs->q; + int tag = clargs->tag; + int myrank = clargs->myrank; + int np_involved = clargs->np_involved; + int np_iter = clargs->np_involved; + int *proc_involved = clargs->proc_involved; + MPI_Comm comm_panel = clargs->comm_panel; + int shift = 1; + int p_recv, p_send, me, size, size_max, count; + int tag_send, tag_recv, where_i_am; + int idx[ mb ]; + char *send_data; + char *recv_data; + MPI_Request request; + MPI_Status status; + + for( me = 0; me < np_involved; me++ ) { + if ( proc_involved[me] == myrank ) { + break; + } + } + assert( me < np_involved ); + + size_max = sizeof( int64_t ) * ( mb + 1 ) + mb * nb * sizeof( CHAMELEON_Complex64_t ); + send_data = malloc( size_max ); + recv_data = malloc( size_max ); + memset( send_data, 0, size_max ); + memset( recv_data, 0, size_max ); + zperm_allreduce_buffer_send( perm, tileU, myrank, mb_full, n, p, q, idx, send_data ); + tag_send = myrank / q + tag + k * nt * p + n * p; + + while ( np_iter > 1 ) { + p_send = proc_involved[ ( me + shift ) % np_involved ] / q; + p_recv = proc_involved[ ( me - shift + np_involved ) % np_involved ] / q; + tag_recv = p_recv + tag + k * nt * p + n * p; + count = *((int64_t *)send_data); + size = sizeof( int64_t ) * ( count + 1 ) + count * nb * sizeof( CHAMELEON_Complex64_t ); + + MPI_Isend( send_data, size, MPI_BYTE, p_send, tag_send, comm_panel, &request ); + MPI_Recv( recv_data, size_max, MPI_BYTE, p_recv, tag_recv, comm_panel, MPI_STATUS_IGNORE ); + MPI_Wait( &request, &status ); + zperm_allreduce_copy( recv_data, send_data, tileU, idx ); + + shift = shift << 1; + np_iter = chameleon_ceil( np_iter, 2 ); + } + + free( send_data ); + free( recv_data ); +} + +CODELETS_CPU( zperm_allreduce_mpi, cl_zperm_allreduce_mpi_cpu_func ) + static void INSERT_TASK_zperm_allreduce_send( const RUNTIME_option_t *options, CHAM_desc_t *U, @@ -168,6 +326,61 @@ zperm_allreduce_chameleon_starpu_task( const RUNTIME_option_t *options, } } +void +zperm_allreduce_chameleon_starpu_mpi_task( const RUNTIME_option_t *options, + const CHAM_desc_t *A, + CHAM_ipiv_t *ipiv, + int ipivk, + int k, + int n, + CHAM_desc_t *U, + int Um, + int Un, + struct chameleon_pzgetrf_s *ws ) +{ + int np_involved = chameleon_min( chameleon_desc_datadist_get_iparam(A, 0), A->mt - k); + size_t size = ( np_involved - 1 ) * sizeof( int ); + struct cl_redux_mpi_args_t *clargs; + int i, size_task; + + if ( np_involved == 1 ) { + assert( ws->proc_involved[0] == A->myrank ); + return; + } + + clargs = malloc( sizeof( struct cl_redux_mpi_args_t ) + size ); + clargs->mb = A->mb; + clargs->nt = A->nt; + clargs->k = k; + clargs->n = n; + clargs->p = chameleon_desc_datadist_get_iparam(A, 0); + clargs->q = chameleon_desc_datadist_get_iparam(A, 1); + clargs->tag = ( ( chameleon_min( A->mt, A->nt ) + 1 ) * A->nb + 1 ) * chameleon_desc_datadist_get_iparam(A, 0); + clargs->myrank = A->myrank; + clargs->np_involved = np_involved; + clargs->comm_panel = ws->comm_panel; + for ( i = 0; i < np_involved; i ++ ) { + clargs->proc_involved[i] = ws->proc_involved[i]; + } + + uint64_t tag = k * A->nt * chameleon_desc_datadist_get_iparam(A, 0) + n * chameleon_desc_datadist_get_iparam(A, 0) + A->myrank / chameleon_desc_datadist_get_iparam(A, 1); + if ( ws->tag != -1 ) { + starpu_tag_declare_deps( (starpu_tag_t)tag, 1, (starpu_tag_t)ws->tag ); + } + ws->tag = tag; + + rt_starpu_insert_task( + &cl_zperm_allreduce_mpi, + STARPU_CL_ARGS, clargs, sizeof(struct cl_redux_mpi_args_t) + size, + STARPU_RW, RTBLKADDR(U, CHAMELEON_Complex64_t, Um, Un), + STARPU_R, RUNTIME_perm_getaddr( ipiv, ipivk ), + STARPU_TAG, (starpu_tag_t)tag, + STARPU_EXECUTE_ON_NODE, A->myrank, + STARPU_EXECUTE_ON_WORKER, options->workerid, + STARPU_PRIORITY, options->priority, + 0 ); +} + void INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options, const CHAM_desc_t *A, @@ -183,6 +396,9 @@ INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options, struct chameleon_pzgetrf_s *tmp = (struct chameleon_pzgetrf_s *)ws; cham_getrf_allreduce_t alg = tmp->alg_allreduce; switch( alg ) { + case ChamStarPUMPITasks: + zperm_allreduce_chameleon_starpu_mpi_task( options, A, ipiv, ipivk, k, n, U, Um, Un, tmp ); + break; case ChamStarPUTasks: default: zperm_allreduce_chameleon_starpu_task( options, A, U, Um, Un, ipiv, ipivk, k, n, tmp ); @@ -250,6 +466,12 @@ INSERT_TASK_zperm_allreduce_send_invp( const RUNTIME_option_t *options, rank, NULL, NULL ); } } + +void +INSERT_TASK_zperm_allreduce_tag_free( ) +{ + starpu_tag_clear(); +} #else void INSERT_TASK_zperm_allreduce_send_A( const RUNTIME_option_t *options, @@ -324,4 +546,8 @@ INSERT_TASK_zperm_allreduce( const RUNTIME_option_t *options, (void)n; (void)ws; } + +void +INSERT_TASK_zperm_allreduce_tag_free( ) +{} #endif