diff --git a/src/spm_gather.c b/src/spm_gather.c index e158384df3c7dca0e102db9268345ce06023cdc5..91f231842c5d930029d0fb454272ed57a7a9115d 100644 --- a/src/spm_gather.c +++ b/src/spm_gather.c @@ -84,22 +84,26 @@ spm_gather_check( const spmatrix_t *spm, * the distribution. * * @param[in] spm - * The gathered spm + * The original scattered spm * * @param[in] colptr - * The pointer to the gathered copressed array (spm->colptr, or - * spm->rowptr) + * The pointer to the gathered compressed array (spm->colptr, or + * spm->rowptr) of the new gathered spm. * * @param[in] recvdispls - * The array of reception displacements. + * The array of reception displacements for the n values. + * + * @param[in] recvcounts + * The array of reception count in terms of nnz. */ static inline void spm_gather_csx_update( const spmatrix_t *spm, spm_int_t *colptr, - int *recvdispls ) + int *recvdispls, + int *recvcounts ) { int c; - + spm_int_t to_add = 0; /* * We need to update the compressed array to match the gathered array * @@ -112,11 +116,12 @@ spm_gather_csx_update( const spmatrix_t *spm, for ( c=1; c<spm->clustnbr; c++ ) { /* Let's start shifting the value after the first array */ spm_int_t shift = recvdispls[c]; - spm_int_t end = ( c == spm->clustnbr-1 ) ? spm->n+1 : recvdispls[c+1]; + spm_int_t end = ( c == spm->clustnbr-1 ) ? spm->gN+1 : recvdispls[c+1]; spm_int_t i; + to_add += recvcounts[c-1]; for ( i=shift; i<end; i++ ) { - colptr[i] += shift; + colptr[i] += to_add; } } } @@ -155,6 +160,9 @@ spm_gather_csx_continuous( const spmatrix_t *oldspm, int recv = ( root == -1 ) || ( root == oldspm->clustnum ); int c; + assert( ((newspm != NULL) && recv) || + ((newspm == NULL) && !recv) ); + /* * First step, let's gather the compressed array */ @@ -195,7 +203,12 @@ spm_gather_csx_continuous( const spmatrix_t *oldspm, } if ( recv ) { - spm_gather_csx_update( newspm, newcol, recvdispls ); + /* recvdispls : n, recvcnt : nnz */ + recvcounts[0] = allcounts[1]; /* nnz */ + for( c=1; c<oldspm->clustnbr; c++ ) { + recvcounts[c] = allcounts[ 3 * c + 1 ]; + } + spm_gather_csx_update( oldspm, newcol, recvdispls, recvcounts ); } } @@ -320,6 +333,9 @@ spm_gather_ijv( const spmatrix_t *oldspm, int nnzexp = oldspm->nnzexp; int recv = ( root == -1 ) || ( root == oldspm->clustnum ); + assert( ((newspm != NULL) && recv) || + ((newspm == NULL) && !recv) ); + if ( recv ) { int c; recvcounts = malloc( oldspm->clustnbr * sizeof(int) ); @@ -356,7 +372,7 @@ spm_gather_ijv( const spmatrix_t *oldspm, /* Gather the values */ if ( oldspm->flttype != SpmPattern ) { - /* Update recvcounts and recvdispls arrays if needed */ + /* Update recvcounts and recvdispls arrays if needed to use nnzexp */ if ( recv && (oldspm->dof != 1) ) { int c; @@ -374,7 +390,7 @@ spm_gather_ijv( const spmatrix_t *oldspm, } else { char *newval = recv ? newspm->values : NULL; - MPI_Gatherv( oldspm->values, nnz, valtype, + MPI_Gatherv( oldspm->values, nnzexp, valtype, newval, recvcounts, recvdispls, valtype, root, oldspm->comm ); } @@ -383,11 +399,6 @@ spm_gather_ijv( const spmatrix_t *oldspm, if ( recv ) { free( recvcounts ); free( recvdispls ); - - /* Let's sort if we can */ - if ( (newspm->dof == 1) || (newspm->flttype == SpmPattern) ) { - spmSort( newspm ); - } } } diff --git a/src/spm_scatter.c b/src/spm_scatter.c index 83e94314168db6ca4532596aa78d1fb54b84a222..e1a479fd116285cf755135b909b4078912fdd23c 100644 --- a/src/spm_scatter.c +++ b/src/spm_scatter.c @@ -157,9 +157,12 @@ spm_scatter_csx_get_locals( const spmatrix_t *oldspm, spm_int_t dofj; const spm_int_t *oldcol; const spm_int_t *oldrow; - const spm_int_t *glob2loc = spm_get_glob2loc( newspm, baseval ) - baseval; + const spm_int_t *glob2loc = spm_get_glob2loc( newspm, baseval ); const spm_int_t *dofs; + /* Shift the pointer to avoid extra baseval computations */ + glob2loc -= baseval; + if ( !allcounts ) { spm_int_t counters[3]; @@ -273,9 +276,12 @@ spm_scatter_ijv_get_locals( const spmatrix_t *oldspm, spm_int_t dof2, dofi, dofj; const spm_int_t *oldcol; const spm_int_t *oldrow; - const spm_int_t *glob2loc = spm_get_glob2loc( newspm, baseval ) - baseval; + const spm_int_t *glob2loc = spm_get_glob2loc( newspm, baseval ); const spm_int_t *dofs; + /* Shift the pointer to avoid extra baseval computations */ + glob2loc -= baseval; + if ( !allcounts ) { spm_int_t counters[3]; @@ -301,13 +307,13 @@ spm_scatter_ijv_get_locals( const spmatrix_t *oldspm, { ig = *oldrow; jg = *oldcol; - c = - glob2loc[jg]; + c = glob2loc[jg]; - if ( c <= 0 ) { + if ( c >= 0 ) { c = newspm->clustnum; } else { - c--; + c = (-c - 1); } if ( newspm->dof > 0 ) { @@ -1026,13 +1032,18 @@ spm_scatter_ijv_local( const spmatrix_t *oldspm, spm_int_t *newrow = distByColumn ? newspm->rowptr : newspm->colptr; char *newval = newspm->values; size_t typesize = (newspm->flttype != SpmPattern) ? spm_size_of(newspm->flttype) : 1; - const spm_int_t *dofs = newspm->dofs - baseval; - const spm_int_t *glob2loc = newspm->glob2loc - baseval; /* It has normally already been initialized */ + const spm_int_t *dofs = newspm->dofs; + const spm_int_t *glob2loc = newspm->glob2loc; /* It has normally already been initialized */ spm_int_t kl, kg, ig, jg, nnz; spm_int_t vl, dof2, dofi, dofj; assert( newspm->glob2loc ); + + /* Shift the pointers to avoid extra baseval computations */ + glob2loc -= baseval; + dofs -= baseval; + dof2 = newspm->dof * newspm->dof; vl = 0; kl = 0; @@ -1067,6 +1078,7 @@ spm_scatter_ijv_local( const spmatrix_t *oldspm, if ( newspm->flttype != SpmPattern ) { memcpy( newval, oldval, nnz * typesize ); newval += nnz * typesize; + oldval += nnz * typesize; } vl += nnz; } @@ -1106,13 +1118,17 @@ spm_scatter_ijv_remote( const spmatrix_t *oldspm, spm_int_t *newrow = distByColumn ? newspm->rowptr : newspm->colptr; char *newval = newspm->values; size_t typesize = (newspm->flttype != SpmPattern) ? spm_size_of(newspm->flttype) : 1; - const spm_int_t *dofs = newspm->dofs - baseval; - const spm_int_t *glob2loc = newspm->glob2loc - baseval; /* Must be already initialized */ + const spm_int_t *dofs = newspm->dofs; + const spm_int_t *glob2loc = newspm->glob2loc; /* Must be already initialized */ spm_int_t kl, kg, ig, jg, nnz; spm_int_t vl, dof2, dofi, dofj; assert( newspm->glob2loc ); + /* Shift the pointers to avoid extra baseval computations */ + glob2loc -= baseval; + dofs -= baseval; + dof2 = newspm->dof * newspm->dof; vl = 0; kl = 0; @@ -1147,6 +1163,7 @@ spm_scatter_ijv_remote( const spmatrix_t *oldspm, if ( newspm->flttype != SpmPattern ) { memcpy( newval, oldval, nnz * typesize ); newval += nnz * typesize; + oldval += nnz * typesize; } vl += nnz; } @@ -1410,26 +1427,32 @@ spmScatter( const spmatrix_t *oldspm, /* Check the initial conditions */ if ( local ) { + if ( loc2glob ) { + assert( n >= 0 ); + MPI_Allreduce( &n, &gN, 1, SPM_MPI_INT, + MPI_SUM, comm ); + } + if ( oldspm == NULL ) { spm_print_warning( "[%02d] spmScatter: Missing input matrix\n", clustnum ); rc = 1; + goto reduce; } - if ( loc2glob ) { - MPI_Allreduce( &n, &gN, 1, SPM_MPI_INT, - MPI_SUM, comm ); - if ( gN != oldspm->gN ) - { - spm_print_warning( "[%02d] spmScatter: Incorrect n sum (%ld != %ld)\n", - clustnum, (long)(oldspm->gN), (long)gN ); - rc = 1; - } + + if ( loc2glob && (gN != oldspm->gN) ) { + spm_print_warning( "[%02d] spmScatter: Incorrect n sum (%ld != %ld)\n", + clustnum, (long)(oldspm->gN), (long)gN ); + rc = 1; + goto reduce; } + if ( ( distByColumn && (oldspm->fmttype == SpmCSR) ) || ((!distByColumn) && (oldspm->fmttype == SpmCSC) ) ) { spm_print_warning( "[%02d] spmScatter: Does not support to scatter along the non compressed array in CSC/CSR formats\n", clustnum ); rc = 1; + goto reduce; } if ( (oldspm->fmttype != SpmIJV) && @@ -1440,6 +1463,7 @@ spmScatter( const spmatrix_t *oldspm, clustnum ); rc = 1; } + reduce: MPI_Allreduce( MPI_IN_PLACE, &rc, 1, MPI_INT, MPI_SUM, comm ); if ( rc != 0 ) { diff --git a/tests/spm_compare.c b/tests/spm_compare.c index cde2ec6388c1f113cf177382d291febca554cae0..6ef899ef1ada2b8a821f050889d6c1dd12ffc326 100644 --- a/tests/spm_compare.c +++ b/tests/spm_compare.c @@ -216,38 +216,58 @@ spmCompare( spmatrix_t *spm1, if ( (spm1->loc2glob == NULL) && (spm2->loc2glob == NULL) ) { - spmSort( spm1 ); - spmSort( spm2 ); + spmatrix_t tmpspm1, tmpspm2; + spmatrix_t *spm1ptr, *spm2ptr; - switch( spm2->fmttype ) { + if ( spm1->fmttype == SpmIJV ) + { + spmExpand( spm1, &tmpspm1 ); + spmExpand( spm2, &tmpspm2 ); + spmSort( &tmpspm1 ); + spmSort( &tmpspm2 ); + + spm1ptr = &tmpspm1; + spm2ptr = &tmpspm2; + } + else { + spm1ptr = spm1; + spm2ptr = spm2; + } + + switch( spm2ptr->fmttype ) { case SpmCSC: - rc = spmCompareCSC( spm1, spm2 ); + rc = spmCompareCSC( spm1ptr, spm2ptr ); break; case SpmCSR: - rc = spmCompareCSR( spm1, spm2 ); + rc = spmCompareCSR( spm1ptr, spm2ptr ); break; case SpmIJV: - rc = spmCompareIJV( spm1, spm2 ); + rc = spmCompareIJV( spm1ptr, spm2ptr ); } if ( rc != 0 ) { fprintf( stderr, "[%2d] Incorrect colptr/rowptr arrays\n", - spm2->clustnum ); + spm2ptr->clustnum ); + + if ( spm1->fmttype == SpmIJV ) { + spmExit( &tmpspm1 ); + spmExit( &tmpspm2 ); + } goto end; } - switch( spm2->flttype ) { + switch( spm2ptr->flttype ) { case SpmFloat: - rc = spmCompareFloatArray( spm1->nnzexp, spm1->values, spm2->values ); + rc = spmCompareFloatArray( spm1ptr->nnzexp, spm1ptr->values, spm2ptr->values ); break; case SpmDouble: - rc = spmCompareDoubleArray( spm1->nnzexp, spm1->values, spm2->values ); + rc = spmCompareDoubleArray( spm1ptr->nnzexp, spm1ptr->values, spm2ptr->values ); break; case SpmComplex32: - rc = spmCompareFloatArray( 2 * spm1->nnzexp, spm1->values, spm2->values ); + rc = spmCompareFloatArray( 2 * spm1ptr->nnzexp, spm1ptr->values, spm2ptr->values ); break; case SpmComplex64: - rc = spmCompareDoubleArray( 2 * spm1->nnzexp, spm1->values, spm2->values ); + rc = spmCompareDoubleArray( 2 * spm1ptr->nnzexp, spm1ptr->values, spm2ptr->values ); break; case SpmPattern: default: @@ -256,10 +276,13 @@ spmCompare( spmatrix_t *spm1, if ( rc != 0 ) { fprintf( stderr, "[%2d] Incorrect values arrays\n", - spm2->clustnum ); + spm2ptr->clustnum ); rc = 6; - assert(0); - goto end; + } + + if ( spm1->fmttype == SpmIJV ) { + spmExit( &tmpspm1 ); + spmExit( &tmpspm2 ); } } @@ -267,7 +290,7 @@ spmCompare( spmatrix_t *spm1, #if defined(SPM_WITH_MPI) MPI_Allreduce( MPI_IN_PLACE, &rc, 1, MPI_INT, - MPI_MAX, spm2->comm ); + MPI_MAX, MPI_COMM_WORLD ); #endif return rc; diff --git a/tests/spm_scatter_gather_tests.c b/tests/spm_scatter_gather_tests.c index e4d9df031f1db0a25c0b59d57817e257bd89fa61..f9c4ec407a060c0fcd939263cc21d882555269d9 100644 --- a/tests/spm_scatter_gather_tests.c +++ b/tests/spm_scatter_gather_tests.c @@ -75,16 +75,17 @@ spmdist_check_scatter_gather( spmatrix_t *original, int root, int clustnum ) { - const char *dofname[] = { "None", "Constant", "Variadic" }; + const char *dofname[] = { "None", "Constant", "Variadic" }; + const char *distname[] = { "Round-Robin", "Continuous "}; spmatrix_t *spms = NULL; spmatrix_t *spmg = NULL; int rc = 0; int local = (root == -1) || (root == clustnum); if ( clustnum == 0 ) { - fprintf( stdout, "type(%s) - dof(%s) - base(%d) - distByColumn(%d) - root(%d): ", + fprintf( stdout, "type(%s) - dof(%s) - base(%d) - distByColumn(%d) - root(%d) - loc2glob(%s): ", fmtnames[fmttype], dofname[dof+1], - (int)baseval, distByColumn, root ); + (int)baseval, distByColumn, root, distname[loc2glob == NULL] ); } if ( local ) { @@ -136,6 +137,8 @@ spmdist_check_scatter_gather( spmatrix_t *original, if ( spmdist_check( clustnum, spms == NULL, "Failed to generate an spm on each node" ) ) { + spmExit( spms ); + free( spms ); return 1; } @@ -153,10 +156,17 @@ spmdist_check_scatter_gather( spmatrix_t *original, * Check spmGather */ spmg = spmGather( spms, root ); + if ( spms ) { + spmExit( spms ); + free( spms ); + } /* Check non supported cases by Gather */ { - if ( (loc2glob != NULL) && (spms->fmttype != SpmIJV) ) + if ( ( original != NULL ) && + ( original->clustnbr > 1 ) && + ( loc2glob != NULL ) && + ( original->fmttype != SpmIJV ) ) { if ( spmg != NULL ) { rc = 2; /* Error */ @@ -168,10 +178,6 @@ spmdist_check_scatter_gather( spmatrix_t *original, MPI_Allreduce( MPI_IN_PLACE, &rc, 1, MPI_INT, MPI_MAX, MPI_COMM_WORLD ); if ( rc != 0 ) { - if ( spms ) { - spmExit( spms ); - free( spms ); - } if ( spmg ) { spmExit( spmg ); free( spmg ); @@ -190,30 +196,35 @@ spmdist_check_scatter_gather( spmatrix_t *original, } } } - rc = 0; + rc = ( ( local && (spmg == NULL)) || + (!local && (spmg != NULL)) ); + MPI_Allreduce( MPI_IN_PLACE, &rc, 1, MPI_INT, + MPI_MAX, MPI_COMM_WORLD ); /* Check the correct case */ - if ( spmdist_check( clustnum, - ( ( local && (spmg == NULL)) || - (!local && (spmg != NULL)) ), + if ( spmdist_check( clustnum, rc, "Failed to gather the spm correctly" ) ) { + if ( spmg ) { + spmExit( spmg ); + free( spmg ); + } return 1; } /* Compare the matrices */ - rc = spmCompare( spmg, spms ); + rc = spmCompare( original, spmg ); if ( spmdist_check( clustnum, rc, "The gathered spm does not match the original spm" ) ) { + if ( spmg ) { + spmExit( spmg ); + free( spmg ); + } return 1; } /* Cleanup */ - if ( spms ) { - spmExit( spms ); - free( spms ); - } if ( spmg ) { spmExit( spmg ); free( spmg ); @@ -272,12 +283,11 @@ int main( int argc, char **argv ) * - The scattered matrix is gathered on all nodes and compared against the * original one */ - spm = &original; for( fmttype=SpmCSC; fmttype<=SpmIJV; fmttype++ ) { if ( spmConvert( fmttype, &original ) != SPM_SUCCESS ) { - fprintf( stderr, "Issue to convert to %d format\n", fmttype ); - return EXIT_FAILURE; + fprintf( stderr, "Issue to convert to %s format\n", fmtnames[fmttype] ); + continue; } for( dof=-1; dof<2; dof++ ) @@ -289,6 +299,10 @@ int main( int argc, char **argv ) spm = &original; } + if ( spm == NULL ) { + continue; + } + for( root=-1; root<clustnbr; root++ ) { /* Make sure we don't give an input spm */