From 2a75a75044287506046eeb85329764602dc9e85e Mon Sep 17 00:00:00 2001
From: Pierre Ramet <pierre.ramet@inria.fr>
Date: Mon, 19 Feb 2018 17:13:50 +0100
Subject: [PATCH] scal multi RHS

---
 spm.c | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--
 spm.h |  5 +++--
 2 files changed, 57 insertions(+), 4 deletions(-)

diff --git a/spm.c b/spm.c
index 95562345..4ddac602 100644
--- a/spm.c
+++ b/spm.c
@@ -26,6 +26,7 @@
 #include "p_spm.h"
 
 #include <cblas.h>
+#include <lapacke.h>
 
 #if !defined(DOXYGEN_SHOULD_SKIP_THIS)
 
@@ -1184,7 +1185,7 @@ spmCheckAxb( pastix_int_t nrhs,
  *
  *******************************************************************************/
 void
-spmScalMatrix(const double alpha, pastix_spm_t* spm)
+spmScalMatrix(double alpha, pastix_spm_t* spm)
 {
     switch(spm->flttype)
     {
@@ -1225,7 +1226,7 @@ spmScalMatrix(const double alpha, pastix_spm_t* spm)
  *
  *******************************************************************************/
 void
-spmScalVector(const double alpha, pastix_spm_t* spm, void *x)
+spmScalVector(double alpha, pastix_spm_t* spm, void *x)
 {
     switch(spm->flttype)
     {
@@ -1246,6 +1247,57 @@ spmScalVector(const double alpha, pastix_spm_t* spm, void *x)
     }
 }
 
+/**
+ *******************************************************************************
+ *
+ * @brief Scale a dense matrix corresponding to a set of RHS (wrapper to LAPACKE_xlascl)
+ *
+ * A = alpha * A
+ *
+ *******************************************************************************
+ *
+ * @param[in] flt
+ *          Datatype.
+ *
+ * @param[in] m
+ *          Number of rows of the matrix A.
+ *
+ * @param[in] n
+ *          Number of columns of the matrix A.
+ *
+ * @param[in] alpha
+ *           The scaling parameter.
+ *
+ * @param[inout] A
+ *          The matrix of RHS to scal.
+ *
+ * @param[in] lda
+ *          Defines the leading dimension of A when multiple right hand sides
+ *          are available. lda >= m.
+ *
+ *******************************************************************************/
+void
+spmScalRHS(pastix_coeftype_t flt, double alpha, pastix_int_t m, pastix_int_t n, void *A, pastix_int_t lda)
+{
+    switch(flt)
+    {
+    case PastixPattern:
+        break;
+    case PastixFloat:
+        LAPACKE_slascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda);
+        break;
+    case PastixComplex32:
+        LAPACKE_clascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda);
+        break;
+    case PastixComplex64:
+        LAPACKE_zlascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda);
+        break;
+    case PastixDouble:
+    default:
+        LAPACKE_dlascl_work(LAPACK_COL_MAJOR, 'G', 0, 0, 1., alpha, m, n, A, lda);
+    }
+}
+
 /**
  * @}
  */
diff --git a/spm.h b/spm.h
index 0f782dcb..91df2c91 100644
--- a/spm.h
+++ b/spm.h
@@ -92,8 +92,9 @@ void          spmGenFakeValues( pastix_spm_t *spm );
  */
 double        spmNorm( pastix_normtype_t ntype, const pastix_spm_t *spm );
 int           spmMatVec( pastix_trans_t trans, const void *alpha, const pastix_spm_t *spm, const void *x, const void *beta, void *y );
-void          spmScalMatrix( const double alpha, pastix_spm_t *spm );
-void          spmScalVector( const double alpha, pastix_spm_t *spm, void *x );
+void          spmScalMatrix( double alpha, pastix_spm_t *spm );
+void          spmScalVector( double alpha, pastix_spm_t *spm, void *x );
+void          spmScalRHS( pastix_coeftype_t flt, double alpha, pastix_int_t m, pastix_int_t n, void *A, pastix_int_t lda );
 
 /**
  * @}
-- 
GitLab