From 20929b52450da275467af5a875fac39a3a2c1f23 Mon Sep 17 00:00:00 2001 From: Mathieu Faverge <mathieu.faverge@inria.fr> Date: Tue, 25 Oct 2016 13:57:24 +0200 Subject: [PATCH] Move spmExpand to spm.c --- spm.c | 77 +++++++++++++++++++++++++++++++++++++++++++++++--- spm_expand.c | 21 -------------- z_spm_expand.c | 2 +- 3 files changed, 74 insertions(+), 26 deletions(-) diff --git a/spm.c b/spm.c index 54e3afc7..d8f954ef 100644 --- a/spm.c +++ b/spm.c @@ -85,7 +85,7 @@ void spmInit( pastix_spm_t *spm ) { spm->mtxtype = PastixGeneral; - spm->flttype = PastixComplex64; + spm->flttype = PastixDouble; spm->fmttype = PastixCSC; spm->gN = 0; @@ -183,7 +183,9 @@ spmBase( pastix_spm_t *spm, return; n = spm->n; - nnz = spm->colptr[n] - spm->colptr[0]; + nnz = spm->nnz; + + assert( nnz == (spm->colptr[n] - spm->colptr[0]) ); for (i = 0; i <= n; i++) { spm->colptr[i] += baseadj; @@ -272,6 +274,9 @@ spmFindBase( const pastix_spm_t *spm ) int spmConvert( int ofmttype, pastix_spm_t *spm ) { + if ( spm->dof != 1 ) { + spm = spmExpand( spm ); + } if ( conversionTable[spm->fmttype][ofmttype][spm->flttype] ) { return conversionTable[spm->fmttype][ofmttype][spm->flttype]( spm ); } @@ -331,6 +336,10 @@ spmNorm( int ntype, { double tmp; + if ( spm->dof != 1 ) { + fprintf(stderr, "WARNING: spm expanded due to non implemented norm for non-expanded spm\n"); + spm = spmExpand( spm ); + } switch (spm->flttype) { case PastixFloat: tmp = (double)s_spmNorm( ntype, spm ); @@ -381,6 +390,10 @@ spmNorm( int ntype, int spmSort( pastix_spm_t *spm ) { + if ( spm->dof != 1 ) { + fprintf(stderr, "WARNING: spm expanded due to non implemented sort for non-expanded spm\n"); + spm = spmExpand( spm ); + } switch (spm->flttype) { case PastixPattern: p_spmSort( spm ); @@ -433,6 +446,10 @@ spmSort( pastix_spm_t *spm ) pastix_int_t spmMergeDuplicate( pastix_spm_t *spm ) { + if ( spm->dof != 1 ) { + fprintf(stderr, "WARNING: spm expanded due to non implemented merge for non-expanded spm\n"); + spm = spmExpand( spm ); + } switch (spm->flttype) { case PastixPattern: return p_spmMergeDuplicate( spm ); @@ -484,6 +501,10 @@ spmMergeDuplicate( pastix_spm_t *spm ) pastix_int_t spmSymmetrize( pastix_spm_t *spm ) { + if ( spm->dof != 1 ) { + fprintf(stderr, "WARNING: spm expanded due to non implemented symmetrize for non-expanded spm\n"); + spm = spmExpand( spm ); + } switch (spm->flttype) { case PastixPattern: return p_spmSymmetrize( spm ); @@ -551,6 +572,11 @@ spmCheckAndCorrect( pastix_spm_t *spm ) /* PaStiX works on CSC matrices */ spmConvert( PastixCSC, newspm ); + if ( newspm->dof != 1 ) { + fprintf(stderr, "WARNING: newspm expanded due to missing check functions implementations\n"); + newspm = spmExpand( newspm ); + } + /* Sort the rowptr for each column */ spmSort( newspm ); @@ -580,7 +606,7 @@ spmCheckAndCorrect( pastix_spm_t *spm ) * have been made */ if (( spm->fmttype != newspm->fmttype ) || - ( spm->nnz != newspm->nnz ) ) + ( spm->nnzexp != newspm->nnzexp ) ) { return newspm; } @@ -633,13 +659,56 @@ spmCopy( const pastix_spm_t *spm ) memcpy( newspm->loc2glob, spm->loc2glob, spm->n * sizeof(pastix_int_t)); } if(spm->values != NULL) { - size_t valsize = spm->nnz * pastix_size_of( spm->flttype ); + size_t valsize = spm->nnzexp * pastix_size_of( spm->flttype ); newspm->values = malloc(valsize); memcpy( newspm->values, spm->values, valsize); } return newspm; } +/** + ******************************************************************************* + * + * @ingroup pastix_spm + * + * @brief Expand a multi-dof spm matrix into an spm with constant dof to 1. + * + * Duplicate the spm data structure given as parameter. All new arrays are + * allocated and copied from the original matrix. Both matrices need to be + * freed. + * + ******************************************************************************* + * + * @param[in] spm + * The sparse matrix to copy. + * + ******************************************************************************* + * + * @return + * The copy of the sparse matrix. + * + *******************************************************************************/ +void +spmExpand(pastix_spm_t* spm) +{ + switch(spm->flttype) + { + case PastixFloat: + s_spmExpand(spm); + break; + case PastixComplex32: + c_spmExpand(spm); + break; + case PastixComplex64: + z_spmExpand(spm); + break; + case PastixDouble: + default: + d_spmExpand(spm); + break; + } +} + /** ******************************************************************************* * diff --git a/spm_expand.c b/spm_expand.c index ac66de23..2957feb9 100644 --- a/spm_expand.c +++ b/spm_expand.c @@ -22,27 +22,6 @@ #include "s_spm.h" #include "p_spm.h" -void -spmExpand(pastix_spm_t* spm) -{ - switch(spm->flttype) - { - case PastixFloat: - s_spmExpand(spm); - break; - case PastixComplex32: - c_spmExpand(spm); - break; - case PastixComplex64: - z_spmExpand(spm); - break; - case PastixDouble: - default: - d_spmExpand(spm); - break; - } -} - void print_tab_int(pastix_int_t* tab, pastix_int_t size) { int i; diff --git a/z_spm_expand.c b/z_spm_expand.c index 3e461691..672b9222 100644 --- a/z_spm_expand.c +++ b/z_spm_expand.c @@ -21,7 +21,7 @@ * TODO: This function is incorrect */ int -z_spmExpand(pastix_spm_t* spm) +z_spmExpand(pastix_spm_t *spm) { pastix_int_t i, col, row, cpt, dofj, dofi, baseval; pastix_complex64_t *oldvalptr; -- GitLab