Commit 847b6b7b authored by RAMET Pierre's avatar RAMET Pierre
Browse files

solve smp v0

parent 0e1b5718
......@@ -179,7 +179,7 @@ isched_parallel_section(isched_thread_t *ctx)
switch (action) {
case ISCHED_ACT_PARALLEL:
isched->pfunc( ctx->rank, isched->pargs );
isched->pfunc( ctx, isched->pargs );
break;
case ISCHED_ACT_FINALIZE:
return isched_thread_destroy( ctx );
......
......@@ -83,7 +83,7 @@ int isched_topo_unbind();
int isched_topo_world_size();
static inline void
isched_parallel_call( isched_t *isched, void (*func)(int, void*), void *args )
isched_parallel_call( isched_t *isched, void (*func)(isched_thread_t*, void*), void *args )
{
pthread_mutex_lock(&isched->statuslock);
isched->pfunc = func;
......@@ -93,7 +93,7 @@ isched_parallel_call( isched_t *isched, void (*func)(int, void*), void *args )
pthread_cond_broadcast(&isched->statuscond);
isched_barrier_wait( &(isched->barrier) );
isched->status = ISCHED_ACT_STAND_BY;
func( isched->master->rank, args );
func( isched->master, args );
isched_barrier_wait( &(isched->barrier) );
}
......
......@@ -658,7 +658,8 @@ void solve_ztrsmsp( int side, int uplo, int trans, int diag,
A = (pastix_complex64_t*)(fcbk->ucoeftab);
lda = (fcbk->cblktype & CBLK_SPLIT) ? tempn : fcbk->stride;
if ( ! (fcbk->cblktype & CBLK_DENSE)) {
pastix_cblk_lock( fcbk );
if ( !(fcbk->cblktype & CBLK_DENSE) ) {
lrA = blok->LRblock + 1;
switch (lrA->rk){
......@@ -699,6 +700,8 @@ void solve_ztrsmsp( int side, int uplo, int trans, int diag,
b + cblk->lcolidx + blok->frownum - cblk->fcolnum, ldb,
CBLAS_SADDR(zone), b + fcbk->lcolidx, ldb );
}
pastix_cblk_unlock( fcbk );
pastix_atomic_dec_32b( &(fcbk->ctrbcnt) );
}
}
/*
......@@ -739,12 +742,16 @@ void solve_ztrsmsp( int side, int uplo, int trans, int diag,
lda = (cblk->cblktype & CBLK_SPLIT) ? tempm : cblk->stride;
pastix_cblk_lock( fcbk );
cblas_zgemm(
CblasColMajor, CblasNoTrans, CblasNoTrans,
tempm, nrhs, tempn,
CBLAS_SADDR(mzone), A + blok->coefind, lda,
b + cblk->lcolidx, ldb,
CBLAS_SADDR(zone), b + fcbk->lcolidx + blok->frownum - fcbk->fcolnum, ldb );
pastix_cblk_unlock( fcbk );
pastix_atomic_dec_32b( &(fcbk->ctrbcnt) );
}
} else {
......@@ -767,7 +774,7 @@ void solve_ztrsmsp( int side, int uplo, int trans, int diag,
assert( blok->frownum >= fcbk->fcolnum );
assert( tempm <= (fcbk->lcolnum - fcbk->fcolnum + 1));
pastix_cblk_lock( fcbk );
switch (lrA->rk){
case 0:
break;
......@@ -799,6 +806,8 @@ void solve_ztrsmsp( int side, int uplo, int trans, int diag,
memFree_null(tmp);
}
pastix_cblk_unlock( fcbk );
pastix_atomic_dec_32b( &(fcbk->ctrbcnt) );
}
}
}
......@@ -831,6 +840,7 @@ void solve_ztrsmsp( int side, int uplo, int trans, int diag,
A = (pastix_complex64_t*)(fcbk->lcoeftab);
lda = (fcbk->cblktype & CBLK_SPLIT) ? tempn : fcbk->stride;
pastix_cblk_lock( fcbk );
if ( ! (fcbk->cblktype & CBLK_DENSE)) {
lrA = blok->LRblock;
......@@ -872,6 +882,8 @@ void solve_ztrsmsp( int side, int uplo, int trans, int diag,
b + cblk->lcolidx + blok->frownum - cblk->fcolnum, ldb,
CBLAS_SADDR(zone), b + fcbk->lcolidx, ldb );
}
pastix_cblk_unlock( fcbk );
pastix_atomic_dec_32b( &(fcbk->ctrbcnt) );
}
}
}
......
......@@ -65,7 +65,7 @@ struct coeftabinit_s {
* assigned to each thread.)
*/
void
pcoeftabInit( int rank, void *args )
pcoeftabInit( isched_thread_t *ctx, void *args )
{
struct coeftabinit_s *ciargs = (struct coeftabinit_s*)args;
const SolverMatrix *datacode = ciargs->datacode;
......@@ -74,6 +74,7 @@ pcoeftabInit( int rank, void *args )
int factoLU = ciargs->factoLU;
pastix_int_t i, itercblk;
pastix_int_t task;
int rank = ctx->rank;
void (*initfunc)(const SolverMatrix*,
const pastix_bcsc_t*,
pastix_int_t,
......
......@@ -62,7 +62,7 @@ sequential_zgetrf( pastix_data_t *pastix_data,
}
void
thread_pzgetrf( int rank, void *args )
thread_pzgetrf( isched_thread_t *ctx, void *args )
{
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
......@@ -71,6 +71,7 @@ thread_pzgetrf( int rank, void *args )
pastix_complex64_t *work;
pastix_int_t i, ii;
pastix_int_t tasknbr, *tasktab;
int rank = ctx->rank;
MALLOC_INTERN( work, datacode->gemmmax, pastix_complex64_t );
......@@ -90,11 +91,11 @@ thread_pzgetrf( int rank, void *args )
}
#if defined(PASTIX_DEBUG_FACTO) && 0
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
isched_barrier_wait( &(ctx->global_ctx->barrier) );
if (rank == 0) {
coeftab_zdump( datacode, "getrf_L.txt" );
}
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
isched_barrier_wait( &(ctx->global_ctx->barrier) );
#endif
memFree_null( work );
......
......@@ -59,7 +59,7 @@ sequential_zhetrf( pastix_data_t *pastix_data,
}
void
thread_pzhetrf( int rank, void *args )
thread_pzhetrf( isched_thread_t *ctx, void *args )
{
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
......@@ -68,6 +68,7 @@ thread_pzhetrf( int rank, void *args )
pastix_complex64_t *work1, *work2;
pastix_int_t i, ii;
pastix_int_t tasknbr, *tasktab;
int rank = ctx->rank;
MALLOC_INTERN( work1, pastix_imax(datacode->gemmmax, datacode->diagmax),
pastix_complex64_t );
......@@ -90,11 +91,11 @@ thread_pzhetrf( int rank, void *args )
}
#if defined(PASTIX_DEBUG_FACTO) && 0
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
isched_barrier_wait( &(ctx->global_ctx->barrier) );
if (rank == 0) {
coeftab_zdump( datacode, "hetrf_L.txt" );
}
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
isched_barrier_wait( &(ctx->global_ctx->barrier) );
#endif
memFree_null( work1 );
......
......@@ -59,7 +59,7 @@ sequential_zpotrf( pastix_data_t *pastix_data,
}
void
thread_pzpotrf( int rank, void *args )
thread_pzpotrf( isched_thread_t *ctx, void *args )
{
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
......@@ -68,6 +68,7 @@ thread_pzpotrf( int rank, void *args )
pastix_complex64_t *work;
pastix_int_t i, ii;
pastix_int_t tasknbr, *tasktab;
int rank = ctx->rank;
MALLOC_INTERN( work, datacode->gemmmax, pastix_complex64_t );
......@@ -87,11 +88,11 @@ thread_pzpotrf( int rank, void *args )
}
#if defined(PASTIX_DEBUG_FACTO) && 0
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
isched_barrier_wait( &(ctx->global_ctx->barrier) );
if (rank == 0) {
coeftab_zdump( datacode, "potrf_L.txt" );
}
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
isched_barrier_wait( &(ctx->global_ctx->barrier) );
#endif
memFree_null( work );
......
......@@ -59,7 +59,7 @@ sequential_zsytrf( pastix_data_t *pastix_data,
}
void
thread_pzsytrf( int rank, void *args )
thread_pzsytrf( isched_thread_t *ctx, void *args )
{
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
......@@ -68,6 +68,7 @@ thread_pzsytrf( int rank, void *args )
pastix_complex64_t *work1, *work2;
pastix_int_t i, ii;
pastix_int_t tasknbr, *tasktab;
int rank = ctx->rank;
MALLOC_INTERN( work1, pastix_imax(datacode->gemmmax, datacode->diagmax),
pastix_complex64_t );
......@@ -90,11 +91,11 @@ thread_pzsytrf( int rank, void *args )
}
#if defined(PASTIX_DEBUG_FACTO) && 0
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
isched_barrier_wait( &(ctx->global_ctx->barrier) );
if (rank == 0) {
coeftab_zdump( datacode, "sytrf_L.txt" );
}
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
isched_barrier_wait( &(ctx->global_ctx->barrier) );
#endif
memFree_null( work1 );
......
......@@ -45,47 +45,30 @@ sequential_ztrsm( pastix_data_t *pastix_data, int side, int uplo, int trans, int
pastix_int_t i;
(void)pastix_data;
if (side == PastixLeft) {
if (uplo == PastixUpper) {
/*
* Left / Upper / NoTrans
*/
if (trans == PastixNoTrans) {
cblk = datacode->cblktab + datacode->cblknbr - 1;
for (i=0; i<datacode->cblknbr; i++, cblk--){
solve_ztrsmsp( side, uplo, trans, diag,
datacode, cblk, nrhs, b, ldb );
}
}
/* We store U^t, so we swap uplo and trans */
}
else {
/*
* Left / Lower / NoTrans
*/
if (trans == PastixNoTrans) {
cblk = datacode->cblktab;
for (i=0; i<datacode->cblknbr; i++, cblk++){
solve_ztrsmsp( side, uplo, trans, diag,
datacode, cblk, nrhs, b, ldb );
}
}
/*
* Left / Lower / [Conj]Trans
*/
else {
cblk = datacode->cblktab + datacode->cblknbr - 1;
for (i=0; i<datacode->cblknbr; i++, cblk--){
solve_ztrsmsp( side, uplo, trans, diag,
datacode, cblk, nrhs, b, ldb );
}
}
if ( ( (side == PastixLeft) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) ||
( (side == PastixLeft) && (uplo == PastixLower) && (trans != PastixNoTrans) ) ||
( (side == PastixRight) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) ||
( (side == PastixRight) && (uplo == PastixLower) && (trans == PastixNoTrans) ) )
{
cblk = datacode->cblktab + datacode->cblknbr - 1;
for (i=0; i<datacode->cblknbr; i++, cblk--){
solve_ztrsmsp( side, uplo, trans, diag,
datacode, cblk, nrhs, b, ldb );
}
}
/**
* Right
*/
else {
else
/**
* ( (side == PastixRight) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) ||
* ( (side == PastixRight) && (uplo == PastixLower) && (trans != PastixNoTrans) ) ||
* ( (side == PastixLeft) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) ||
* ( (side == PastixLeft) && (uplo == PastixLower) && (trans == PastixNoTrans) )
*/
{
cblk = datacode->cblktab;
for (i=0; i<datacode->cblknbr; i++, cblk++){
solve_ztrsmsp( side, uplo, trans, diag,
datacode, cblk, nrhs, b, ldb );
}
}
}
......@@ -99,7 +82,7 @@ struct args_ztrsm_t
};
void
thread_pztrsm( int rank, void *args )
thread_pztrsm( isched_thread_t *ctx, void *args )
{
struct args_ztrsm_t *arg = (struct args_ztrsm_t*)args;
sopalin_data_t *sopalin_data = arg->sopalin_data;
......@@ -115,13 +98,68 @@ thread_pztrsm( int rank, void *args )
Task *t;
pastix_int_t i,ii;
pastix_int_t tasknbr, *tasktab;
int rank = ctx->rank;
tasknbr = datacode->ttsknbr[rank];
tasktab = datacode->ttsktab[rank];
/* try in sequential */
if (!rank)
sequential_ztrsm(NULL, side, uplo, trans, diag, sopalin_data, nrhs, b, ldb);
/* Backward like */
if ( ( (side == PastixLeft) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) ||
( (side == PastixLeft) && (uplo == PastixLower) && (trans != PastixNoTrans) ) ||
( (side == PastixRight) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) ||
( (side == PastixRight) && (uplo == PastixLower) && (trans == PastixNoTrans) ) )
{
/* Init ctrbcnt in parallel */
for (ii=0; ii<tasknbr; ii++) {
i = tasktab[ii];
t = datacode->tasktab + i;
cblk = datacode->cblktab + t->cblknum;
cblk->ctrbcnt = (cblk[1].fblokptr-cblk[0].fblokptr)-1;
}
isched_barrier_wait( &(ctx->global_ctx->barrier) );
for (ii=tasknbr-1; ii>=0; ii--) {
i = tasktab[ii];
t = datacode->tasktab + i;
cblk = datacode->cblktab + t->cblknum;
/* Wait */
do {} while( cblk->ctrbcnt );
solve_ztrsmsp( side, uplo, trans, diag,
datacode, cblk, nrhs, b, ldb );
}
}
/* Forward like */
else
/**
* ( (side == PastixRight) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) ||
* ( (side == PastixRight) && (uplo == PastixLower) && (trans != PastixNoTrans) ) ||
* ( (side == PastixLeft) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) ||
* ( (side == PastixLeft) && (uplo == PastixLower) && (trans == PastixNoTrans) )
*/
{
/* Init ctrbcnt in parallel */
for (ii=0; ii<tasknbr; ii++) {
i = tasktab[ii];
t = datacode->tasktab + i;
cblk = datacode->cblktab + t->cblknum;
cblk->ctrbcnt = cblk[1].brownum - cblk[0].brownum;
}
isched_barrier_wait( &(ctx->global_ctx->barrier) );
for (ii=0; ii<tasknbr; ii++) {
i = tasktab[ii];
t = datacode->tasktab + i;
cblk = datacode->cblktab + t->cblknum;
/* Wait */
do {} while( cblk->ctrbcnt );
solve_ztrsmsp( side, uplo, trans, diag,
datacode, cblk, nrhs, b, ldb );
}
}
}
void
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment