From 88bc45f538245f9574f6643cf63c332f0a55c0dc Mon Sep 17 00:00:00 2001
From: tdelarue <tony.delarue@inria.fr>
Date: Wed, 1 Jul 2020 16:14:36 +0200
Subject: [PATCH] spmm and spmv now support multidof

---
 include/spm.h              |   1 -
 include/spm/const.h        |   6 +
 src/common.h               |   1 +
 src/spm.c                  | 141 ++++++-------
 src/z_spm.h                |   2 +-
 src/z_spm_convert_to_csc.c |  25 +--
 src/z_spm_convert_to_csr.c |  23 +--
 src/z_spm_matrixvector.c   | 402 +++++++++++++++++++++++++++----------
 src/z_spm_reduce_rhs.c     |   8 +-
 9 files changed, 387 insertions(+), 222 deletions(-)

diff --git a/include/spm.h b/include/spm.h
index 5d52866c..ed42fa9f 100644
--- a/include/spm.h
+++ b/include/spm.h
@@ -116,7 +116,6 @@ spmatrix_t *spmScatter( const spmatrix_t *spm,
                         SPM_Comm          comm );
 spmatrix_t *spmGather ( const spmatrix_t *spm,
                               int         root );
-int spmGetDistribution( const spmatrix_t *spm );
 
 /**
  * @}
diff --git a/include/spm/const.h b/include/spm/const.h
index 01f9de53..16815f39 100644
--- a/include/spm/const.h
+++ b/include/spm/const.h
@@ -32,6 +32,12 @@ BEGIN_C_DECLS
 #define CBLAS_SADDR( a_ ) (&(a_))
 #endif
 
+/**
+ * @brief Distribution of the matrix storage
+ */
+#define SpmDistByColumn (0x1 << 0) /**< Storage in column distributed */
+#define SpmDistByRow    (0x1 << 1) /**< Storage in row distributed    */
+
 /**
  * @brief Verbose modes
  */
diff --git a/src/common.h b/src/common.h
index 4abadf79..bce115b6 100644
--- a/src/common.h
+++ b/src/common.h
@@ -54,6 +54,7 @@ spm_get_datatype( const spmatrix_t *spm )
 #endif
 
 spm_int_t *spm_get_glob2loc( spmatrix_t *spm, spm_int_t baseval );
+int        spm_get_distribution( const spmatrix_t *spm );
 
 /********************************************************************
  * Conjuguate/Id functions
diff --git a/src/spm.c b/src/spm.c
index 31ab2b36..b42538ca 100644
--- a/src/spm.c
+++ b/src/spm.c
@@ -128,70 +128,6 @@ spmInitDist( spmatrix_t *spm, SPM_Comm comm )
 #endif /* defined(SPM_WITH_MPI) */
 }
 
-/**
- *******************************************************************************
- *
- * @brief Search the distribution pattern used in the spm structure.
- *
- *******************************************************************************
- *
- * @param[in] spm
- *          The sparse matrix structure.
- *
- ********************************************************************************
- *
- * @return  1 if the distribution is column based.
- *          0 otherwise.
- *
- *******************************************************************************/
-int
-spmGetDistribution( const spmatrix_t *spm )
-{
-    int distribution = 1;
-
-    if( spm->fmttype == SpmCSC ){
-        distribution = 1;
-    }
-    else if ( spm->fmttype == SpmCSR ) {
-        distribution = 0;
-    }
-    else {
-        spm_int_t  i, baseval;
-        spm_int_t *colptr   = spm->colptr;
-        spm_int_t *glob2loc = spm->glob2loc;
-
-        baseval = spmFindBase( spm );
-        assert( glob2loc != NULL );
-        for ( i = 0; i < spm->nnz; i++, colptr++ )
-        {
-            /*
-             * If the global index is not in the local colptr
-             * -> row distribution
-             */
-            if( glob2loc[ *colptr - baseval  ] < 0 ) {
-                distribution = 0;
-                break;
-            }
-        }
-
-#if defined(SPM_WITH_MPI)
-        {
-            int check = 0;
-            MPI_Allreduce( &distribution, &check, 1, MPI_INT,
-                           MPI_SUM, spm->comm );
-            if( distribution == 0) {
-                assert( check == 0 );
-            }
-            else {
-                assert( check == spm->clustnbr );
-            }
-        }
-#endif
-    }
-
-    return distribution;
-}
-
 /**
  *******************************************************************************
  *
@@ -1308,10 +1244,6 @@ spmMatMat(       spm_trans_t trans,
     spmatrix_t *espm = (spmatrix_t*)A;
     int rc = SPM_SUCCESS;
 
-    if ( A->dof != 1 ) {
-        espm = malloc( sizeof(spmatrix_t) );
-        spmExpand( A, espm );
-    }
     switch (A->flttype) {
     case SpmFloat:
         rc = spm_sspmm( SpmLeft, trans, SpmNoTrans, n, alpha, espm, B, ldb, beta, C, ldc );
@@ -1682,3 +1614,76 @@ spm_get_glob2loc( spmatrix_t *spm,
     (void) baseval;
     return spm->glob2loc;
 }
+
+/**
+ *******************************************************************************
+ *
+ * @ingroup spm_mpi_dev
+ *
+ * @brief Search the distribution pattern used in the spm structure.
+ *
+ *******************************************************************************
+ *
+ * @param[in] spm
+ *          The sparse matrix structure.
+ *
+ ********************************************************************************
+ *
+ * @return  SpmDistByColumn if the distribution is column based.
+ *          SpmDistByRow if the distribution is row based.
+ *          (SpmDistByColumn|SpmDistByRow) if the matrix is not distributed.
+ *
+ *******************************************************************************/
+int
+spm_get_distribution( const spmatrix_t *spm )
+{
+    int distribution = 0;
+
+    if( (spm->loc2glob == NULL) || (spm->n == spm->gN) ) {
+        distribution = ( SpmDistByColumn | SpmDistByRow );
+    }
+    else {
+        if( spm->fmttype == SpmCSC ){
+            distribution = SpmDistByColumn;
+        }
+        else if ( spm->fmttype == SpmCSR ) {
+            distribution = SpmDistByRow;
+        }
+        else {
+            spm_int_t  i, baseval;
+            spm_int_t *colptr   = spm->colptr;
+            spm_int_t *glob2loc = spm->glob2loc;
+
+            baseval = spmFindBase( spm );
+            distribution = 1;
+            assert( glob2loc != NULL );
+            for ( i = 0; i < spm->nnz; i++, colptr++ )
+            {
+                /*
+                * If the global index is not in the local colptr
+                * -> row distribution
+                */
+                if( glob2loc[ *colptr - baseval  ] < 0 ) {
+                    distribution = SpmDistByRow;
+                    break;
+                }
+            }
+
+    #if defined(SPM_WITH_MPI)
+            {
+                int check = 0;
+                MPI_Allreduce( &distribution, &check, 1, MPI_INT,
+                               MPI_BOR, spm->comm );
+                /*
+                 * If a matrix is distributed
+                 * it cannot be distributed by row AND column
+                 */
+                assert( check != ( SpmDistByColumn | SpmDistByRow ) );
+                assert( distribution == check );
+            }
+    #endif
+        }
+    }
+    assert(distribution > 0);
+    return distribution;
+}
diff --git a/src/z_spm.h b/src/z_spm.h
index 5e0f6860..cb2089c2 100644
--- a/src/z_spm.h
+++ b/src/z_spm.h
@@ -76,7 +76,7 @@ spm_int_t z_spmSymmetrize( spmatrix_t *spm );
 int              z_spmGenRHS(spm_rhstype_t type, int nrhs, const spmatrix_t *spm, void *x, int ldx, void *b, int ldb );
 int              z_spmCheckAxb( spm_fixdbl_t eps, int nrhs, const spmatrix_t *spm, void *x0, int ldx0, void *b, int ldb, const void *x, int ldx );
 spm_complex64_t *z_spmGatherRHS( const spmatrix_t *spm, int nrhs, const spm_complex64_t *x, spm_int_t ldx, int root );
-void             z_spmReduceRhs( const spmatrix_t *spm, int nrhs, spm_complex64_t *bglob, spm_complex64_t *b, spm_int_t ldb );
+void             z_spmReduceRHS( const spmatrix_t *spm, int nrhs, spm_complex64_t *bglob, spm_int_t ldbglob, spm_complex64_t *b, spm_int_t ldb );
 
 /**
  * Output routines
diff --git a/src/z_spm_convert_to_csc.c b/src/z_spm_convert_to_csc.c
index 9ea0d28e..d4efbc3f 100644
--- a/src/z_spm_convert_to_csc.c
+++ b/src/z_spm_convert_to_csc.c
@@ -65,30 +65,13 @@ z_spmConvertIJV2CSC( spmatrix_t *spm )
 
 #if defined(SPM_WITH_MPI)
     if ( spm->loc2glob != NULL ) {
-        /*
-         * Check if the distribution is by column or row by exploiting the fact
-         * that the array is sorted.
-         * This is not completely safe, but that avoids going through the full
-         * matrix.
-         */
         const spm_int_t *glob2loc;
-        spm_int_t m = spm->rowptr[spm->nnz-1] - spm->rowptr[0] + 1; /* This may be not correct */
-        spm_int_t n = spm->colptr[spm->nnz-1] - spm->colptr[0] + 1;
         spm_int_t jg;
-        int distribution = 0;
+        int distribution = spm_get_distribution( spm );
 
-        if ( m <= spm->n ) { /* By row */
-            distribution |= 1;
-        }
-        if ( n <= spm->n ) { /* By column */
-            distribution |= 2;
-        }
-        MPI_Allreduce( MPI_IN_PLACE, &distribution, 1, MPI_INT,
-                       MPI_BAND, spm->comm );
-
-        if ( !(distribution & 2) ) {
-            //fprintf( stderr, "spmConvert: Conversion of column distributed matrices to CSC is not yet implemented\n");
-            return SPM_ERR_NOTIMPLEMENTED;
+        if ( !(distribution & SpmDistByColumn) ) {
+            fprintf( stderr, "spmConvert: Conversion of non column distributed matrices to CSC is not yet implemented\n");
+            return SPM_ERR_BADPARAMETER;
         }
 
         /* Allocate and compute the glob2loc array */
diff --git a/src/z_spm_convert_to_csr.c b/src/z_spm_convert_to_csr.c
index 1f5f390c..3d7a247f 100644
--- a/src/z_spm_convert_to_csr.c
+++ b/src/z_spm_convert_to_csr.c
@@ -70,29 +70,12 @@ z_spmConvertIJV2CSR( spmatrix_t *spm )
 
 #if defined(SPM_WITH_MPI)
     if ( spm->loc2glob != NULL ) {
-        /*
-         * Check if the distribution is by column or row by exploiting the fact
-         * that the array is sorted.
-         * This is not completely safe, but that avoids going through the full
-         * matrix.
-         */
         const spm_int_t *glob2loc;
-        spm_int_t m = spm->rowptr[spm->nnz-1] - spm->rowptr[0] + 1; /* This may be not correct */
-        spm_int_t n = spm->colptr[spm->nnz-1] - spm->colptr[0] + 1;
         spm_int_t ig;
-        int distribution = 0;
+        int distribution = spm_get_distribution( spm );
 
-        if ( m <= spm->n ) { /* By row */
-            distribution |= 1;
-        }
-        if ( n <= spm->n ) { /* By column */
-            distribution |= 2;
-        }
-        MPI_Allreduce( MPI_IN_PLACE, &distribution, 1, MPI_INT,
-                       MPI_BAND, spm->comm );
-
-        if ( !(distribution & 1) ) {
-            //fprintf( stderr, "spmConvert: Conversion of column distributed matrices to CSC is not yet implemented\n");
+        if ( !(distribution & SpmDistByRow) ) {
+            fprintf( stderr, "spmConvert: Conversion of non row distributed matrices to CSR is not yet implemented\n");
             return SPM_ERR_NOTIMPLEMENTED;
         }
 
diff --git a/src/z_spm_matrixvector.c b/src/z_spm_matrixvector.c
index d8a05c28..32c766ea 100644
--- a/src/z_spm_matrixvector.c
+++ b/src/z_spm_matrixvector.c
@@ -37,10 +37,14 @@ __fct_conj( spm_complex64_t val ) {
 }
 #endif
 
+/**
+ * @brief Store all the data necessary to do a matrix-matrix product
+ *        for all cases.
+ */
 struct __spm_zmatvec_s {
     int                    follow_x;
 
-    spm_int_t              baseval, n, nnz;
+    spm_int_t              baseval, n, nnz, gN;
 
     spm_complex64_t        alpha;
     const spm_int_t       *rowptr;
@@ -62,44 +66,77 @@ struct __spm_zmatvec_s {
     __loop_fct_t           loop_fct;
 };
 
-static inline int
-__spm_zmatvec_sy_csr( const __spm_zmatvec_t *args )
+/**
+ * @brief Compute the dof loop for a general block
+ */
+static inline void
+__spm_zmatvec_dof_loop(       spm_int_t        row, spm_int_t dofi,
+                              spm_int_t        col, spm_int_t dofj,
+                              spm_complex64_t *y,   spm_int_t incy,
+                        const spm_complex64_t *x,   spm_int_t incx,
+                        const spm_complex64_t *values,
+                        const __conj_fct_t     conjA_fct,
+                              spm_complex64_t  alpha )
 {
-    spm_int_t              baseval    = args->baseval;
-    spm_int_t              n          = args->n;
-    spm_complex64_t        alpha      = args->alpha;
-    const spm_int_t       *rowptr     = args->rowptr;
-    const spm_int_t       *colptr     = args->colptr;
-    const spm_complex64_t *values     = args->values;
-    const spm_int_t       *loc2glob   = args->loc2glob;
-    const spm_complex64_t *x          = args->x;
-    spm_int_t              incx       = args->incx;
-    spm_complex64_t       *y          = args->y;
-    spm_int_t              incy       = args->incy;
-    const __conj_fct_t     conjA_fct  = args->conjA_fct;
-    const __conj_fct_t     conjAt_fct = args->conjAt_fct;
-    spm_int_t              col, gcol, grow, i;
+    spm_int_t ii, jj;
 
-    for( col=0; col<n; col++, colptr++ )
+    for(jj=0; jj<dofj; jj++)
     {
-        gcol = (loc2glob == NULL) ? col : loc2glob[col] - baseval ;
-        for( i=colptr[0]; i<colptr[1]; i++, rowptr++, values++ )
+        for(ii=0; ii<dofi; ii++, values++)
         {
-            grow = *rowptr - baseval;
-            if ( grow != gcol ) {
-                y[ grow * incy ] += alpha * conjAt_fct( *values ) * x[ gcol * incx ];
-                y[ gcol * incy ] += alpha *  conjA_fct( *values ) * x[ grow * incx ];
-            }
-            else {
-                y[ gcol * incy ] += alpha *  conjA_fct( *values ) * x[ grow * incx ];
-            }
+            y[ row + (ii * incy) ] += alpha * conjA_fct( *values ) * x[ col +(jj * incx) ];
         }
     }
-    return SPM_SUCCESS;
 }
 
+/**
+ * @brief Compute the dof loop for a symmetric off diagonal block
+ */
+static inline void
+__spm_zmatvec_dof_loop_sy(       spm_int_t        row, spm_int_t dofi,
+                                 spm_int_t        col, spm_int_t dofj,
+                                 spm_complex64_t *y,   spm_int_t incy,
+                           const spm_complex64_t *x,   spm_int_t incx,
+                           const spm_complex64_t *values,
+                           const __conj_fct_t     conjA_fct,
+                           const __conj_fct_t     conjAt_fct,
+                                 spm_complex64_t  alpha )
+{
+    spm_int_t ii, jj;
+
+    for(jj=0; jj<dofj; jj++)
+    {
+        for(ii=0; ii<dofi; ii++, values++)
+        {
+            y[ row + (ii * incy) ] += alpha * conjA_fct( *values )  * x[ col +(jj * incx) ];
+            y[ col + (jj * incy) ] += alpha * conjAt_fct( *values ) * x[ row +(ii * incx) ];
+        }
+    }
+}
+
+/**
+ * @brief Compute the dof loop for a symmetric CSR matrix
+ *        Allow code factorization.
+ */
+static inline void
+__spm_zmatvec_dof_loop_sy_csr(       spm_int_t        row, spm_int_t dofi,
+                                     spm_int_t        col, spm_int_t dofj,
+                                     spm_complex64_t *y,   spm_int_t incy,
+                               const spm_complex64_t *x,   spm_int_t incx,
+                               const spm_complex64_t *values,
+                               const __conj_fct_t     conjA_fct,
+                               const __conj_fct_t     conjAt_fct,
+                                     spm_complex64_t  alpha )
+{
+    return __spm_zmatvec_dof_loop_sy( row, dofi, col, dofj, y, incy, x, incx, values, conjAt_fct, conjA_fct, alpha );
+}
+
+/**
+ * @brief Compute A*x[i:, j] = y[i:, j]
+ *        for a CSX symmetric matrix
+ */
 static inline int
-__spm_zmatvec_sy_csc( const __spm_zmatvec_t *args )
+__spm_zmatvec_sy_csx( const __spm_zmatvec_t *args )
 {
     spm_int_t              baseval    = args->baseval;
     spm_int_t              n          = args->n;
@@ -108,35 +145,52 @@ __spm_zmatvec_sy_csc( const __spm_zmatvec_t *args )
     const spm_int_t       *colptr     = args->colptr;
     const spm_complex64_t *values     = args->values;
     const spm_int_t       *loc2glob   = args->loc2glob;
+    const spm_int_t       *dofs       = args->dofs;
+    spm_int_t              dof        = args->dof;
     const spm_complex64_t *x          = args->x;
     spm_int_t              incx       = args->incx;
     spm_complex64_t       *y          = args->y;
     spm_int_t              incy       = args->incy;
     const __conj_fct_t     conjA_fct  = args->conjA_fct;
     const __conj_fct_t     conjAt_fct = args->conjAt_fct;
-    spm_int_t              col, gcol, grow, i;
+    spm_int_t              row, col, dofj, dofi;
+    spm_int_t              i, ig, j, jg;
 
+    /* If(args->follow_x) -> CSR. We need to change exchange the conj functions in the symmetric dof loop */
+    void (*dof_loop_sy)( spm_int_t, spm_int_t, spm_int_t, spm_int_t,
+                         spm_complex64_t *, spm_int_t,
+                         const spm_complex64_t *, spm_int_t, const spm_complex64_t *,
+                         const __conj_fct_t, const __conj_fct_t, spm_complex64_t )
+                        = ( args->follow_x ) ? __spm_zmatvec_dof_loop_sy_csr : __spm_zmatvec_dof_loop_sy;
 
-    for( col=0; col<n; col++, colptr++ )
+    for( j=0; j<n; j++, colptr++ )
     {
-        gcol = (loc2glob == NULL) ? col : loc2glob[col] - baseval ;
-        for( i=colptr[0]; i<colptr[1]; i++, rowptr++, values++ )
+        jg   = (loc2glob == NULL) ? j : loc2glob[j] - baseval ;
+        dofj = ( dof > 0 ) ? dof      : dofs[jg+1] - dofs[jg];
+        col  = ( dof > 0 ) ? dof * jg : dofs[jg] - baseval;
+        for( i=colptr[0]; i<colptr[1]; i++, rowptr++ )
         {
-            grow = *rowptr - baseval;
-            if ( grow != gcol ) {
-                y[ grow * incy ] += alpha *  conjA_fct( *values ) * x[ gcol * incx ];
-                y[ gcol * incy ] += alpha * conjAt_fct( *values ) * x[ grow * incx ];
+            ig   = *rowptr - baseval;
+            dofi = ( dof > 0 ) ? dof      : dofs[ig+1] - dofs[ig];
+            row  = ( dof > 0 ) ? dof * ig : dofs[ig] - baseval;
+            if ( row != col ) {
+                dof_loop_sy( row, dofi, col, dofj, y, incy, x, incx, values, conjA_fct, conjAt_fct, alpha );
             }
             else {
-                y[ gcol * incy ] += alpha *  conjA_fct( *values ) * x[ gcol * incx ];
+                __spm_zmatvec_dof_loop( col, dofj, row, dofi, y, incy, x, incx, values, conjA_fct, alpha );
             }
+            values += dofi*dofj;
         }
     }
     return SPM_SUCCESS;
 }
 
+/**
+ * @brief Compute A*x[i:, j] = y[i:, j]
+ *        for a CSC/CSR general matrix
+ */
 static inline int
-__spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
+__spm_zmatvec_ge_csx( const __spm_zmatvec_t *args )
 {
     spm_int_t              baseval   = args->baseval;
     spm_int_t              n         = args->n;
@@ -153,7 +207,7 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
     spm_int_t              incy      = args->incy;
     const __conj_fct_t     conjA_fct = args->conjA_fct;
     spm_int_t              row, dofj, dofi;
-    spm_int_t              i, ii, ig, j, jj, jg,;
+    spm_int_t              i, ig, j, jg;
 
     if ( args->follow_x ) {
         for( j = 0; j < n; j++, colptr++ )
@@ -165,16 +219,10 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
                 ig   = *rowptr - baseval;
                 dofi = ( dof > 0 ) ? dof      : dofs[ig+1] - dofs[ig];
                 row  = ( dof > 0 ) ? dof * ig : dofs[ig] - baseval;
-
-                for(jj=0; jj<dofj; jj++)
-                   {
-                       for(ii=0; ii<dofi; ii++, values++)
-                       {
-                            y[ row + (ii * incy) ] += alpha * conjA_fct( *values ) * x[ jj ];
-                       }
-                   }
+                __spm_zmatvec_dof_loop( row, dofi, 0, dofj, y, incy, x, 1, values, conjA_fct, alpha );
+                values += dofi * dofj;
             }
-            x += (dofj * incx);
+            x += dofj * incx;
         }
     }
     else {
@@ -187,21 +235,19 @@ __spm_zmatvec_ge_csc( const __spm_zmatvec_t *args )
                 ig   = *rowptr - baseval;
                 dofi = ( dof > 0 ) ? dof      : dofs[ig+1] - dofs[ig];
                 row  = ( dof > 0 ) ? dof * ig : dofs[ig] - baseval;
-
-                for ( jj = 0; jj < dofj; jj++)
-                {
-                    for ( ii = 0; ii < dofi; ii++, values++ )
-                    {
-                        y[jj] += alpha * conjA_fct( *values ) * x[ (row + ii) * incx ];
-                    }
-                }
+                __spm_zmatvec_dof_loop( 0, dofj, row, dofi, y, 1, x, incx, values, conjA_fct, alpha );
+                values += dofi * dofj;
             }
-            y += (dofj * incy);
+            y += dofj * incy;
         }
     }
     return SPM_SUCCESS;
 }
 
+/**
+ * @brief Compute A*x[i:, j] = y[i:, j]
+ *        for a IJV symmetric matrix
+ */
 static inline int
 __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args )
 {
@@ -211,30 +257,68 @@ __spm_zmatvec_sy_ijv( const __spm_zmatvec_t *args )
     const spm_int_t       *rowptr     = args->rowptr;
     const spm_int_t       *colptr     = args->colptr;
     const spm_complex64_t *values     = args->values;
+    const spm_int_t       *dofs       = args->dofs;
+    spm_int_t              dof        = args->dof;
     const spm_complex64_t *x          = args->x;
     spm_int_t              incx       = args->incx;
     spm_complex64_t       *y          = args->y;
     spm_int_t              incy       = args->incy;
     const __conj_fct_t     conjA_fct  = args->conjA_fct;
     const __conj_fct_t     conjAt_fct = args->conjAt_fct;
-    spm_int_t              col, row, i;
+    spm_int_t              row, col, dofj, dofi;
+    spm_int_t              i, ig, jg;
 
-    for( i=0; i<nnz; i++, colptr++, rowptr++, values++ )
+    for( i=0; i<nnz; i++, colptr++, rowptr++ )
     {
-        row = *rowptr - baseval;
-        col = *colptr - baseval;
+        ig = *rowptr - baseval;
+        jg = *colptr - baseval;
+
+        dofj = ( dof > 0 ) ? dof : dofs[jg+1] - dofs[jg];
+        dofi = ( dof > 0 ) ? dof : dofs[ig+1] - dofs[ig];
+
+        row = ( dof > 0 ) ? dof * ig : dofs[ig] - baseval;
+        col = ( dof > 0 ) ? dof * jg : dofs[jg] - baseval;
 
         if ( row != col ) {
-            y[ row * incy ] += alpha *  conjA_fct( *values ) * x[ col * incx ];
-            y[ col * incy ] += alpha * conjAt_fct( *values ) * x[ row * incx ];
+            __spm_zmatvec_dof_loop_sy( row, dofi, col, dofj, y, incy, x, incx, values, conjA_fct, conjAt_fct, alpha );
         }
         else {
-            y[ row * incy ] += alpha *  conjA_fct( *values ) * x[ col * incx ];
+            __spm_zmatvec_dof_loop( row, dofi, col, dofj, y, incy, x, incx, values, conjA_fct, alpha );
         }
+        values += dofi*dofj;
     }
     return SPM_SUCCESS;
 }
 
+/**
+ * @brief Build a local dofs array which corresponds
+ *        to the local numerotation of a global index for
+ *        variadic dofs.
+ */
+static inline spm_int_t *
+__spm_zmatvec_dofs_local( const spm_int_t *dofs,
+                          const spm_int_t *glob2loc,
+                          spm_int_t gN)
+{
+    spm_int_t  i, acc = 0;
+    spm_int_t *result, *resptr;
+
+    result = calloc( gN , sizeof(spm_int_t) );
+    resptr = result;
+    for ( i = 0; i < gN; i++, glob2loc++, resptr++ )
+    {
+        if( *glob2loc >= 0 ) {
+            *resptr = acc;
+            acc += dofs[i+1] - dofs[i];
+        }
+    }
+    return result;
+}
+
+/**
+ * @brief Compute A*x[i:, j] = y[i:, j]
+ *        for a IJV general matrix
+ */
 static inline int
 __spm_zmatvec_ge_ijv( const __spm_zmatvec_t *args )
 {
@@ -245,32 +329,67 @@ __spm_zmatvec_ge_ijv( const __spm_zmatvec_t *args )
     const spm_int_t       *colptr    = args->colptr;
     const spm_complex64_t *values    = args->values;
     const spm_int_t       *glob2loc  = args->loc2glob;
+    const spm_int_t       *dofs      = args->dofs;
+    spm_int_t              dof       = args->dof;
     const spm_complex64_t *x         = args->x;
     spm_int_t              incx      = args->incx;
     spm_complex64_t       *y         = args->y;
     spm_int_t              incy      = args->incy;
     const __conj_fct_t     conjA_fct = args->conjA_fct;
-    spm_int_t              col, row, i;
+    spm_int_t              row, col, dofj, dofi;
+    spm_int_t              i, ig, jg;
+
+    spm_int_t *dof_local = NULL;
+
+    if( (dofs != NULL) && (glob2loc != NULL) ) {
+        dof_local = __spm_zmatvec_dofs_local( dofs, glob2loc, args->gN );
+    }
 
     if( args->follow_x ) {
-        for( i=0; i<nnz; i++, colptr++, rowptr++, values++ )
+        for( i=0; i<nnz; i++, colptr++, rowptr++ )
         {
-            row = *rowptr - baseval;
-            col = (glob2loc == NULL) ? *colptr - baseval : glob2loc[ *colptr - baseval ];
+            ig = *rowptr - baseval;
+            jg = *colptr - baseval;
 
-            y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
+            dofj = ( dof > 0 ) ? dof : dofs[jg+1] - dofs[jg];
+            dofi = ( dof > 0 ) ? dof : dofs[ig+1] - dofs[ig];
+
+            row  = ( dof > 0 ) ? dof * ig : dofs[ig] - baseval;
+            if (glob2loc == NULL) {
+                col = ( dof > 0 ) ? dof * jg : dofs[jg] - baseval;
+            }
+            else {
+                col = ( dof > 0 ) ? dof * glob2loc[jg] : dof_local[jg];
+            }
+            __spm_zmatvec_dof_loop( row, dofi, col, dofj, y, incy, x, incx, values, conjA_fct, alpha );
+            values += dofi*dofj;
         }
     }
     else {
-        for( i=0; i<nnz; i++, colptr++, rowptr++, values++ )
+        for( i=0; i<nnz; i++, colptr++, rowptr++ )
         {
-            row = (glob2loc == NULL) ? *rowptr - baseval : glob2loc[ *rowptr - baseval ];
-            col = *colptr - baseval;
+            ig = *rowptr - baseval;
+            jg = *colptr - baseval;
+
+            dofj = ( dof > 0 ) ? dof : dofs[jg+1] - dofs[jg];
+            dofi = ( dof > 0 ) ? dof : dofs[ig+1] - dofs[ig];
 
-            y[ row * incy ] += alpha * conjA_fct( *values ) * x[ col * incx ];
+            col = ( dof > 0 ) ? dof * jg : dofs[jg] - baseval;
+            if ( glob2loc == NULL ) {
+                row  = ( dof > 0 ) ? dof * ig : dofs[ig] - baseval;
+            }
+            else {
+                row = ( dof > 0 ) ? dof * glob2loc[ig] : dof_local[ig];
+            }
+            __spm_zmatvec_dof_loop( row, dofi, col, dofj, y, incy, x, incx, values, conjA_fct, alpha );
+            values += dofi*dofj;
         }
     }
 
+    if(dof_local != NULL) {
+        free(dof_local);
+    }
+
     return SPM_SUCCESS;
 }
 
@@ -324,6 +443,7 @@ __spm_zmatvec_args_init( __spm_zmatvec_t       *args,
     args->baseval    = spmFindBase( A );
     args->n          = A->n;
     args->nnz        = A->nnz;
+    args->gN         = A->gN;
     args->alpha      = alpha;
     args->rowptr     = A->rowptr;
     args->colptr     = A->colptr;
@@ -363,31 +483,39 @@ __spm_zmatvec_args_init( __spm_zmatvec_t       *args,
     case SpmCSC:
     {
         /* Switch pointers and side to get the correct behaviour */
-        if ( ((side == SpmLeft)  && (transA == SpmNoTrans)) ||
-             ((side == SpmRight) && (transA != SpmNoTrans)) )
-        {
-            args->follow_x = 1;
-        }
-        else {
-            args->follow_x = 0;
+        if( A->mtxtype == SpmGeneral ) {
+            if ( ((side == SpmLeft)  && (transA == SpmNoTrans)) ||
+                 ((side == SpmRight) && (transA != SpmNoTrans)) )
+            {
+                args->follow_x = 1;
+            }
+            else {
+                args->follow_x = 0;
+            }
         }
-        args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csc;
+        args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csx : __spm_zmatvec_sy_csx;
     }
     break;
     case SpmCSR:
     {
         /* Switch pointers and side to get the correct behaviour */
-        if ( ((side == SpmLeft)  && (transA != SpmNoTrans)) ||
-             ((side == SpmRight) && (transA == SpmNoTrans)) )
-        {
-            args->follow_x = 1;
+        if( A->mtxtype == SpmGeneral ) {
+            if ( ((side == SpmLeft)  && (transA != SpmNoTrans)) ||
+                 ((side == SpmRight) && (transA == SpmNoTrans)) )
+            {
+                args->follow_x = 1;
+            }
+            else {
+                args->follow_x = 0;
+            }
         }
         else {
-            args->follow_x = 0;
+            args->follow_x = 1;
         }
+
         args->colptr = A->rowptr;
         args->rowptr = A->colptr;
-        args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csc : __spm_zmatvec_sy_csr;
+        args->loop_fct = (A->mtxtype == SpmGeneral) ? __spm_zmatvec_ge_csx : __spm_zmatvec_sy_csx;
     }
     break;
     case SpmIJV:
@@ -415,6 +543,33 @@ __spm_zmatvec_args_init( __spm_zmatvec_t       *args,
     return SPM_SUCCESS;
 }
 
+/**
+ *******************************************************************************
+ *
+ * @ingroup spm_dev_matvec
+ *
+ * @brief Build a global C RHS, set to 0 for remote datas.
+ *
+ *******************************************************************************
+ *
+ * @param[in] spm
+ *          The pointer to the sparse matrix structure.
+ *
+ * @param[in] Cloc
+ *          The local C vector.
+ *
+ * @param[inout] ldc
+ *          The leading dimension of the local C vector.
+ *          Will be updated to corresponds to the global one.
+ *
+ * @param[in] nrhs
+ *          The number of RHS.
+ *
+ *******************************************************************************
+ *
+ * @return A global C vector which stores local datas and set remote datas to 0.
+ *
+ *******************************************************************************/
 static inline spm_complex64_t *
 z_spmm_build_Ctmp( const spmatrix_t      *spm,
                    const spm_complex64_t *Cloc,
@@ -422,6 +577,7 @@ z_spmm_build_Ctmp( const spmatrix_t      *spm,
                          int              nrhs )
 {
     spm_complex64_t *Ctmp;
+    spm_complex64_t *Cptr = (spm_complex64_t *)Cloc;
     spm_int_t i, j, idx;
     spm_int_t ig, dof, baseval, *loc2glob;
 
@@ -429,22 +585,50 @@ z_spmm_build_Ctmp( const spmatrix_t      *spm,
     *ldc = spm->gNexp;
 
     baseval = spmFindBase(spm);
-    for ( j = 0; j < nrhs; j++)
+    for ( j = 0; j < nrhs; j++ )
     {
         loc2glob = spm->loc2glob;
-        for ( i = 0; i < spm->n; i++, loc2glob++)
+        for ( i = 0; i < spm->n; i++, loc2glob++ )
         {
             ig  = *loc2glob - baseval;
             dof = (spm->dof > 0) ? spm->dof : spm->dofs[ig+1] - spm->dofs[ig];
             idx = (spm->dof > 0) ? spm->dof * ig : spm->dofs[ig] - baseval;
             memcpy( (Ctmp + j * spm->gNexp + idx),
-                    (Cloc + j * spm->nexp  +   i),
-                    dof * sizeof(spm_complex64_t) );
+                     Cptr,
+                     dof * sizeof(spm_complex64_t) );
+            Cptr += dof;
         }
     }
     return Ctmp;
 }
 
+/**
+ *******************************************************************************
+ *
+ * @ingroup spm_dev_matvec
+ *
+ * @brief Build a global B vector by gathering datas from all nodes.
+ *
+ *******************************************************************************
+ *
+ * @param[in] spm
+ *          The pointer to the sparse matrix structure.
+ *
+ * @param[in] Bloc
+ *          The local B vector.
+ *
+ * @param[inout] ldb
+ *          The leading dimension of the local B vector.
+ *          Will be updated to corresponds to the global one.
+ *
+ * @param[in] nrhs
+ *          The number of RHS.
+ *
+ *******************************************************************************
+ *
+ * @return The gathered Btmp vector.
+ *
+ *******************************************************************************/
 static inline spm_complex64_t *
 z_spmm_build_Btmp( const spmatrix_t      *spm,
                    const spm_complex64_t *Bloc,
@@ -555,6 +739,7 @@ spm_zspmm( spm_side_t             side,
            spm_int_t              ldc )
 {
     int rc = SPM_SUCCESS;
+    int distribution;
     spm_int_t M, N, ldx, ldy, r;
     __spm_zmatvec_t args;
     spm_complex64_t *Ctmp, *Btmp;
@@ -593,20 +778,20 @@ spm_zspmm( spm_side_t             side,
 
     Btmp = (spm_complex64_t*)B;
     Ctmp = C;
-    if ( A->loc2glob != NULL ) {
-        int distByCol = spmGetDistribution(A);
+    distribution = spm_get_distribution(A);
+    if ( distribution != ( SpmDistByColumn | SpmDistByRow ) ) {
 
         if ( A->mtxtype != SpmGeneral ) {
             Btmp = z_spmm_build_Btmp( A, B, &ldb, N );
             Ctmp = z_spmm_build_Ctmp( A, C, &ldc, N );
         }
         else {
-            if( ( (transA != SpmNoTrans) && (distByCol == 1) ) ||
-                ( (transA == SpmNoTrans) && (distByCol == 0) ) ) {
+            if( ( (transA != SpmNoTrans) && (distribution == 1) ) ||
+                ( (transA == SpmNoTrans) && (distribution == 2) ) ) {
                 Btmp = z_spmm_build_Btmp( A, B, &ldb, N );
             }
-            if( ( (transA == SpmNoTrans) && (distByCol == 1) ) ||
-                ( (transA != SpmNoTrans) && (distByCol == 0) ) ) {
+            if( ( (transA == SpmNoTrans) && (distribution == 1) ) ||
+                ( (transA != SpmNoTrans) && (distribution == 2) ) ) {
                 Ctmp = z_spmm_build_Ctmp( A, C, &ldc, N );
             }
         }
@@ -622,7 +807,7 @@ spm_zspmm( spm_side_t             side,
     }
 
     if ( Ctmp != C ) {
-        z_spmReduceRhs( A, N, Ctmp, C, ldc );
+        z_spmReduceRHS( A, N, Ctmp, A->gNexp, C, A->nexp );
         free( Ctmp );
     }
 
@@ -687,6 +872,7 @@ spm_zspmv( spm_trans_t            trans,
            spm_int_t              incy )
 {
     int rc = SPM_SUCCESS;
+    int distribution;
     __spm_zmatvec_t args;
     spm_complex64_t *ytmp, *xtmp;
 
@@ -703,20 +889,20 @@ spm_zspmv( spm_trans_t            trans,
 
     xtmp = (spm_complex64_t*)x;
     ytmp = y;
-    if ( A->loc2glob != NULL ){
-        int distByCol = spmGetDistribution(A);
+    distribution = spm_get_distribution(A);
+    if ( distribution != ( SpmDistByColumn | SpmDistByRow ) ){
 
         if ( A->mtxtype != SpmGeneral ) {
             xtmp = z_spmm_build_Btmp( A, x, &incx, 1 );
             ytmp = z_spmm_build_Ctmp( A, y, &incy, 1 );
         }
         else {
-            if( ( (trans != SpmNoTrans) && (distByCol == 1) ) ||
-                ( (trans == SpmNoTrans) && (distByCol == 0) ) ) {
+            if( ( (trans != SpmNoTrans) && (distribution == 1) ) ||
+                ( (trans == SpmNoTrans) && (distribution == 2) ) ) {
                 xtmp = z_spmm_build_Btmp( A, x, &incx, 1 );
             }
-            if( ( (trans == SpmNoTrans) && (distByCol == 1) ) ||
-                ( (trans != SpmNoTrans) && (distByCol == 0) ) ) {
+            if( ( (trans == SpmNoTrans) && (distribution == 1) ) ||
+                ( (trans != SpmNoTrans) && (distribution == 2) ) ) {
                 ytmp = z_spmm_build_Ctmp( A, y, &incy, 1 );
             }
         }
@@ -727,8 +913,8 @@ spm_zspmv( spm_trans_t            trans,
     rc = args.loop_fct( &args );
 
     if ( ytmp != y ) {
-        z_spmReduceRhs( A, 1, ytmp, y, incy );
-        free(ytmp);
+        z_spmReduceRHS( A, 1, ytmp, A->gNexp, y, A->nexp );
+        free( ytmp );
     }
 
     if ( xtmp != x ) {
diff --git a/src/z_spm_reduce_rhs.c b/src/z_spm_reduce_rhs.c
index f2a7f964..2a331d97 100644
--- a/src/z_spm_reduce_rhs.c
+++ b/src/z_spm_reduce_rhs.c
@@ -41,9 +41,10 @@
  *
  *******************************************************************************/
 void
-z_spmReduceRhs( const spmatrix_t      *spm,
+z_spmReduceRHS( const spmatrix_t      *spm,
                       int              nrhs,
                       spm_complex64_t *bglob,
+                      spm_int_t        ldbglob,
                       spm_complex64_t *b,
                       spm_int_t        ldb )
 {
@@ -57,7 +58,7 @@ z_spmReduceRhs( const spmatrix_t      *spm,
         return;
     }
 
-    MPI_Allreduce( MPI_IN_PLACE, bglob, ldb * nrhs, SPM_MPI_COMPLEX64, MPI_SUM, spm->comm );
+    MPI_Allreduce( MPI_IN_PLACE, bglob, ldbglob * nrhs, SPM_MPI_COMPLEX64, MPI_SUM, spm->comm );
 
     baseval  = spmFindBase( spm );
     loc2glob = spm->loc2glob;
@@ -67,7 +68,7 @@ z_spmReduceRhs( const spmatrix_t      *spm,
         row  = ( spm->dof > 0 ) ? spm->dof * ig : spm->dofs[ig] - baseval;
         for( j=0; j<nrhs; j++ ) {
             for( k=0; k<dofi; k++ ) {
-                rhs[ j * spm->nexp + k ] = bglob[ row + j * ldb + k ];
+                rhs[ j * ldb + k ] = bglob[ row + j * ldbglob + k ];
             }
         }
         rhs += dofi;
@@ -76,6 +77,7 @@ z_spmReduceRhs( const spmatrix_t      *spm,
     (void)spm;
     (void)nrhs;
     (void)bglob;
+    (void)ldbglob;
     (void)b;
     (void)ldb;
 #endif
-- 
GitLab