Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 6b88697d authored by LISITO Alycia's avatar LISITO Alycia
Browse files

zgetrf: zperm distributed use Wu for update

parent dc30a4bb
No related branches found
No related tags found
1 merge request!496zgetrf: Allreduce perm
......@@ -26,6 +26,7 @@
#define A(m,n) A, m, n
#define U(m,n) &(ws->U), m, n
#define Up(m,n) &(ws->Up), m, n
#define Wu(m,n) &(ws->Wu), m, n
/*
* All the functions below are panel factorization variant.
......@@ -389,6 +390,19 @@ chameleon_pzgetrf_panel_permute( struct chameleon_pzgetrf_s *ws,
int m;
int tempkm, tempkn, tempnn, minmn;
chameleon_get_proc_involved_in_panelk_2dbc( A, k, n, ws );
if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
INSERT_TASK_zperm_allreduce_send_perm( options, ipiv, k, A->myrank, ws->np_involved, ws->proc_involved );
INSERT_TASK_zperm_allreduce_send_invp( options, ipiv, k, A, k, n );
}
if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
INSERT_TASK_zperm_allreduce_send_A( options, A, k, n, A->myrank, ws->np_involved, ws->proc_involved );
}
if ( !ws->involved ) {
return;
}
tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
tempkn = k == A->nt-1 ? A->n-k*A->nb : A->nb;
tempnn = n == A->nt-1 ? A->n-n*A->nb : A->nb;
......@@ -396,28 +410,26 @@ chameleon_pzgetrf_panel_permute( struct chameleon_pzgetrf_s *ws,
/* Extract selected rows into U */
INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
A(k, n), U(k, n) );
A(k, n), Wu(A->myrank, n) );
/*
* perm array is made of size tempkm for the first row especially.
* Otherwise, the final copy back to the tile may copy only a partial tile
*/
INSERT_TASK_zlaswp_get( options, k*A->mb, tempkm,
ipiv, k, A(k, n), U(k, n) );
ipiv, k, A(k, n), Wu(A->myrank, n) );
for(m=k+1; m<A->mt; m++){
/* Extract selected rows into A(k, n) */
INSERT_TASK_zlaswp_get( options, m*A->mb, minmn,
ipiv, k, A(m, n), U(k, n) );
ipiv, k, A(m, n), Wu(A->myrank, n) );
/* Copy rows from A(k,n) into their final position */
INSERT_TASK_zlaswp_set( options, m*A->mb, minmn,
ipiv, k, A(k, n), A(m, n) );
}
INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
U(k, n), A(k, n) );
RUNTIME_data_flush( options->sequence, U(k, n) );
INSERT_TASK_zperm_allreduce( options, A, ipiv, k, k, n,
Wu(A->myrank, n), ws );
}
break;
default:
......@@ -440,6 +452,20 @@ chameleon_pzgetrf_panel_permute_batched( struct chameleon_pzgetrf_s *ws,
{
int m;
int tempkm, tempkn, tempnn, minmn;
chameleon_get_proc_involved_in_panelk_2dbc( A, k, n, ws );
if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
INSERT_TASK_zperm_allreduce_send_perm( options, ipiv, k, A->myrank, ws->np_involved, ws->proc_involved );
INSERT_TASK_zperm_allreduce_send_invp( options, ipiv, k, A, k, n );
}
if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
INSERT_TASK_zperm_allreduce_send_A( options, A, k, n, A->myrank, ws->np_involved, ws->proc_involved );
}
if ( !ws->involved ) {
return;
}
void **clargs = malloc( sizeof(char *) );
*clargs = NULL;
......@@ -450,25 +476,23 @@ chameleon_pzgetrf_panel_permute_batched( struct chameleon_pzgetrf_s *ws,
/* Extract selected rows into U */
INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
A(k, n), U(k, n) );
A(k, n), Wu(A->myrank, n) );
/*
* perm array is made of size tempkm for the first row especially.
* Otherwise, the final copy back to the tile may copy only a partial tile
*/
INSERT_TASK_zlaswp_get( options, k*A->mb, tempkm,
ipiv, k, A(k, n), U(k, n) );
ipiv, k, A(k, n), Wu(A->myrank, n) );
for(m=k+1; m<A->mt; m++){
INSERT_TASK_zlaswp_batched( options, m*A->mb, minmn, k, m, n, (void *)ws,
ipiv, k, A, &(ws->U), clargs );
INSERT_TASK_zlaswp_batched( options, m*A->mb, minmn, (void *)ws, ipiv, k,
A(m, n), A(k, n), Wu(A->myrank, n), clargs );
}
INSERT_TASK_zlaswp_batched_flush( options, k, n, ipiv, k, A, &(ws->U), clargs );
INSERT_TASK_zlaswp_batched_flush( options, ipiv, k, A(k, n), Wu(A->myrank, n), clargs );
INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
U(k, n), A(k, n) );
INSERT_TASK_zperm_allreduce( options, A, ipiv, k, k, n, Wu(A->myrank, n), ws );
RUNTIME_data_flush( options->sequence, U(k, n) );
free( clargs );
}
break;
......@@ -488,7 +512,7 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws,
const CHAMELEON_Complex64_t zone = (CHAMELEON_Complex64_t) 1.0;
const CHAMELEON_Complex64_t mzone = (CHAMELEON_Complex64_t)-1.0;
int m, tempkm, tempmm, tempnn;
int m, tempkm, tempmm, tempnn, rankAmn, p;
tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
tempnn = n == A->nt-1 ? A->n-n*A->nb : A->nb;
......@@ -500,25 +524,44 @@ chameleon_pzgetrf_panel_update( struct chameleon_pzgetrf_s *ws,
chameleon_pzgetrf_panel_permute( ws, A, ipiv, k, n, options );
}
INSERT_TASK_ztrsm(
options,
ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
tempkm, tempnn, A->mb,
zone, A(k, k),
A(k, n) );
if ( A->myrank == chameleon_getrankof_2d( A, k, k ) ) {
for ( p = 0; p < ws->np_involved; p++ ) {
INSERT_TASK_ztrsm(
options,
ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
tempkm, tempnn, A->mb,
zone, A(k, k),
Wu(ws->proc_involved[p], n) );
}
}
else if ( ws->involved ) {
INSERT_TASK_ztrsm(
options,
ChamLeft, ChamLower, ChamNoTrans, ChamUnit,
tempkm, tempnn, A->mb,
zone, A(k, k),
Wu(A->myrank, n) );
}
for (m = k+1; m < A->mt; m++) {
tempmm = m == A->mt-1 ? A->m-m*A->mb : A->mb;
rankAmn = A->get_rankof( A, m, n );
INSERT_TASK_zgemm(
options,
ChamNoTrans, ChamNoTrans,
tempmm, tempnn, A->mb, A->mb,
mzone, A(m, k),
A(k, n),
Wu(rankAmn, n),
zone, A(m, n) );
}
if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
INSERT_TASK_zlacpy( options, ChamUpperLower, tempkm, tempnn,
Wu(A->myrank, n), A(k, n) );
}
RUNTIME_data_flush( options->sequence, Wu(A->myrank, n) );
RUNTIME_data_flush( options->sequence, A(k, n) );
}
......@@ -534,7 +577,7 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
CHAM_context_t *chamctxt;
RUNTIME_option_t options;
int k, m, n;
int k, m, n, tempkm, tempnn;
int min_mnt = chameleon_min( A->mt, A->nt );
chamctxt = chameleon_context_self();
......@@ -559,7 +602,11 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
for (n = k+1; n < A->nt; n++) {
options.priority = A->nt-n;
chameleon_pzgetrf_panel_update( ws, A, IPIV, k, n, &options );
if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
chameleon_involved_in_panelk_2dbc( A, n ) )
{
chameleon_pzgetrf_panel_update( ws, A, IPIV, k, n, &options );
}
}
/* Flush panel k */
......@@ -574,7 +621,19 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
if ( ws->batch_size > 0 ) {
for (k = 1; k < min_mnt; k++) {
for (n = 0; n < k; n++) {
chameleon_pzgetrf_panel_permute_batched( ws, A, IPIV, k, n, &options );
if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
chameleon_involved_in_panelk_2dbc( A, n ) )
{
chameleon_pzgetrf_panel_permute_batched( ws, A, IPIV, k, n, &options );
if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
tempnn = n == A->nt-1 ? A->n-n*A->nb : A->nb;
INSERT_TASK_zlacpy( &options, ChamUpperLower, tempkm, tempnn,
Wu(A->myrank, n), A(k, n) );
RUNTIME_data_flush( sequence, A(k, n) );
}
}
RUNTIME_data_flush( sequence, Wu(A->myrank, n) );
}
RUNTIME_perm_flushk( sequence, IPIV, k );
}
......@@ -582,7 +641,19 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
else {
for (k = 1; k < min_mnt; k++) {
for (n = 0; n < k; n++) {
chameleon_pzgetrf_panel_permute( ws, A, IPIV, k, n, &options );
if ( chameleon_involved_in_panelk_2dbc( A, k ) ||
chameleon_involved_in_panelk_2dbc( A, n ) )
{
chameleon_pzgetrf_panel_permute( ws, A, IPIV, k, n, &options );
if ( A->myrank == chameleon_getrankof_2d( A, k, n ) ) {
tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
tempnn = n == A->nt-1 ? A->n-n*A->nb : A->nb;
INSERT_TASK_zlacpy( &options, ChamUpperLower, tempkm, tempnn,
Wu(A->myrank, n), A(k, n) );
RUNTIME_data_flush( sequence, A(k, n) );
}
}
RUNTIME_data_flush( sequence, Wu(A->myrank, n) );
}
RUNTIME_perm_flushk( sequence, IPIV, k );
}
......
......@@ -118,6 +118,11 @@ CHAMELEON_zgetrf_WS_Alloc( const CHAM_desc_t *A )
A->m, A->n, 0, 0,
A->m, A->n, A->p, A->q,
NULL, NULL, A->get_rankof_init, A->get_rankof_init_arg );
chameleon_desc_init( &(ws->Wu), CHAMELEON_MAT_ALLOC_TILE,
ChamComplexDouble, A->mb, A->nb, A->mb*A->nb,
A->mb * A->p * A->q, A->n, 0, 0,
A->mb * A->p * A->q, A->n, A->p * A->q, 1,
NULL, NULL, NULL, A->get_rankof_init_arg );
}
/* Set ib to 1 if per column algorithm */
......@@ -180,6 +185,11 @@ CHAMELEON_zgetrf_WS_Free( void *user_ws )
{
chameleon_desc_destroy( &(ws->Up) );
}
if ( ( ws->alg == ChamGetrfPPiv ) ||
( ws->alg == ChamGetrfPPivPerColumn ) )
{
chameleon_desc_destroy( &(ws->Wu) );
}
free( ws );
}
......
......@@ -43,13 +43,15 @@ struct chameleon_pzgemm_s {
* @brief Data structure to handle the GETRF workspaces with partial pivoting
*/
struct chameleon_pzgetrf_s {
cham_getrf_t alg;
int ib; /**< Internal blocking parameter */
int batch_size; /**< Batch size for the panel */
CHAM_desc_t U;
CHAM_desc_t Up;
int *proc_involved;
unsigned int involved:1;
cham_getrf_t alg;
int ib; /**< Internal blocking parameter */
int batch_size; /**< Batch size for the panel */
CHAM_desc_t U;
CHAM_desc_t Up; /**< Workspace used for the panel factorization */
CHAM_desc_t Wu; /**< Workspace used for the permutation and update */
int *proc_involved;
unsigned int involved;
int np_involved;
};
/**
......
......@@ -199,17 +199,17 @@ void INSERT_TASK_zlaswp_set( const RUNTIME_option_t *options,
const CHAM_desc_t *tileA, int tileAm, int tileAn,
const CHAM_desc_t *tileB, int tileBm, int tileBn );
void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
int m0, int minmn, int k, int m, int n,
int m0, int minmn,
void *ws,
const CHAM_ipiv_t *ipiv, int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *U,
const CHAM_desc_t *Am, int Amm, int Amn,
const CHAM_desc_t *Ak, int Akm, int Akn,
const CHAM_desc_t *U, int Um, int Un,
void **clargs_ptr );
void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
int k, int n,
const CHAM_ipiv_t *ipiv, int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *U,
const CHAM_desc_t *Ak, int Akm, int Akn,
const CHAM_desc_t *U, int Um, int Un,
void **clargs_ptr );
void INSERT_TASK_zlatro( const RUNTIME_option_t *options,
cham_uplo_t uplo, cham_trans_t trans, int m, int n, int mb,
......
......@@ -21,45 +21,57 @@
void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
int m0,
int minmn,
int k,
int m,
int n,
void *ws,
const CHAM_ipiv_t *ipiv,
int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *Wu,
const CHAM_desc_t *Am,
int Amm,
int Amn,
const CHAM_desc_t *Ak,
int Akm,
int Akn,
const CHAM_desc_t *U,
int Um,
int Un,
void **clargs_ptr )
{
(void)options;
(void)m0;
(void)minmn;
(void)k;
(void)m;
(void)n;
(void)ws;
(void)ipiv;
(void)ipivk;
(void)A;
(void)Wu;
(void)Am;
(void)Amm;
(void)Amn;
(void)Ak;
(void)Akm;
(void)Akn;
(void)U;
(void)Um;
(void)Un;
(void)clargs_ptr;
}
void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
int k,
int n,
const CHAM_ipiv_t *ipiv,
int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *Ak,
int Akm,
int Akn,
const CHAM_desc_t *U,
int Um,
int Un,
void **clargs_ptr )
{
(void)options;
(void)k;
(void)n;
(void)ipiv;
(void)ipivk;
(void)A;
(void)Ak;
(void)Akm;
(void)Akn;
(void)U;
(void)Um;
(void)Un;
(void)clargs_ptr;
}
......@@ -21,45 +21,57 @@
void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
int m0,
int minmn,
int k,
int m,
int n,
void *ws,
const CHAM_ipiv_t *ipiv,
int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *Wu,
const CHAM_desc_t *Am,
int Amm,
int Amn,
const CHAM_desc_t *Ak,
int Akm,
int Akn,
const CHAM_desc_t *U,
int Um,
int Un,
void **clargs_ptr )
{
(void)options;
(void)m0;
(void)minmn;
(void)k;
(void)m;
(void)n;
(void)ws;
(void)ipiv;
(void)ipivk;
(void)A;
(void)Wu;
(void)Am;
(void)Amm;
(void)Amn;
(void)Ak;
(void)Akm;
(void)Akn;
(void)U;
(void)Um;
(void)Un;
(void)clargs_ptr;
}
void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
int k,
int n,
const CHAM_ipiv_t *ipiv,
int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *Ak,
int Akm,
int Akn,
const CHAM_desc_t *U,
int Um,
int Un,
void **clargs_ptr )
{
(void)options;
(void)k;
(void)n;
(void)ipiv;
(void)ipivk;
(void)A;
(void)Ak;
(void)Akm;
(void)Akn;
(void)U;
(void)Um;
(void)Un;
(void)clargs_ptr;
}
......@@ -21,45 +21,57 @@
void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
int m0,
int minmn,
int k,
int m,
int n,
void *ws,
const CHAM_ipiv_t *ipiv,
int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *Wu,
const CHAM_desc_t *Am,
int Amm,
int Amn,
const CHAM_desc_t *Ak,
int Akm,
int Akn,
const CHAM_desc_t *U,
int Um,
int Un,
void **clargs_ptr )
{
(void)options;
(void)m0;
(void)minmn;
(void)k;
(void)m;
(void)n;
(void)ws;
(void)ipiv;
(void)ipivk;
(void)A;
(void)Wu;
(void)Am;
(void)Amm;
(void)Amn;
(void)Ak;
(void)Akm;
(void)Akn;
(void)U;
(void)Um;
(void)Un;
(void)clargs_ptr;
}
void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
int k,
int n,
const CHAM_ipiv_t *ipiv,
int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *Ak,
int Akm,
int Akn,
const CHAM_desc_t *U,
int Um,
int Un,
void **clargs_ptr )
{
(void)options;
(void)k;
(void)n;
(void)ipiv;
(void)ipivk;
(void)A;
(void)Ak;
(void)Akm;
(void)Akn;
(void)U;
(void)Um;
(void)Un;
(void)clargs_ptr;
}
......@@ -47,6 +47,9 @@ void INSERT_TASK_zlaswp_get( const RUNTIME_option_t *options,
const CHAM_desc_t *U, int Um, int Un )
{
struct starpu_codelet *codelet = &cl_zlaswp_get;
if ( A->get_rankof( A, Am, An) != A->myrank ) {
return;
}
//void (*callback)(void*) = options->profiling ? cl_zlaswp_get_callback : NULL;
......@@ -91,6 +94,9 @@ void INSERT_TASK_zlaswp_set( const RUNTIME_option_t *options,
const CHAM_desc_t *B, int Bm, int Bn )
{
struct starpu_codelet *codelet = &cl_zlaswp_set;
if ( A->get_rankof( B, Bm, Bn) != A->myrank ) {
return;
}
//void (*callback)(void*) = options->profiling ? cl_zlaswp_set_callback : NULL;
......
......@@ -57,21 +57,25 @@ CODELETS_CPU( zlaswp_batched, cl_zlaswp_batched_cpu_func )
void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
int m0,
int minmn,
int k,
int m,
int n,
void *ws,
const CHAM_ipiv_t *ipiv,
int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *Wu,
const CHAM_desc_t *Am,
int Amm,
int Amn,
const CHAM_desc_t *Ak,
int Akm,
int Akn,
const CHAM_desc_t *U,
int Um,
int Un,
void **clargs_ptr )
{
int task_num = 0;
int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size;
int nhandles;
struct cl_laswp_batched_args_t *clargs = *clargs_ptr;
if ( A->get_rankof( A, m, n) != A->myrank ) {
if ( Am->get_rankof( Am, Amm, Amn) != Am->myrank ) {
return;
}
......@@ -84,7 +88,7 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
task_num = clargs->tasks_nbr;
clargs->m0[ task_num ] = m0;
clargs->handle_mode[ task_num ].handle = RTBLKADDR(A, CHAMELEON_Complex64_t, m, n);
clargs->handle_mode[ task_num ].handle = RTBLKADDR(Am, CHAMELEON_Complex64_t, Amm, Amn);
clargs->handle_mode[ task_num ].mode = STARPU_RW;
clargs->tasks_nbr ++;
......@@ -95,8 +99,8 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
STARPU_CL_ARGS, clargs, sizeof(struct cl_laswp_batched_args_t),
STARPU_R, RUNTIME_perm_getaddr( ipiv, ipivk ),
STARPU_R, RUNTIME_invp_getaddr( ipiv, ipivk ),
STARPU_RW | STARPU_COMMUTE, RTBLKADDR(Wu, ChamComplexDouble, A->myrank, n),
STARPU_R, RTBLKADDR(A, ChamComplexDouble, k, n),
STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, Um, Un),
STARPU_R, RTBLKADDR(Ak, ChamComplexDouble, Akm, Akn),
STARPU_DATA_MODE_ARRAY, clargs->handle_mode, nhandles,
STARPU_PRIORITY, options->priority,
STARPU_EXECUTE_ON_WORKER, options->workerid,
......@@ -108,12 +112,14 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
}
void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
int k,
int n,
const CHAM_ipiv_t *ipiv,
int ipivk,
const CHAM_desc_t *A,
const CHAM_desc_t *Ak,
int Akm,
int Akn,
const CHAM_desc_t *U,
int Um,
int Un,
void **clargs_ptr )
{
struct cl_laswp_batched_args_t *clargs = *clargs_ptr;
......@@ -129,8 +135,8 @@ void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
STARPU_CL_ARGS, clargs, sizeof(struct cl_laswp_batched_args_t),
STARPU_R, RUNTIME_perm_getaddr( ipiv, ipivk ),
STARPU_R, RUNTIME_invp_getaddr( ipiv, ipivk ),
STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, k, n),
STARPU_R, RTBLKADDR(A, ChamComplexDouble, k, n),
STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, Um, Un),
STARPU_R, RTBLKADDR(Ak, ChamComplexDouble, Akm, Akn),
STARPU_DATA_MODE_ARRAY, clargs->handle_mode, nhandles,
STARPU_PRIORITY, options->priority,
STARPU_EXECUTE_ON_WORKER, options->workerid,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment