Commit 3d6612b1 authored by Antoine Jego's avatar Antoine Jego
Browse files

fixed prune-handles (always create handle)

parent 0ec80708
......@@ -23,6 +23,7 @@ Matrix* alloc_matrix(int mb, int nb, int b, int p, int q, starpu_mpi_tag_t* tag)
X->blocks[i*nb+j].ld = b;
X->blocks[i*nb+j].tag= *tag;
X->blocks[i*nb+j].registered = 0;
X->blocks[i*nb+j].hdl = NULL;
if (X->blocks[i*nb+j].owner == comm_rank)
X->blocks[i*nb+j].c = malloc(b*b*sizeof(double));
else
......@@ -69,19 +70,19 @@ void register_matrix(Matrix* X, int mb, int nb, int datatype, int prune_handles,
for (b_col = 0; b_col < nb; b_col++)
{
Xij = & X->blocks[b_row*nb+b_col];
Xij->hdl = malloc(sizeof(starpu_data_handle_t));
// printf("[%d] X_%d,%d | tag:%d\n",comm_rank,b_row,b_col,Xij->tag);
if (Xij->owner == comm_rank)
{
Xij->hdl = malloc(sizeof(starpu_data_handle_t));
if (datatype) {
starpu_tile_register( &Xij->hdl, STARPU_MAIN_RAM, Xij );
starpu_tile_register( Xij->hdl, STARPU_MAIN_RAM, Xij );
} else {
starpu_matrix_data_register( &Xij->hdl, STARPU_MAIN_RAM,
starpu_matrix_data_register( Xij->hdl, STARPU_MAIN_RAM,
(uintptr_t) Xij->c, Xij->m, Xij->n, Xij->ld,
sizeof(double));
}
starpu_mpi_data_register(Xij->hdl, Xij->tag, Xij->owner);
// printf("[%d] X_%d,%d | mpi_data_register %p\n",comm_rank,b_row,b_col,Xij->hdl);
starpu_mpi_data_register(*Xij->hdl, Xij->tag, Xij->owner);
// printf("[%d] X_%d,%d | mpi_data_register %p\n",comm_rank,b_row,b_col,*Xij->hdl);
Xij->registered = 1;
} else if (!delay && (!prune_handles || (row && proc_row == b_row % p) ||
(col && proc_col == b_col % q) ||
......@@ -107,8 +108,10 @@ void unregister_matrix(Matrix* X, int mb, int nb)
// assuming we flush, we do not need to unregister everywhere
if (X->blocks[b_row*nb+b_col].owner == comm_rank) {
// printf("[%d] unregistering X_%d,%d\n", comm_rank, b_row, b_col);
starpu_data_unregister(X->blocks[b_row*nb+b_col].hdl);
starpu_data_unregister(*X->blocks[b_row*nb+b_col].hdl);
}
free(X->blocks[b_row*nb+b_col].hdl);
X->blocks[b_row*nb+b_col].registered = 0;
}
}
}
......@@ -136,17 +139,17 @@ void print_matrix(Matrix* X, char* name) {
void block_starpu_register(Block* Xij, int datatype) {
if (!Xij->registered) {
Xij->hdl = malloc(sizeof(starpu_data_handle_t));
// printf("[%d] X_block | mpi_data_register %p\n",comm_rank,Xij->hdl);
starpu_mpi_comm_rank(MPI_COMM_WORLD, &comm_rank);
// Xij->hdl = malloc(sizeof(starpu_data_handle_t));
if (datatype) {
starpu_tile_register( &Xij->hdl, -1, Xij );
starpu_tile_register( Xij->hdl, -1, Xij );
} else {
starpu_matrix_data_register( &Xij->hdl, -1,
starpu_matrix_data_register( Xij->hdl, -1,
(uintptr_t) NULL, Xij->m, Xij->n, Xij->ld,
sizeof(double));
}
starpu_mpi_data_register(Xij->hdl, Xij->tag, Xij->owner);
starpu_mpi_data_register(*Xij->hdl, Xij->tag, Xij->owner);
// printf("[%d] X_block | mpi_data_register %p\n",comm_rank,*Xij->hdl);
Xij->registered = 1;
} else {
// printf("[%d] X_block | already registered\n");
......
......@@ -6,7 +6,7 @@ typedef struct Blocks
double* c;
int m,n,ld;
int owner;
starpu_data_handle_t hdl;
starpu_data_handle_t* hdl;
starpu_mpi_tag_t tag;
int registered;
} Block;
......
......@@ -235,6 +235,7 @@ static void unregister_matrices()
if (datatype) {
starpu_tile_interface_register();
}
if (verbose) printf( "[%d] Unregistered matrices\n", comm_rank);
}
struct cl_zgemm_args_s {
......@@ -427,7 +428,7 @@ static void init_matrix(Matrix* X, int mb, int nb)
{
// printf("[%d] fill X_%d,%d %p\n",comm_rank,row,col, X->blocks[row*nb+col].hdl);
starpu_mpi_task_insert(MPI_COMM_WORLD, &fill_cl,
STARPU_W, X->blocks[row*nb+col].hdl, 0);
STARPU_W, *X->blocks[row*nb+col].hdl, 0);
// printf("[%d] filled X_%d,%d\n",comm_rank,row,col);
}
}
......@@ -615,15 +616,19 @@ int main(int argc, char *argv[])
int a_local, b_local, c_local;
int b_row,b_col,b_aisle;
Block *Ail, *Blj, *Cij;
for (b_row = 0; b_row < MB; b_row++)
{
for (b_col = 0; b_col < NB; b_col++)
{
for (b_aisle=0;b_aisle<KB;b_aisle++)
{
a_local = A->blocks[b_row*KB+b_aisle].owner == comm_rank;
b_local = B->blocks[b_aisle*NB+b_col].owner == comm_rank;
c_local = C->blocks[ b_row*NB+b_col].owner == comm_rank;
Ail = & A->blocks[b_row*KB + b_aisle];
Blj = & B->blocks[b_aisle*NB + b_col];
Cij = & C->blocks[b_col * NB + b_row];
a_local = Ail->owner == comm_rank;
b_local = Blj->owner == comm_rank;
c_local = Cij->owner == comm_rank;
// when prune and/or prune_handles are allowed needs to be clarified
//if ((!prune && !prune_handles) || (A->blocks[b_row*KB+b_aisle].owner == comm_rank || B->blocks[b_aisle*NB+b_col].owner == comm_rank || C->blocks[b_row*NB+b_col].owner == comm_rank)) {
// TODO : logic might be written more clearly (a/b/c_local may be redundant)
......@@ -632,10 +637,10 @@ int main(int argc, char *argv[])
if (delay) {
// printf("[%d] late registration i,j,l %d,%d,%d\n",comm_rank,b_row,b_col,b_aisle);
if (!prune_handles || c_local) {
block_starpu_register(& (A->blocks[b_row*KB+b_aisle]),datatype);
block_starpu_register(& (B->blocks[b_aisle*NB+b_col]),datatype);
block_starpu_register(Ail,datatype);
block_starpu_register(Blj,datatype);
}
block_starpu_register(& (C->blocks[b_row*NB+b_col]), datatype);
block_starpu_register(Cij, datatype);
}
struct cl_zgemm_args_s *clargs = NULL;
if (c_local) {
......@@ -646,13 +651,11 @@ int main(int argc, char *argv[])
}
starpu_mpi_task_insert(MPI_COMM_WORLD, &gemm_cl,
STARPU_CL_ARGS, clargs, sizeof(struct cl_zgemm_args_s),
//STARPU_R, *A->blocks[b_row*KB+b_aisle].hdl,
STARPU_R, A->blocks[b_row*KB+b_aisle].hdl,
STARPU_R, B->blocks[b_aisle*NB+b_col].hdl,
STARPU_RW,C->blocks[b_row * NB+b_col].hdl, 0);
//printf("[%d] inserted C_%d,%d += A_%d,%d B_%d,%d\n",comm_rank, b_row,b_col, b_row,b_aisle, b_aisle,b_col);
STARPU_R, *Ail->hdl,
STARPU_R, *Blj->hdl,
STARPU_RW,*Cij->hdl, 0);
} else {
//printf("[%d] NOT inserted C_%d,%d += A_%d,%d B_%d,%d\n",comm_rank, b_row,b_col, b_row,b_aisle, b_aisle,b_col);
// printf("[%d] NOT inserted C_%d,%d += A_%d,%d B_%d,%d\n",comm_rank, b_row,b_col, b_row,b_aisle, b_aisle,b_col);
}
}
}
......@@ -660,11 +663,11 @@ int main(int argc, char *argv[])
for (b_aisle=0;b_aisle<KB;b_aisle++)
{
if (A->blocks[b_row*KB+b_aisle].registered)
starpu_mpi_cache_flush(MPI_COMM_WORLD, A->blocks[b_row*KB+b_aisle].hdl);
starpu_mpi_cache_flush(MPI_COMM_WORLD, *Ail->hdl);
}
}
}
//printf("[%d] finished submission\n",comm_rank);
// printf("[%d] finished submission\n",comm_rank);
starpu_mpi_wait_for_all(MPI_COMM_WORLD);
barrier_ret = starpu_mpi_barrier(MPI_COMM_WORLD);
stop = starpu_timing_now();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment