Mentions légales du service

Skip to content
Snippets Groups Projects
Commit cde51f77 authored by Tony Delarue's avatar Tony Delarue Committed by Mathieu Faverge
Browse files

Hotfix/spm scatter

parent 0d1c5bf8
No related branches found
No related tags found
No related merge requests found
......@@ -84,22 +84,26 @@ spm_gather_check( const spmatrix_t *spm,
* the distribution.
*
* @param[in] spm
* The gathered spm
* The original scattered spm
*
* @param[in] colptr
* The pointer to the gathered copressed array (spm->colptr, or
* spm->rowptr)
* The pointer to the gathered compressed array (spm->colptr, or
* spm->rowptr) of the new gathered spm.
*
* @param[in] recvdispls
* The array of reception displacements.
* The array of reception displacements for the n values.
*
* @param[in] recvcounts
* The array of reception count in terms of nnz.
*/
static inline void
spm_gather_csx_update( const spmatrix_t *spm,
spm_int_t *colptr,
int *recvdispls )
int *recvdispls,
int *recvcounts )
{
int c;
spm_int_t to_add = 0;
/*
* We need to update the compressed array to match the gathered array
*
......@@ -112,11 +116,12 @@ spm_gather_csx_update( const spmatrix_t *spm,
for ( c=1; c<spm->clustnbr; c++ ) {
/* Let's start shifting the value after the first array */
spm_int_t shift = recvdispls[c];
spm_int_t end = ( c == spm->clustnbr-1 ) ? spm->n+1 : recvdispls[c+1];
spm_int_t end = ( c == spm->clustnbr-1 ) ? spm->gN+1 : recvdispls[c+1];
spm_int_t i;
to_add += recvcounts[c-1];
for ( i=shift; i<end; i++ ) {
colptr[i] += shift;
colptr[i] += to_add;
}
}
}
......@@ -155,6 +160,9 @@ spm_gather_csx_continuous( const spmatrix_t *oldspm,
int recv = ( root == -1 ) || ( root == oldspm->clustnum );
int c;
assert( ((newspm != NULL) && recv) ||
((newspm == NULL) && !recv) );
/*
* First step, let's gather the compressed array
*/
......@@ -195,7 +203,12 @@ spm_gather_csx_continuous( const spmatrix_t *oldspm,
}
if ( recv ) {
spm_gather_csx_update( newspm, newcol, recvdispls );
/* recvdispls : n, recvcnt : nnz */
recvcounts[0] = allcounts[1]; /* nnz */
for( c=1; c<oldspm->clustnbr; c++ ) {
recvcounts[c] = allcounts[ 3 * c + 1 ];
}
spm_gather_csx_update( oldspm, newcol, recvdispls, recvcounts );
}
}
......@@ -320,6 +333,9 @@ spm_gather_ijv( const spmatrix_t *oldspm,
int nnzexp = oldspm->nnzexp;
int recv = ( root == -1 ) || ( root == oldspm->clustnum );
assert( ((newspm != NULL) && recv) ||
((newspm == NULL) && !recv) );
if ( recv ) {
int c;
recvcounts = malloc( oldspm->clustnbr * sizeof(int) );
......@@ -356,7 +372,7 @@ spm_gather_ijv( const spmatrix_t *oldspm,
/* Gather the values */
if ( oldspm->flttype != SpmPattern ) {
/* Update recvcounts and recvdispls arrays if needed */
/* Update recvcounts and recvdispls arrays if needed to use nnzexp */
if ( recv && (oldspm->dof != 1) ) {
int c;
......@@ -374,7 +390,7 @@ spm_gather_ijv( const spmatrix_t *oldspm,
}
else {
char *newval = recv ? newspm->values : NULL;
MPI_Gatherv( oldspm->values, nnz, valtype,
MPI_Gatherv( oldspm->values, nnzexp, valtype,
newval, recvcounts, recvdispls, valtype,
root, oldspm->comm );
}
......@@ -383,11 +399,6 @@ spm_gather_ijv( const spmatrix_t *oldspm,
if ( recv ) {
free( recvcounts );
free( recvdispls );
/* Let's sort if we can */
if ( (newspm->dof == 1) || (newspm->flttype == SpmPattern) ) {
spmSort( newspm );
}
}
}
......
......@@ -157,9 +157,12 @@ spm_scatter_csx_get_locals( const spmatrix_t *oldspm,
spm_int_t dofj;
const spm_int_t *oldcol;
const spm_int_t *oldrow;
const spm_int_t *glob2loc = spm_get_glob2loc( newspm, baseval ) - baseval;
const spm_int_t *glob2loc = spm_get_glob2loc( newspm, baseval );
const spm_int_t *dofs;
/* Shift the pointer to avoid extra baseval computations */
glob2loc -= baseval;
if ( !allcounts ) {
spm_int_t counters[3];
......@@ -273,9 +276,12 @@ spm_scatter_ijv_get_locals( const spmatrix_t *oldspm,
spm_int_t dof2, dofi, dofj;
const spm_int_t *oldcol;
const spm_int_t *oldrow;
const spm_int_t *glob2loc = spm_get_glob2loc( newspm, baseval ) - baseval;
const spm_int_t *glob2loc = spm_get_glob2loc( newspm, baseval );
const spm_int_t *dofs;
/* Shift the pointer to avoid extra baseval computations */
glob2loc -= baseval;
if ( !allcounts ) {
spm_int_t counters[3];
......@@ -301,13 +307,13 @@ spm_scatter_ijv_get_locals( const spmatrix_t *oldspm,
{
ig = *oldrow;
jg = *oldcol;
c = - glob2loc[jg];
c = glob2loc[jg];
if ( c <= 0 ) {
if ( c >= 0 ) {
c = newspm->clustnum;
}
else {
c--;
c = (-c - 1);
}
if ( newspm->dof > 0 ) {
......@@ -1026,13 +1032,18 @@ spm_scatter_ijv_local( const spmatrix_t *oldspm,
spm_int_t *newrow = distByColumn ? newspm->rowptr : newspm->colptr;
char *newval = newspm->values;
size_t typesize = (newspm->flttype != SpmPattern) ? spm_size_of(newspm->flttype) : 1;
const spm_int_t *dofs = newspm->dofs - baseval;
const spm_int_t *glob2loc = newspm->glob2loc - baseval; /* It has normally already been initialized */
const spm_int_t *dofs = newspm->dofs;
const spm_int_t *glob2loc = newspm->glob2loc; /* It has normally already been initialized */
spm_int_t kl, kg, ig, jg, nnz;
spm_int_t vl, dof2, dofi, dofj;
assert( newspm->glob2loc );
/* Shift the pointers to avoid extra baseval computations */
glob2loc -= baseval;
dofs -= baseval;
dof2 = newspm->dof * newspm->dof;
vl = 0;
kl = 0;
......@@ -1067,6 +1078,7 @@ spm_scatter_ijv_local( const spmatrix_t *oldspm,
if ( newspm->flttype != SpmPattern ) {
memcpy( newval, oldval, nnz * typesize );
newval += nnz * typesize;
oldval += nnz * typesize;
}
vl += nnz;
}
......@@ -1106,13 +1118,17 @@ spm_scatter_ijv_remote( const spmatrix_t *oldspm,
spm_int_t *newrow = distByColumn ? newspm->rowptr : newspm->colptr;
char *newval = newspm->values;
size_t typesize = (newspm->flttype != SpmPattern) ? spm_size_of(newspm->flttype) : 1;
const spm_int_t *dofs = newspm->dofs - baseval;
const spm_int_t *glob2loc = newspm->glob2loc - baseval; /* Must be already initialized */
const spm_int_t *dofs = newspm->dofs;
const spm_int_t *glob2loc = newspm->glob2loc; /* Must be already initialized */
spm_int_t kl, kg, ig, jg, nnz;
spm_int_t vl, dof2, dofi, dofj;
assert( newspm->glob2loc );
/* Shift the pointers to avoid extra baseval computations */
glob2loc -= baseval;
dofs -= baseval;
dof2 = newspm->dof * newspm->dof;
vl = 0;
kl = 0;
......@@ -1147,6 +1163,7 @@ spm_scatter_ijv_remote( const spmatrix_t *oldspm,
if ( newspm->flttype != SpmPattern ) {
memcpy( newval, oldval, nnz * typesize );
newval += nnz * typesize;
oldval += nnz * typesize;
}
vl += nnz;
}
......@@ -1410,26 +1427,32 @@ spmScatter( const spmatrix_t *oldspm,
/* Check the initial conditions */
if ( local ) {
if ( loc2glob ) {
assert( n >= 0 );
MPI_Allreduce( &n, &gN, 1, SPM_MPI_INT,
MPI_SUM, comm );
}
if ( oldspm == NULL ) {
spm_print_warning( "[%02d] spmScatter: Missing input matrix\n", clustnum );
rc = 1;
goto reduce;
}
if ( loc2glob ) {
MPI_Allreduce( &n, &gN, 1, SPM_MPI_INT,
MPI_SUM, comm );
if ( gN != oldspm->gN )
{
spm_print_warning( "[%02d] spmScatter: Incorrect n sum (%ld != %ld)\n",
clustnum, (long)(oldspm->gN), (long)gN );
rc = 1;
}
if ( loc2glob && (gN != oldspm->gN) ) {
spm_print_warning( "[%02d] spmScatter: Incorrect n sum (%ld != %ld)\n",
clustnum, (long)(oldspm->gN), (long)gN );
rc = 1;
goto reduce;
}
if ( ( distByColumn && (oldspm->fmttype == SpmCSR) ) ||
((!distByColumn) && (oldspm->fmttype == SpmCSC) ) )
{
spm_print_warning( "[%02d] spmScatter: Does not support to scatter along the non compressed array in CSC/CSR formats\n",
clustnum );
rc = 1;
goto reduce;
}
if ( (oldspm->fmttype != SpmIJV) &&
......@@ -1440,6 +1463,7 @@ spmScatter( const spmatrix_t *oldspm,
clustnum );
rc = 1;
}
reduce:
MPI_Allreduce( MPI_IN_PLACE, &rc, 1, MPI_INT,
MPI_SUM, comm );
if ( rc != 0 ) {
......
......@@ -216,38 +216,58 @@ spmCompare( spmatrix_t *spm1,
if ( (spm1->loc2glob == NULL) &&
(spm2->loc2glob == NULL) )
{
spmSort( spm1 );
spmSort( spm2 );
spmatrix_t tmpspm1, tmpspm2;
spmatrix_t *spm1ptr, *spm2ptr;
switch( spm2->fmttype ) {
if ( spm1->fmttype == SpmIJV )
{
spmExpand( spm1, &tmpspm1 );
spmExpand( spm2, &tmpspm2 );
spmSort( &tmpspm1 );
spmSort( &tmpspm2 );
spm1ptr = &tmpspm1;
spm2ptr = &tmpspm2;
}
else {
spm1ptr = spm1;
spm2ptr = spm2;
}
switch( spm2ptr->fmttype ) {
case SpmCSC:
rc = spmCompareCSC( spm1, spm2 );
rc = spmCompareCSC( spm1ptr, spm2ptr );
break;
case SpmCSR:
rc = spmCompareCSR( spm1, spm2 );
rc = spmCompareCSR( spm1ptr, spm2ptr );
break;
case SpmIJV:
rc = spmCompareIJV( spm1, spm2 );
rc = spmCompareIJV( spm1ptr, spm2ptr );
}
if ( rc != 0 ) {
fprintf( stderr, "[%2d] Incorrect colptr/rowptr arrays\n",
spm2->clustnum );
spm2ptr->clustnum );
if ( spm1->fmttype == SpmIJV ) {
spmExit( &tmpspm1 );
spmExit( &tmpspm2 );
}
goto end;
}
switch( spm2->flttype ) {
switch( spm2ptr->flttype ) {
case SpmFloat:
rc = spmCompareFloatArray( spm1->nnzexp, spm1->values, spm2->values );
rc = spmCompareFloatArray( spm1ptr->nnzexp, spm1ptr->values, spm2ptr->values );
break;
case SpmDouble:
rc = spmCompareDoubleArray( spm1->nnzexp, spm1->values, spm2->values );
rc = spmCompareDoubleArray( spm1ptr->nnzexp, spm1ptr->values, spm2ptr->values );
break;
case SpmComplex32:
rc = spmCompareFloatArray( 2 * spm1->nnzexp, spm1->values, spm2->values );
rc = spmCompareFloatArray( 2 * spm1ptr->nnzexp, spm1ptr->values, spm2ptr->values );
break;
case SpmComplex64:
rc = spmCompareDoubleArray( 2 * spm1->nnzexp, spm1->values, spm2->values );
rc = spmCompareDoubleArray( 2 * spm1ptr->nnzexp, spm1ptr->values, spm2ptr->values );
break;
case SpmPattern:
default:
......@@ -256,10 +276,13 @@ spmCompare( spmatrix_t *spm1,
if ( rc != 0 ) {
fprintf( stderr, "[%2d] Incorrect values arrays\n",
spm2->clustnum );
spm2ptr->clustnum );
rc = 6;
assert(0);
goto end;
}
if ( spm1->fmttype == SpmIJV ) {
spmExit( &tmpspm1 );
spmExit( &tmpspm2 );
}
}
......@@ -267,7 +290,7 @@ spmCompare( spmatrix_t *spm1,
#if defined(SPM_WITH_MPI)
MPI_Allreduce( MPI_IN_PLACE, &rc, 1, MPI_INT,
MPI_MAX, spm2->comm );
MPI_MAX, MPI_COMM_WORLD );
#endif
return rc;
......
......@@ -75,16 +75,17 @@ spmdist_check_scatter_gather( spmatrix_t *original,
int root,
int clustnum )
{
const char *dofname[] = { "None", "Constant", "Variadic" };
const char *dofname[] = { "None", "Constant", "Variadic" };
const char *distname[] = { "Round-Robin", "Continuous "};
spmatrix_t *spms = NULL;
spmatrix_t *spmg = NULL;
int rc = 0;
int local = (root == -1) || (root == clustnum);
if ( clustnum == 0 ) {
fprintf( stdout, "type(%s) - dof(%s) - base(%d) - distByColumn(%d) - root(%d): ",
fprintf( stdout, "type(%s) - dof(%s) - base(%d) - distByColumn(%d) - root(%d) - loc2glob(%s): ",
fmtnames[fmttype], dofname[dof+1],
(int)baseval, distByColumn, root );
(int)baseval, distByColumn, root, distname[loc2glob == NULL] );
}
if ( local ) {
......@@ -136,6 +137,8 @@ spmdist_check_scatter_gather( spmatrix_t *original,
if ( spmdist_check( clustnum, spms == NULL,
"Failed to generate an spm on each node" ) )
{
spmExit( spms );
free( spms );
return 1;
}
......@@ -153,10 +156,17 @@ spmdist_check_scatter_gather( spmatrix_t *original,
* Check spmGather
*/
spmg = spmGather( spms, root );
if ( spms ) {
spmExit( spms );
free( spms );
}
/* Check non supported cases by Gather */
{
if ( (loc2glob != NULL) && (spms->fmttype != SpmIJV) )
if ( ( original != NULL ) &&
( original->clustnbr > 1 ) &&
( loc2glob != NULL ) &&
( original->fmttype != SpmIJV ) )
{
if ( spmg != NULL ) {
rc = 2; /* Error */
......@@ -168,10 +178,6 @@ spmdist_check_scatter_gather( spmatrix_t *original,
MPI_Allreduce( MPI_IN_PLACE, &rc, 1, MPI_INT,
MPI_MAX, MPI_COMM_WORLD );
if ( rc != 0 ) {
if ( spms ) {
spmExit( spms );
free( spms );
}
if ( spmg ) {
spmExit( spmg );
free( spmg );
......@@ -190,30 +196,35 @@ spmdist_check_scatter_gather( spmatrix_t *original,
}
}
}
rc = 0;
rc = ( ( local && (spmg == NULL)) ||
(!local && (spmg != NULL)) );
MPI_Allreduce( MPI_IN_PLACE, &rc, 1, MPI_INT,
MPI_MAX, MPI_COMM_WORLD );
/* Check the correct case */
if ( spmdist_check( clustnum,
( ( local && (spmg == NULL)) ||
(!local && (spmg != NULL)) ),
if ( spmdist_check( clustnum, rc,
"Failed to gather the spm correctly" ) )
{
if ( spmg ) {
spmExit( spmg );
free( spmg );
}
return 1;
}
/* Compare the matrices */
rc = spmCompare( spmg, spms );
rc = spmCompare( original, spmg );
if ( spmdist_check( clustnum, rc,
"The gathered spm does not match the original spm" ) )
{
if ( spmg ) {
spmExit( spmg );
free( spmg );
}
return 1;
}
/* Cleanup */
if ( spms ) {
spmExit( spms );
free( spms );
}
if ( spmg ) {
spmExit( spmg );
free( spmg );
......@@ -272,12 +283,11 @@ int main( int argc, char **argv )
* - The scattered matrix is gathered on all nodes and compared against the
* original one
*/
spm = &original;
for( fmttype=SpmCSC; fmttype<=SpmIJV; fmttype++ )
{
if ( spmConvert( fmttype, &original ) != SPM_SUCCESS ) {
fprintf( stderr, "Issue to convert to %d format\n", fmttype );
return EXIT_FAILURE;
fprintf( stderr, "Issue to convert to %s format\n", fmtnames[fmttype] );
continue;
}
for( dof=-1; dof<2; dof++ )
......@@ -289,6 +299,10 @@ int main( int argc, char **argv )
spm = &original;
}
if ( spm == NULL ) {
continue;
}
for( root=-1; root<clustnbr; root++ )
{
/* Make sure we don't give an input spm */
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment