Commit 39f71e00 authored by Mathieu Faverge's avatar Mathieu Faverge
Browse files

Move the workspace allocation to the higher level to avoid repeated allocation/free

parent 7281362c
......@@ -240,7 +240,7 @@ int core_zgemdm(int transA, int transB,
if ( transA == CblasNoTrans )
{
/* WORK = A * D */
for (j=0; j<K; j++, wD++) {
for (j=0; j<K; j++, wD++) {
delta = *wD;
cblas_zcopy(M, &A[LDA*j], 1, &w[M*j], 1);
cblas_zscal(M, CBLAS_SADDR(delta), &w[M*j], 1);
......
......@@ -592,13 +592,13 @@ void core_zgetrfsp1d_gemm( SolverCblk *cblk,
*
*******************************************************************************/
int
core_zgetrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria )
core_zgetrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria,
pastix_complex64_t *work)
{
pastix_complex64_t *L = cblk->lcoeftab;
pastix_complex64_t *U = cblk->ucoeftab;
pastix_complex64_t *work = NULL;
SolverCblk *fcblk;
SolverBlok *blok, *lblk;
pastix_int_t nbpivot;
......@@ -608,15 +608,6 @@ core_zgetrfsp1d( SolverMatrix *solvmtx,
blok = cblk->fblokptr + 1; /* this diagonal block */
lblk = cblk[1].fblokptr; /* the next diagonal block */
if ( blok < lblk ) {
pastix_int_t maxarea = 0;
for( ; blok < lblk; blok++ )
{
maxarea = pastix_imax( maxarea, blok_rownbr( blok ) * cblk->stride );
}
MALLOC_INTERN( work, maxarea, pastix_complex64_t );
}
/* if there are off-diagonal supernodes in the column */
blok = cblk->fblokptr+1;
for( ; blok < lblk; blok++ )
......@@ -627,7 +618,6 @@ core_zgetrfsp1d( SolverMatrix *solvmtx,
L, U, fcblk->lcoeftab, fcblk->ucoeftab, work );
}
memFree_null( work );
return nbpivot;
}
......@@ -508,37 +508,23 @@ int core_zhetrfsp1d_panel( SolverCblk *cblk,
*
*******************************************************************************/
int
core_zhetrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria )
core_zhetrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria,
pastix_complex64_t *work1,
pastix_complex64_t *work2)
{
pastix_complex64_t *L = cblk->lcoeftab;
pastix_complex64_t *work1 = NULL;
pastix_complex64_t *work2 = NULL;
SolverCblk *fcblk;
SolverBlok *blok, *lblk;
pastix_int_t nbpivot;
pastix_int_t maxarea;
blok = cblk->fblokptr; /* this diagonal block */
lblk = cblk[1].fblokptr; /* the next diagonal block */
maxarea = blok_rownbr( blok ) * blok_rownbr( blok );
blok++;
if ( blok < lblk ) {
for( ; blok < lblk; blok++ )
{
maxarea = pastix_imax( maxarea, (blok_rownbr( blok )+1) * cblk->stride );
}
}
MALLOC_INTERN( work1, maxarea, pastix_complex64_t );
MALLOC_INTERN( work2, cblk->stride * cblk_colnbr(cblk), pastix_complex64_t );
/* if there are off-diagonal supernodes in the column */
nbpivot = core_zhetrfsp1d_hetrf(cblk, L, criteria, work1);
core_zhetrfsp1d_trsm(cblk, L);
blok = cblk->fblokptr+1;
blok = cblk->fblokptr+1; /* this diagonal block */
lblk = cblk[1].fblokptr; /* the next diagonal block */
for( ; blok < lblk; blok++ )
{
fcblk = (solvmtx->cblktab + blok->fcblknm);
......@@ -548,8 +534,6 @@ core_zhetrfsp1d( SolverMatrix *solvmtx,
work1, work2 );
}
memFree_null( work1 );
memFree_null( work2 );
return nbpivot;
}
......@@ -475,12 +475,12 @@ int core_zpotrfsp1d_panel( SolverCblk *cblk,
*
*******************************************************************************/
int
core_zpotrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria )
core_zpotrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria,
pastix_complex64_t *work)
{
pastix_complex64_t *L = cblk->lcoeftab;
pastix_complex64_t *work = NULL;
SolverCblk *fcblk;
SolverBlok *blok, *lblk;
pastix_int_t nbpivot;
......@@ -491,17 +491,7 @@ core_zpotrfsp1d( SolverMatrix *solvmtx,
blok = cblk->fblokptr + 1; /* this diagonal block */
lblk = cblk[1].fblokptr; /* the next diagonal block */
if ( blok < lblk ) {
pastix_int_t maxarea = 0;
for( ; blok < lblk; blok++ )
{
maxarea = pastix_imax( maxarea, blok_rownbr( blok ) * cblk->stride );
}
MALLOC_INTERN( work, maxarea, pastix_complex64_t );
}
/* if there are off-diagonal supernodes in the column */
blok = cblk->fblokptr+1;
for( ; blok < lblk; blok++ )
{
fcblk = (solvmtx->cblktab + blok->fcblknm);
......@@ -510,7 +500,6 @@ core_zpotrfsp1d( SolverMatrix *solvmtx,
L, fcblk->lcoeftab, work );
}
memFree_null( work );
return nbpivot;
}
......@@ -514,37 +514,23 @@ int core_zsytrfsp1d_panel( SolverCblk *cblk,
*
*******************************************************************************/
int
core_zsytrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria )
core_zsytrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria,
pastix_complex64_t *work1,
pastix_complex64_t *work2)
{
pastix_complex64_t *L = cblk->lcoeftab;
pastix_complex64_t *work1 = NULL;
pastix_complex64_t *work2 = NULL;
SolverCblk *fcblk;
SolverBlok *blok, *lblk;
pastix_int_t nbpivot;
pastix_int_t maxarea;
blok = cblk->fblokptr; /* this diagonal block */
lblk = cblk[1].fblokptr; /* the next diagonal block */
maxarea = blok_rownbr( blok ) * blok_rownbr( blok );
blok++;
if ( blok < lblk ) {
for( ; blok < lblk; blok++ )
{
maxarea = pastix_imax( maxarea, (blok_rownbr( blok )+1) * cblk->stride );
}
}
MALLOC_INTERN( work1, maxarea, pastix_complex64_t );
MALLOC_INTERN( work2, cblk->stride * cblk_colnbr(cblk), pastix_complex64_t );
/* if there are off-diagonal supernodes in the column */
nbpivot = core_zsytrfsp1d_sytrf(cblk, L, criteria, work1);
core_zsytrfsp1d_trsm(cblk, L);
blok = cblk->fblokptr+1;
blok = cblk->fblokptr+1; /* this diagonal block */
lblk = cblk[1].fblokptr; /* the next diagonal block */
for( ; blok < lblk; blok++ )
{
fcblk = (solvmtx->cblktab + blok->fcblknm);
......@@ -554,8 +540,6 @@ core_zsytrfsp1d( SolverMatrix *solvmtx,
work1, work2 );
}
memFree_null( work1 );
memFree_null( work2 );
return nbpivot;
}
......@@ -59,9 +59,10 @@ void core_zgetrfsp1d_gemm( SolverCblk *cblk,
pastix_complex64_t *Cu,
pastix_complex64_t *work );
int core_zgetrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria );
int core_zgetrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria,
pastix_complex64_t *work );
#if defined(PRECISION_z) || defined(PRECISION_c)
int core_zhetrfsp1d_hetrf( SolverCblk *cblk,
......@@ -85,9 +86,11 @@ void core_zhetrfsp1d_gemm( SolverCblk *cblk,
pastix_complex64_t *work1,
pastix_complex64_t *work2 );
int core_zhetrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria );
int core_zhetrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria,
pastix_complex64_t *work1,
pastix_complex64_t *work2 );
#endif /* defined(PRECISION_z) || defined(PRECISION_c) */
......@@ -109,9 +112,10 @@ void core_zpotrfsp1d_gemm(SolverCblk *cblk,
pastix_complex64_t *C,
pastix_complex64_t *work);
int core_zpotrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria );
int core_zpotrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria,
pastix_complex64_t *work );
int core_zsytrfsp1d_sytrf( SolverCblk *cblk,
pastix_complex64_t *L,
......@@ -134,8 +138,10 @@ void core_zsytrfsp1d_gemm( SolverCblk *cblk,
pastix_complex64_t *work1,
pastix_complex64_t *work2);
int core_zsytrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria );
int core_zsytrfsp1d( SolverMatrix *solvmtx,
SolverCblk *cblk,
double criteria,
pastix_complex64_t *work1,
pastix_complex64_t *work2 );
#endif /* _CORE_Z_H_ */
......@@ -32,33 +32,41 @@ void
sequential_zgetrf( pastix_data_t *pastix_data,
sopalin_data_t *sopalin_data )
{
SolverMatrix *datacode = pastix_data->solvmatr;
SolverCblk *cblk;
double threshold = sopalin_data->diagthreshold;
SolverMatrix *datacode = pastix_data->solvmatr;
SolverCblk *cblk;
double threshold = sopalin_data->diagthreshold;
pastix_complex64_t *work;
pastix_int_t i;
(void)pastix_data;
MALLOC_INTERN( work, datacode->gemmmax, pastix_complex64_t );
cblk = datacode->cblktab;
for (i=0; i<datacode->cblknbr; i++, cblk++){
/* Compute */
core_zgetrfsp1d( datacode, cblk, threshold );
core_zgetrfsp1d( datacode, cblk, threshold, work );
}
#if defined(PASTIX_DEBUG_FACTO)
coeftab_zdump( datacode, "getrf_L.txt" );
#endif
memFree_null( work );
}
void
thread_pzgetrf( int rank, void *args )
{
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
SolverCblk *cblk;
Task *t;
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
SolverCblk *cblk;
Task *t;
pastix_complex64_t *work;
pastix_int_t i, ii;
pastix_int_t tasknbr, *tasktab;
MALLOC_INTERN( work, datacode->gemmmax, pastix_complex64_t );
tasknbr = datacode->ttsknbr[rank];
tasktab = datacode->ttsktab[rank];
......@@ -68,7 +76,7 @@ thread_pzgetrf( int rank, void *args )
cblk = datacode->cblktab + t->cblknum;
/* Compute */
core_zgetrfsp1d( datacode, cblk, sopalin_data->diagthreshold );
core_zgetrfsp1d( datacode, cblk, sopalin_data->diagthreshold, work );
}
#if defined(PASTIX_DEBUG_FACTO)
......@@ -78,6 +86,8 @@ thread_pzgetrf( int rank, void *args )
}
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
#endif
memFree_null( work );
}
void
......
......@@ -32,33 +32,47 @@ void
sequential_zhetrf( pastix_data_t *pastix_data,
sopalin_data_t *sopalin_data )
{
SolverMatrix *datacode = pastix_data->solvmatr;
SolverCblk *cblk;
double threshold = sopalin_data->diagthreshold;
SolverMatrix *datacode = pastix_data->solvmatr;
SolverCblk *cblk;
double threshold = sopalin_data->diagthreshold;
pastix_complex64_t *work1, *work2;
pastix_int_t i;
(void)pastix_data;
MALLOC_INTERN( work1, pastix_imax(datacode->gemmmax, datacode->diagmax),
pastix_complex64_t );
MALLOC_INTERN( work2, datacode->gemmmax, pastix_complex64_t );
cblk = datacode->cblktab;
for (i=0; i<datacode->cblknbr; i++, cblk++){
/* Compute */
core_zhetrfsp1d( datacode, cblk, threshold );
core_zhetrfsp1d( datacode, cblk, threshold,
work1, work2 );
}
#if defined(PASTIX_DEBUG_FACTO)
coeftab_zdump( datacode, "hetrf_L.txt" );
#endif
memFree_null( work1 );
memFree_null( work2 );
}
void
thread_pzhetrf( int rank, void *args )
{
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
SolverCblk *cblk;
Task *t;
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
SolverCblk *cblk;
Task *t;
pastix_complex64_t *work1, *work2;
pastix_int_t i, ii;
pastix_int_t tasknbr, *tasktab;
MALLOC_INTERN( work1, pastix_imax(datacode->gemmmax, datacode->diagmax),
pastix_complex64_t );
MALLOC_INTERN( work2, datacode->gemmmax, pastix_complex64_t );
tasknbr = datacode->ttsknbr[rank];
tasktab = datacode->ttsktab[rank];
......@@ -68,7 +82,8 @@ thread_pzhetrf( int rank, void *args )
cblk = datacode->cblktab + t->cblknum;
/* Compute */
core_zhetrfsp1d( datacode, cblk, sopalin_data->diagthreshold );
core_zhetrfsp1d( datacode, cblk, sopalin_data->diagthreshold,
work1, work2 );
}
#if defined(PASTIX_DEBUG_FACTO)
......@@ -78,6 +93,9 @@ thread_pzhetrf( int rank, void *args )
}
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
#endif
memFree_null( work1 );
memFree_null( work2 );
}
......
......@@ -32,33 +32,41 @@ void
sequential_zpotrf( pastix_data_t *pastix_data,
sopalin_data_t *sopalin_data )
{
SolverMatrix *datacode = pastix_data->solvmatr;
SolverCblk *cblk;
double threshold = sopalin_data->diagthreshold;
SolverMatrix *datacode = pastix_data->solvmatr;
SolverCblk *cblk;
double threshold = sopalin_data->diagthreshold;
pastix_complex64_t *work;
pastix_int_t i;
(void)pastix_data;
MALLOC_INTERN( work, datacode->gemmmax, pastix_complex64_t );
cblk = datacode->cblktab;
for (i=0; i<datacode->cblknbr; i++, cblk++){
/* Compute */
core_zpotrfsp1d( datacode, cblk, threshold );
core_zpotrfsp1d( datacode, cblk, threshold, work );
}
#if defined(PASTIX_DEBUG_FACTO)
coeftab_zdump( datacode, "potrf_L.txt" );
#endif
memFree_null( work );
}
void
thread_pzpotrf( int rank, void *args )
{
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
SolverCblk *cblk;
Task *t;
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
SolverCblk *cblk;
Task *t;
pastix_complex64_t *work;
pastix_int_t i, ii;
pastix_int_t tasknbr, *tasktab;
MALLOC_INTERN( work, datacode->gemmmax, pastix_complex64_t );
tasknbr = datacode->ttsknbr[rank];
tasktab = datacode->ttsktab[rank];
......@@ -68,7 +76,7 @@ thread_pzpotrf( int rank, void *args )
cblk = datacode->cblktab + t->cblknum;
/* Compute */
core_zpotrfsp1d( datacode, cblk, sopalin_data->diagthreshold );
core_zpotrfsp1d( datacode, cblk, sopalin_data->diagthreshold, work );
}
#if defined(PASTIX_DEBUG_FACTO)
......@@ -78,6 +86,8 @@ thread_pzpotrf( int rank, void *args )
}
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
#endif
memFree_null( work );
}
void
......
......@@ -32,33 +32,47 @@ void
sequential_zsytrf( pastix_data_t *pastix_data,
sopalin_data_t *sopalin_data )
{
SolverMatrix *datacode = pastix_data->solvmatr;
SolverCblk *cblk;
double threshold = sopalin_data->diagthreshold;
SolverMatrix *datacode = pastix_data->solvmatr;
SolverCblk *cblk;
double threshold = sopalin_data->diagthreshold;
pastix_complex64_t *work1, *work2;
pastix_int_t i;
(void)pastix_data;
MALLOC_INTERN( work1, pastix_imax(datacode->gemmmax, datacode->diagmax),
pastix_complex64_t );
MALLOC_INTERN( work2, datacode->gemmmax, pastix_complex64_t );
cblk = datacode->cblktab;
for (i=0; i<datacode->cblknbr; i++, cblk++){
/* Compute */
core_zsytrfsp1d( datacode, cblk, threshold );
core_zsytrfsp1d( datacode, cblk, threshold,
work1, work2 );
}
#if defined(PASTIX_DEBUG_FACTO)
coeftab_zdump( datacode, "sytrf_L.txt" );
#endif
memFree_null( work1 );
memFree_null( work2 );
}
void
thread_pzsytrf( int rank, void *args )
{
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
SolverCblk *cblk;
Task *t;
sopalin_data_t *sopalin_data = (sopalin_data_t*)args;
SolverMatrix *datacode = sopalin_data->solvmtx;
SolverCblk *cblk;
Task *t;
pastix_complex64_t *work1, *work2;
pastix_int_t i, ii;
pastix_int_t tasknbr, *tasktab;
MALLOC_INTERN( work1, pastix_imax(datacode->gemmmax, datacode->diagmax),
pastix_complex64_t );
MALLOC_INTERN( work2, datacode->gemmmax, pastix_complex64_t );
tasknbr = datacode->ttsknbr[rank];
tasktab = datacode->ttsktab[rank];
......@@ -68,7 +82,8 @@ thread_pzsytrf( int rank, void *args )
cblk = datacode->cblktab + t->cblknum;
/* Compute */
core_zsytrfsp1d( datacode, cblk, sopalin_data->diagthreshold );
core_zsytrfsp1d( datacode, cblk, sopalin_data->diagthreshold,
work1, work2 );
}
#if defined(PASTIX_DEBUG_FACTO)
......@@ -78,6 +93,9 @@ thread_pzsytrf( int rank, void *args )
}
isched_barrier_wait( &(((isched_t*)(sopalin_data->sched))->barrier) );
#endif
memFree_null( work1 );
memFree_null( work2 );
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment