From 96d9c0fbfe9bbd36a652c8a0fcbcbdc3cad6f073 Mon Sep 17 00:00:00 2001
From: KUHN Matthieu <matthieu.kuhn@inria.fr>
Date: Thu, 5 Apr 2018 14:49:35 +0200
Subject: [PATCH] Matrix vector product for all Spm storage formats using
 function pointers

---
 include/spm.h            |   1 +
 src/spm.c                |   2 +-
 src/z_spm_matrixvector.c | 374 ++++++++++++++++++++++++++++++++++++---
 tests/spm_matvec_tests.c |  87 ++++-----
 4 files changed, 392 insertions(+), 72 deletions(-)

diff --git a/include/spm.h b/include/spm.h
index f74efe7e..a5a1aa85 100644
--- a/include/spm.h
+++ b/include/spm.h
@@ -131,6 +131,7 @@ spm_int_t * spmIntConvert(   spm_int_t n, int *input );
 void        spmIntSort1Asc1( void * const pbase, const spm_int_t n );
 void        spmIntSort2Asc1( void * const pbase, const spm_int_t n );
 void        spmIntSort2Asc2( void * const pbase, const spm_int_t n );
+void        spmIntMSortIntAsc(void ** const pbase, const spm_int_t n);
 
 void        spmIntMSortIntAsc(void ** const pbase, const spm_int_t n);
 
diff --git a/src/spm.c b/src/spm.c
index b2737556..e49114d7 100644
--- a/src/spm.c
+++ b/src/spm.c
@@ -1015,7 +1015,7 @@ spmMatVec(       spm_trans_t trans,
     spmatrix_t *espm = (spmatrix_t*)spm;
     int rc = SPM_SUCCESS;
 
-    if ( spm->fmttype != SpmCSC ) {
+    if ( spm->fmttype != SpmCSC && spm->fmttype != SpmCSR && spm->fmttype != SpmIJV ) {
         return SPM_ERR_BADPARAMETER;
     }
 
diff --git a/src/z_spm_matrixvector.c b/src/z_spm_matrixvector.c
index f0fc6c63..dd2a9776 100644
--- a/src/z_spm_matrixvector.c
+++ b/src/z_spm_matrixvector.c
@@ -17,6 +17,317 @@
 #include "common.h"
 #include "z_spm.h"
 
+typedef void (*z_vectorUpdater_t)(const spm_complex64_t alpha,
+                                  const spm_int_t baseval,
+                                  const spm_int_t pos,
+                                  const spm_int_t row,
+                                  const spm_int_t col,
+                                  const spm_complex64_t *x,
+                                  const spm_complex64_t *val,
+                                  spm_complex64_t *y);
+
+spm_complex64_t z_idFunc(spm_complex64_t val)
+{
+    return val;
+}
+
+void z_updateVectCore(const spm_complex64_t alpha,
+                    const spm_int_t baseval,
+                    const spm_int_t pos,
+                    const spm_int_t idy,
+                    const spm_int_t idx,
+                    const spm_complex64_t *x,
+                    const spm_complex64_t *val,
+                    spm_complex64_t *y,
+                    spm_complex64_t (*conj_func)(spm_complex64_t))
+{
+    y[idy] += alpha * conj_func(val[pos-baseval]) * x[idx];
+}
+
+
+void z_updateVectNoTrans(const spm_complex64_t alpha,
+                       const spm_int_t baseval,
+                       const spm_int_t pos,
+                       const spm_int_t row,
+                       const spm_int_t col,
+                       const spm_complex64_t *x,
+                       const spm_complex64_t *val,
+                       spm_complex64_t *y)
+{
+    z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc);
+}
+
+void z_updateVectTrans(const spm_complex64_t alpha,
+                       const spm_int_t baseval,
+                       const spm_int_t pos,
+                       const spm_int_t row,
+                       const spm_int_t col,
+                       const spm_complex64_t *x,
+                       const spm_complex64_t *val,
+                       spm_complex64_t *y)
+{
+    z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,z_idFunc);
+}
+
+#if defined(PRECISION_c) || defined(PRECISION_z)
+void z_updateVectConjTrans(const spm_complex64_t alpha,
+                           const spm_int_t baseval,
+                           const spm_int_t pos,
+                           const spm_int_t row,
+                           const spm_int_t col,
+                           const spm_complex64_t *x,
+                           const spm_complex64_t *val,
+                           spm_complex64_t *y)
+{
+    z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,conj);
+}
+#endif
+
+void z_updateVectSy(const spm_complex64_t alpha,
+                       const spm_int_t baseval,
+                       const spm_int_t pos,
+                       const spm_int_t row,
+                       const spm_int_t col,
+                       const spm_complex64_t *x,
+                       const spm_complex64_t *val,
+                       spm_complex64_t *y)
+{
+    z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc);
+    if( col != row )
+    {
+        z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,z_idFunc);
+    }
+}
+
+#if defined(PRECISION_c) || defined(PRECISION_z)
+void z_updateVectHe(const spm_complex64_t alpha,
+                           const spm_int_t baseval,
+                           const spm_int_t pos,
+                           const spm_int_t row,
+                           const spm_int_t col,
+                           const spm_complex64_t *x,
+                           const spm_complex64_t *val,
+                           spm_complex64_t *y)
+{
+    if( col != row )
+    {
+        z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,z_idFunc);
+        z_updateVectCore(alpha,baseval,pos,col,row,x,val,y,conj);
+    }
+    else
+    {
+        z_updateVectCore(alpha,baseval,pos,row,col,x,val,y,conj);
+    }
+}
+#endif
+
+int z_loopMatCSC(const spm_int_t       baseval,
+                 const spm_complex64_t alpha,
+                 const spmatrix_t      *spm,
+                 const spm_complex64_t *x,
+                 spm_complex64_t       *yptr,
+                 z_vectorUpdater_t updateVect)
+{
+    const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values);
+    const spm_complex64_t *xptr   = (const spm_complex64_t*)x;
+    spm_int_t col, row, i;
+
+    for( col=0; col < spm->gN; col++ )
+    {
+        for( i=spm->colptr[col]; i<spm->colptr[col+1]; i++ )
+        {
+            row = spm->rowptr[i-baseval]-baseval;
+            updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr);
+        }
+    }
+    return SPM_SUCCESS;
+}
+
+int z_loopMatCSR(const spm_int_t       baseval,
+                 const spm_complex64_t alpha,
+                 const spmatrix_t      *spm,
+                 const spm_complex64_t *x,
+                 spm_complex64_t       *yptr,
+                 z_vectorUpdater_t updateVect)
+{
+    const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values);
+    const spm_complex64_t *xptr   = (const spm_complex64_t*)x;
+    spm_int_t col, row, i;
+
+    for( row=0; row < spm->gN; row++ )
+    {
+        for( i=spm->rowptr[row]; i<spm->rowptr[row+1]; i++ )
+        {
+            col = spm->colptr[i-baseval]-baseval;
+            updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr);
+        }
+    }
+    return SPM_SUCCESS;
+}
+
+
+int z_loopMatIJV(const spm_int_t       baseval,
+                 const spm_complex64_t alpha,
+                 const spmatrix_t      *spm,
+                 const spm_complex64_t *x,
+                 spm_complex64_t       *yptr,
+                 z_vectorUpdater_t updateVect)
+{
+    const spm_complex64_t *valptr = (spm_complex64_t*)(spm->values);
+    const spm_complex64_t *xptr   = (const spm_complex64_t*)x;
+    spm_int_t col, row, i, upperBound;
+
+    upperBound = spm->gnnz+baseval;
+    for( i=baseval; i < upperBound; i++ )
+    {
+        row = spm->rowptr[i-baseval]-baseval;
+        col = spm->colptr[i-baseval]-baseval;
+        updateVect(alpha,baseval,i,row,col,xptr,valptr,yptr);
+    }
+    return SPM_SUCCESS;
+}
+
+/**
+ *******************************************************************************
+ *
+ * @ingroup spm_dev_matvec
+ *
+ * @brief compute the matrix-vector product:
+ *          y = alpha * op( A ) * x + beta * y
+ *
+ * A is a SpmGeneral spm, where op( X ) is one of
+ *
+ *    op( X ) = X  or op( X ) = X' or op( X ) = conjg( X' )
+ *
+ *  alpha and beta are scalars, and x and y are vectors.
+ *
+ *******************************************************************************
+ *
+ * @param[in] trans
+ *          Specifies whether the matrix spm is transposed, not transposed or
+ *          conjugate transposed:
+ *          = SpmNoTrans:   A is not transposed;
+ *          = SpmTrans:     A is transposed;
+ *          = SpmConjTrans: A is conjugate transposed.
+ *
+ * @param[in] alpha
+ *          alpha specifies the scalar alpha
+ *
+ * @param[in] spm
+ *          The SpmGeneral spm.
+ *
+ * @param[in] x
+ *          The vector x.
+ *
+ * @param[in] beta
+ *          beta specifies the scalar beta
+ *
+ * @param[inout] y
+ *          The vector y.
+ *
+ *******************************************************************************
+ *
+ * @retval SPM_SUCCESS if the y vector has been computed succesfully,
+ * @retval SPM_ERR_BADPARAMETER otherwise.
+ *
+ *******************************************************************************/
+int
+z_spmv(const spm_trans_t      trans,
+                  spm_complex64_t  alpha,
+            const spmatrix_t       *spm,
+            const spm_complex64_t *x,
+                  spm_complex64_t  beta,
+                  spm_complex64_t *y )
+{
+    spm_complex64_t *yptr = (spm_complex64_t*)y;
+    spm_int_t baseval, i;
+    spm_int_t (*getRow(spm_int_t,spmatrix_t));
+    spm_int_t (*getCol(spm_int_t,spmatrix_t));
+    const spm_fmttype_t fmt = spm->fmttype;
+    const spm_mtxtype_t mtxtype = spm->mtxtype;
+    z_vectorUpdater_t updateVect;
+
+    if ( (spm == NULL) || (x == NULL) || (y == NULL ) )
+    {
+        return SPM_ERR_BADPARAMETER;
+    }
+
+
+    if( mtxtype == SpmGeneral )
+    {
+        /**
+         * Select the appropriate vector updater
+         */
+        if( trans == SpmNoTrans )
+        {
+            updateVect=&z_updateVectNoTrans;
+        }
+        /**
+         * SpmTrans
+         */
+        else if( trans == SpmTrans )
+        {
+            updateVect=&z_updateVectTrans;
+        }
+#if defined(PRECISION_c) || defined(PRECISION_z)
+        /**
+         * SpmConjTrans
+         */
+        else if( trans == SpmConjTrans )
+        {
+            updateVect=&z_updateVectConjTrans;
+        }
+#endif
+        else
+        {
+            return SPM_ERR_BADPARAMETER;
+        }
+    }
+    else if( mtxtype == SpmSymmetric )
+    {
+        updateVect=&z_updateVectSy;
+    }
+#if defined(PRECISION_z) || defined(PRECISION_c)
+    else if( mtxtype == SpmHermitian )
+    {
+        updateVect=&z_updateVectHe;
+    }
+#endif
+    else
+        return SPM_ERR_BADPARAMETER;
+
+    /* first, y = beta*y */
+    if( beta == 0. ) {
+        memset( yptr, 0, spm->gN * sizeof(spm_complex64_t) );
+    }
+    else {
+        for( i=0; i<spm->gN; i++, yptr++ ) {
+            (*yptr) *= beta;
+        }
+        yptr = y;
+    }
+
+    baseval = spmFindBase( spm );
+    if( alpha != 0. ) {
+        /**
+         * Select the appropriate matrix looper
+         */
+        if( fmt == SpmCSC)
+        {
+            return z_loopMatCSC(baseval, alpha, spm, x, yptr, updateVect);
+        }
+        else if( fmt == SpmCSR )
+        {
+            return z_loopMatCSR(baseval, alpha, spm, x, yptr, updateVect);
+        }
+        else if( fmt == SpmIJV )
+        {
+            return z_loopMatIJV(baseval, alpha, spm, x, yptr, updateVect);
+        }
+    }
+    return SPM_ERR_BADPARAMETER;
+}
+
 /**
  *******************************************************************************
  *
@@ -148,6 +459,7 @@ z_spmGeCSCv(const spm_trans_t      trans,
     return SPM_SUCCESS;
 }
 
+
 /**
  *******************************************************************************
  *
@@ -378,17 +690,18 @@ z_spmCSCMatVec(const spm_trans_t  trans,
     alpha = *((const spm_complex64_t *)alphaptr);
     beta  = *((const spm_complex64_t *)betaptr);
 
-    switch (spm->mtxtype) {
-#if defined(PRECISION_z) || defined(PRECISION_c)
-    case SpmHermitian:
-        return z_spmHeCSCv( alpha, spm, x, beta, y );
-#endif
-    case SpmSymmetric:
-        return z_spmSyCSCv( alpha, spm, x, beta, y );
-    case SpmGeneral:
-    default:
-        return z_spmGeCSCv( trans, alpha, spm, x, beta, y );
-    }
+//    switch (spm->mtxtype) {
+//#if defined(PRECISION_z) || defined(PRECISION_c)
+//    case SpmHermitian:
+//        return z_spmHeCSCv( alpha, spm, x, beta, y );
+//#endif
+//    case SpmSymmetric:
+//        return z_spmSyCSCv( alpha, spm, x, beta, y );
+//    case SpmGeneral:
+//    default:
+//        return z_spmGeCSCv( trans, alpha, spm, x, beta, y );
+//    }
+    return z_spmv( trans, alpha, spm, x, beta, y);
 }
 
 /**
@@ -463,24 +776,27 @@ z_spmCSCMatMat(const spm_trans_t trans,
     alpha = *((const spm_complex64_t *)alphaptr);
     beta  = *((const spm_complex64_t *)betaptr);
 
-    switch (A->mtxtype) {
-#if defined(PRECISION_z) || defined(PRECISION_c)
-    case SpmHermitian:
-        for( i=0; i<n; i++ ){
-            rc = z_spmHeCSCv( alpha, A, B + i * ldb, beta, C + i *ldc );
-        }
-        break;
-#endif
-    case SpmSymmetric:
-        for( i=0; i<n; i++ ){
-            rc = z_spmSyCSCv( alpha, A, B + i * ldb, beta, C + i *ldc );
-        }
-        break;
-    case SpmGeneral:
-    default:
-        for( i=0; i<n; i++ ){
-            rc = z_spmGeCSCv( trans, alpha, A, B + i * ldb, beta, C + i *ldc );
-        }
+//    switch (A->mtxtype) {
+//#if defined(PRECISION_z) || defined(PRECISION_c)
+//    case SpmHermitian:
+//        for( i=0; i<n; i++ ){
+//            rc = z_spmHeCSCv( alpha, A, B + i * ldb, beta, C + i *ldc );
+//        }
+//        break;
+//#endif
+//    case SpmSymmetric:
+//        for( i=0; i<n; i++ ){
+//            rc = z_spmSyCSCv( alpha, A, B + i * ldb, beta, C + i *ldc );
+//        }
+//        break;
+//    case SpmGeneral:
+//    default:
+//        for( i=0; i<n; i++ ){
+//            rc = z_spmGeCSCv( trans, alpha, A, B + i * ldb, beta, C + i *ldc );
+//        }
+//    }
+    for( i=0; i<n; i++ ){
+        rc = z_spmv( trans, alpha, A, B + i * ldb, beta, C + i *ldc );
     }
     return rc;
 }
diff --git a/tests/spm_matvec_tests.c b/tests/spm_matvec_tests.c
index 7ad94e56..9a1f894f 100644
--- a/tests/spm_matvec_tests.c
+++ b/tests/spm_matvec_tests.c
@@ -39,13 +39,14 @@ int s_spm_matvec_check( int trans, const spmatrix_t *spm );
 char* fltnames[] = { "Pattern", "", "Float", "Double", "Complex32", "Complex64" };
 char* transnames[] = { "NoTrans", "Trans", "ConjTrans" };
 char* mtxnames[] = { "General", "Symmetric", "Hermitian" };
+char* mtxfmts[] = { "CSC", "CSR", "IJV" };
 
 int main (int argc, char **argv)
 {
     spmatrix_t    spm;
     spm_driver_t driver;
     char *filename;
-    int t,spmtype, mtxtype, baseval;
+    int t,spmtype, mtxtype, mtxfmt, baseval;
     int rc = SPM_SUCCESS;
     int err = 0;
 
@@ -67,7 +68,6 @@ int main (int argc, char **argv)
     /**
      * Only CSC is supported for now
      */
-    spmConvert( SpmCSC, &spm );
 
     spmtype = spm.mtxtype;
     printf(" -- SPM Matrix-Vector Test --\n");
@@ -77,56 +77,59 @@ int main (int argc, char **argv)
     {
         printf(" Baseval : %d\n", baseval );
         spmBase( &spm, baseval );
-
-        for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ )
+        for( mtxfmt=SpmCSC; mtxfmt<=SpmIJV; mtxfmt++ )
         {
-            if ( (mtxtype == SpmHermitian) &&
-                 ( ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)) ||
-                   (spmtype != SpmHermitian) ) )
-            {
-                continue;
-            }
-            if ( (mtxtype != SpmGeneral) &&
-                 (spmtype == SpmGeneral) )
-            {
-                continue;
-            }
-            spm.mtxtype = mtxtype;
-
-            for( t=SpmNoTrans; t<=SpmConjTrans; t++ )
+            spmConvert( mtxfmt, &spm );
+            for( mtxtype=SpmGeneral; mtxtype<=SpmHermitian; mtxtype++ )
             {
-                if ( (t == SpmConjTrans) &&
-                     ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)))
+                if ( (mtxtype == SpmHermitian) &&
+                     ( ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)) ||
+                       (spmtype != SpmHermitian) ) )
                 {
                     continue;
                 }
-                if ( (spm.mtxtype != SpmGeneral) && (t != SpmNoTrans) )
+                if ( (mtxtype != SpmGeneral) &&
+                     (spmtype == SpmGeneral) )
                 {
                     continue;
                 }
+                spm.mtxtype = mtxtype;
 
-                printf("   Case %s - %d - %s:\n",
-                       mtxnames[mtxtype - SpmGeneral], baseval,
-                       transnames[t - SpmNoTrans] );
-
-                switch( spm.flttype ){
-                case SpmComplex64:
-                    rc = z_spm_matvec_check( t, &spm );
-                    break;
-
-                case SpmComplex32:
-                    rc = c_spm_matvec_check( t, &spm );
-                break;
-
-                case SpmFloat:
-                    rc = s_spm_matvec_check( t, &spm );
-                    break;
-
-                case SpmDouble:
-                default:
-                    rc = d_spm_matvec_check( t, &spm );
+                for( t=SpmNoTrans; t<=SpmConjTrans; t++ )
+                {
+                    if ( (t == SpmConjTrans) &&
+                         ((spm.flttype != SpmComplex64) && (spm.flttype != SpmComplex32)))
+                    {
+                        continue;
+                    }
+                    if ( (spm.mtxtype != SpmGeneral) && (t != SpmNoTrans) )
+                    {
+                        continue;
+                    }
+
+                    printf("   Case %s - %s - %d - %s:\n",
+                           mtxnames[mtxtype - SpmGeneral], mtxfmts[mtxfmt - SpmCSC],
+                           baseval, transnames[t - SpmNoTrans] );
+
+                    switch( spm.flttype ){
+                    case SpmComplex64:
+                        rc = z_spm_matvec_check( t, &spm );
+                        break;
+
+                    case SpmComplex32:
+                        rc = c_spm_matvec_check( t, &spm );
+                        break;
+
+                    case SpmFloat:
+                        rc = s_spm_matvec_check( t, &spm );
+                        break;
+
+                    case SpmDouble:
+                    default:
+                        rc = d_spm_matvec_check( t, &spm );
+                    }
+                    PRINT_RES(rc);
                 }
-                PRINT_RES(rc);
             }
         }
     }
-- 
GitLab