From ae250b9c2a28f62bc781df50ce500ad4d29be11f Mon Sep 17 00:00:00 2001
From: Florent Pruvost <florent.pruvost@inria.fr>
Date: Mon, 2 Aug 2021 17:06:13 +0200
Subject: [PATCH] Update cesca to be able to compute a Correspondance Analysis
 (COA) pre-treatment. It depends on the same kernels and has the same
 dependency pattern as cesca, no need to create a new chameleon algorithm for
 it.

---
 compute/zcesca.c                |  89 ++++++++++++++++++++--------
 coreblas/compute/core_zcesca.c  | 101 ++++++++++++++++++++------------
 include/chameleon/chameleon_z.h |   4 +-
 testing/testing_zcesca.c        |   2 +-
 4 files changed, 131 insertions(+), 65 deletions(-)

diff --git a/compute/zcesca.c b/compute/zcesca.c
index 446951edc..35f2839a7 100644
--- a/compute/zcesca.c
+++ b/compute/zcesca.c
@@ -126,7 +126,9 @@ void CHAMELEON_zcesca_WS_Free( void *user_ws )
  *
  * @ingroup CHAMELEON_Complex64_t
  *
- *  CHAMELEON_zcesca replace a general matrix by the Centered-Scaled matrix inplace
+ *  CHAMELEON_zcesca replace a general matrix by the Centered-Scaled matrix inplace.
+ *  This algorithm is used as a pretreatment of a Principal Component Algorithm
+ *  (PCA) or a Correspondence analysis (COA).
  *
  *  Considering a matrix A of size m x n, \f[A = (a_{i,j})_{1 \leq i \leq m, 1 \leq j \leq n}\f]
  *  Lets
@@ -147,6 +149,22 @@ void CHAMELEON_zcesca_WS_Free( void *user_ws )
  * A scaled columnwise gives \f[A' = (a_{i,j}')_{1 \leq i \leq m, 1 \leq j \leq n}\f] such that
  * \f[ a_{*j}' = \frac{a_{*j}}{d_j} \f]
  *
+ * This function can also be used to compute a pretreatment of a Correspondence analysis (COA).
+ * To use it set center = 1, scale = 1, axis = ChamEltwise.
+ * A on entry is a contingency table. For this pre-treatment we need to work on
+ * the frequencies table (each value of A is divided by the global sum)
+ * In this case lets
+ * \f[r_i = \sum_j a_{ij} \\
+ *    c_j = \sum_i a_{ij} \\
+ *    sg  = \sum_{i,j} a_{ij} \f]
+ *
+ * A transformed gives \f[\bar{A} = (\bar{a}_{i,j})_{1 \leq i \leq m, 1 \leq j \leq n}\f] such that
+ * \f[ \bar{a}_{i,j} = \frac{a_{i,j}-r_i*c_j/sg}{ \sqrt{r_i*c_j} } \f]
+ *
+ * It also gives \f[ r_i \f] and \f[ c_j \f]
+ * in vectors SR and SC useful in the post-treatment of the COA.
+ * SR and SC must be already allocated.
+ *
  *******************************************************************************
  *
  * @param[in] center
@@ -159,7 +177,9 @@ void CHAMELEON_zcesca_WS_Free( void *user_ws )
  *          Specifies the axis over which to center and or scale.
  *            = ChamColumnwise: centered column-wise
  *            = ChamRowwise: centered row-wise
- *            = ChamEltwise: bi-centered (only compatible if center=1 and scale=0)
+ *            = ChamEltwise:
+ *              bi-centered if center=1 and scale=0
+ *              COA if center=1 and scale=1
  *
  * @param[in] M
  *          The number of rows of the overall matrix.  M >= 0.
@@ -167,12 +187,18 @@ void CHAMELEON_zcesca_WS_Free( void *user_ws )
  * @param[in] N
  *         The number of columns of the overall matrix.  N >= 0.
  *
- * @param[in] A
+ * @param[in,out] A
  *          The M-by-N matrix A.
  *
  * @param[in] LDA
  *          The leading dimension of the array A. LDA >= max(1,M).
  *
+ * @param[out] SR
+ *          The vector of size M containing the \f[ r_i \f]
+ *
+ * @param[out] SC
+ *          The vector of size N containing the \f[ c_j \f]
+ *
  *******************************************************************************
  *
 * @retval CHAMELEON_SUCCESS successful exit
@@ -184,7 +210,9 @@ void CHAMELEON_zcesca_WS_Free( void *user_ws )
  * @sa CHAMELEON_scesca
  *
  */
-int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHAMELEON_Complex64_t *A, int LDA )
+int CHAMELEON_zcesca(int center, int scale, cham_store_t axis,
+                     int M, int N, CHAMELEON_Complex64_t *A, int LDA,
+                     CHAMELEON_Complex64_t *SR, CHAMELEON_Complex64_t *SC)
 {
     int NB;
     int status;
@@ -192,7 +220,7 @@ int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHA
     RUNTIME_sequence_t *sequence = NULL;
     RUNTIME_request_t request = RUNTIME_REQUEST_INITIALIZER;
     CHAM_desc_t descAl, descAt;
-    void *ws;
+    struct chameleon_pzcesca_s *ws;
 
     chamctxt = chameleon_context_self();
     if (chamctxt == NULL) {
@@ -212,10 +240,6 @@ int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHA
         chameleon_error("CHAMELEON_zcesca", "Illegal value of axis");
         return -3;
     }
-    if ( (axis == ChamEltwise) && (center == 1) && (scale == 1) ) {
-        chameleon_error("CHAMELEON_zcesca", "Illegal value of axis and/or scale, center=1 and axis=ChamEltwise (i.e. bi-centered) must not be used with scale=1");
-        return -3;
-    }
     if (M < 0) {
         chameleon_error("CHAMELEON_zcesca", "Illegal value of M");
         return -4;
@@ -246,8 +270,8 @@ int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHA
     chameleon_sequence_create( chamctxt, &sequence );
 
     /* Submit the matrix conversion */
-    chameleon_zlap2tile( chamctxt, &descAl, &descAt, ChamDescInout, ChamUpperLower,
-                         A, NB, NB, LDA, N, N, N, sequence, &request );
+    chameleon_zlap2tile( chamctxt, &descAl, &descAt, ChamDescInput, ChamUpperLower,
+                         A, NB, NB, LDA, N, M, N, sequence, &request );
 
     /* Call the tile interface */
     ws = CHAMELEON_zcesca_WS_Alloc( &descAt );
@@ -255,10 +279,18 @@ int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHA
 
     /* Submit the matrix conversion back */
     chameleon_ztile2lap( chamctxt, &descAl, &descAt,
-                         ChamDescInout, ChamUpperLower, sequence, &request );
+                         ChamDescInput, ChamUpperLower, sequence, &request );
 
     chameleon_sequence_wait( chamctxt, sequence );
 
+    /* pre-coa case : save the sums over rows and columns */
+    if ( (center == 1) && (scale == 1) && (axis == ChamEltwise) ) {
+        CHAM_desc_t *descSR = chameleon_desc_submatrix( &(ws->Wgrow), 0, 0, M, 1 );
+        CHAM_desc_t *descSC = chameleon_desc_submatrix( &(ws->Wgcol), 0, 0, 1, N );
+        CHAMELEON_zDesc2Lap( ChamUpperLower, descSR, SR, M );
+        CHAMELEON_zDesc2Lap( ChamUpperLower, descSC, SC, 1 );
+    }
+
     /* Cleanup the temporary data */
     CHAMELEON_zcesca_WS_Free( ws );
     chameleon_ztile2lap_cleanup( chamctxt, &descAl, &descAt );
@@ -290,11 +322,19 @@ int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHA
  *          Specifies the axis over which to center and or scale.
  *            = ChamColumnwise: centered column-wise
  *            = ChamRowwise: centered row-wise
- *            = ChamEltwise: bi-centered (only compatible if center=1 and scale=0)
+ *            = ChamEltwise:
+ *              bi-centered if center=1 and scale=0
+ *              pre-coa if center=1 and scale=1
  *
- * @param[in] A
+ * @param[in,out] A
  *          The M-by-N matrix A.
  *
+ * @param[out] SR
+ *          The vector of size M containing the \f[ r_i \f]
+ *
+ * @param[out] SC
+ *          The vector of size N containing the \f[ c_j \f]
+ *
  *******************************************************************************
  *
  * @retval CHAMELEON_SUCCESS successful exit
@@ -306,13 +346,14 @@ int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHA
  * @sa CHAMELEON_scesca_Tile
  *
  */
-int CHAMELEON_zcesca_Tile( int center, int scale, cham_store_t axis, CHAM_desc_t *A )
+int CHAMELEON_zcesca_Tile( int center, int scale, cham_store_t axis, CHAM_desc_t *A,
+                           CHAMELEON_Complex64_t *SR, CHAMELEON_Complex64_t *SC)
 {
     CHAM_context_t *chamctxt;
     RUNTIME_sequence_t *sequence = NULL;
     RUNTIME_request_t request = RUNTIME_REQUEST_INITIALIZER;
     int status;
-    void *ws;
+    struct chameleon_pzcesca_s *ws;
 
     chamctxt = chameleon_context_self();
     if (chamctxt == NULL) {
@@ -333,10 +374,6 @@ int CHAMELEON_zcesca_Tile( int center, int scale, cham_store_t axis, CHAM_desc_t
         chameleon_error("CHAMELEON_zcesca_Tile", "Illegal value of axis");
         return -3;
     }
-    if ( (axis == ChamEltwise) && (center == 1) && (scale == 1) ) {
-        chameleon_error("CHAMELEON_zcesca_Tile", "Illegal value of axis and/or scale, center=1 and axis=ChamEltwise (i.e. bi-centered) must not be used with scale=1");
-        return -3;
-    }
 
     chameleon_sequence_create( chamctxt, &sequence );
 
@@ -347,6 +384,14 @@ int CHAMELEON_zcesca_Tile( int center, int scale, cham_store_t axis, CHAM_desc_t
 
     chameleon_sequence_wait( chamctxt, sequence );
 
+    /* pre-coa case : save the sums over rows and columns */
+    if ( (center == 1) && (scale == 1) && (axis == ChamEltwise) ) {
+        CHAM_desc_t *descSR = chameleon_desc_submatrix( &(ws->Wgrow), 0, 0, A->lm, 1 );
+        CHAM_desc_t *descSC = chameleon_desc_submatrix( &(ws->Wgcol), 0, 0, 1, A->ln );
+        CHAMELEON_zDesc2Lap( ChamUpperLower, descSR, SR, A->lm );
+        CHAMELEON_zDesc2Lap( ChamUpperLower, descSC, SC, 1 );
+    }
+
     CHAMELEON_zcesca_WS_Free( ws );
 
     status = sequence->status;
@@ -403,10 +448,6 @@ int CHAMELEON_zcesca_Tile_Async( int center, int scale, cham_store_t axis, CHAM_
         chameleon_error("CHAMELEON_zcesca_Tile_Async", "Illegal value of axis");
         return -3;
     }
-    if ( (axis == ChamEltwise) && (center == 1) && (scale == 1) ) {
-        chameleon_error("CHAMELEON_zcesca_Tile_Async", "Illegal value of axis and/or scale, center=1 and axis=ChamEltwise (i.e. bi-centered) must not be used with scale=1");
-        return -3;
-    }
     if (sequence == NULL) {
         chameleon_fatal_error("CHAMELEON_zcesca_Tile_Async", "NULL sequence");
         return CHAMELEON_ERR_UNALLOCATED;
diff --git a/coreblas/compute/core_zcesca.c b/coreblas/compute/core_zcesca.c
index 5d95d2eef..8bdc87dac 100644
--- a/coreblas/compute/core_zcesca.c
+++ b/coreblas/compute/core_zcesca.c
@@ -43,6 +43,15 @@
  * A scaled columnwise gives \f[A' = (a_{i,j}')_{1 \leq i \leq m, 1 \leq j \leq n}\f] such that
  * \f[ a_{*j}' = \frac{a_{*j}}{d_j} \f]
  *
+ * This function can also be used to compute a pretreatment of a Correspondence analysis (COA).
+ * To use it set center = 1, scale = 1, axis = ChamEltwise.
+ * In this case lets
+ * \f[r_i = \sum_j a_{ij} \\
+ *    c_j = \sum_i a_{ij} \\
+ *    sg  = \sum_{i,j} a_{ij}\f]
+ * A transformed gives \f[\bar{A} = (\bar{a}_{i,j})_{1 \leq i \leq m, 1 \leq j \leq n}\f] such that
+ * \f[ \bar{a}_{i,j} = \frac{a_{i,j}-r_i*c_j/sg}{ \sqrt{r_i*c_j} } \f]
+ *
  *******************************************************************************
  *
  * @param[in] center
@@ -55,7 +64,9 @@
  *          Specifies the axis over which to center and or scale.
  *            = ChamColumnwise: centered column-wise
  *            = ChamRowwise: centered row-wise
- *            = ChamEltwise: bi-centered (only compatible if center=1 and scale=0)
+ *            = ChamEltwise:
+ *              bi-centered if center=1 and scale=0
+ *              COA if center=1 and scale=1
  *
  * @param[in] M
  *          The number of rows of the overall matrix.  M >= 0.
@@ -122,7 +133,7 @@ int CORE_zcesca( int center, int scale,
                  CHAMELEON_Complex64_t *A, int LDA )
 {
     int i, j;
-    CHAMELEON_Complex64_t gi, gj, g;
+    CHAMELEON_Complex64_t gi, gj, g, rc, sqrc;
     double di, dj;
 
     /* Check input arguments */
@@ -138,10 +149,6 @@ int CORE_zcesca( int center, int scale,
         coreblas_error(3, "Illegal value of axis");
         return -3;
     }
-    if ( (axis == ChamEltwise) && (center == 1) && (scale == 1) ) {
-        coreblas_error(3, "Illegal value of axis and/or scale, center=1 and axis=ChamEltwise (i.e. bi-centered) must not be used with scale=1");
-        return -3;
-    }
     if (M < 0) {
         coreblas_error(4, "Illegal value of M");
         return -4;
@@ -187,46 +194,64 @@ int CORE_zcesca( int center, int scale,
         return CHAMELEON_SUCCESS;
     }
 
-    if ( (center == 1) && (axis == ChamEltwise) ) {
-        /* overall mean of values */
-        g =  G[0] / ( (double)M * (double)N );
-    }
+    if ( !( (center == 1) && (scale == 1) && (axis == ChamEltwise) ) ) {
+        /* PCA case i.e. centered-scaled or bi-centering */
 
-    for(j = 0; j < Nt; j++) {
-        if ( (center == 1) && ( (axis == ChamColumnwise) || (axis == ChamEltwise) ) ) {
-            /* mean of values of the column */
-            gj = Gj[j*LDGJ] / ((double)M);
+        if ( (center == 1) && (axis == ChamEltwise) ) {
+            /* overall mean of values */
+            g =  G[0] / ( (double)M * (double)N );
         }
-        if ( (scale == 1) && (axis == ChamColumnwise) ) {
-            /* norm 2 of the column */
-            dj = Dj[j*LDDJ];
-        }
-        for(i = 0; i < Mt; i++) {
-            if ( (center == 1) && ( (axis == ChamRowwise) || (axis == ChamEltwise) ) ) {
-                /* mean of values of the row */
-                gi = Gi[i] / ((double)N);
-                /* compute centered matrix factor */
-                A[j*LDA+i] -= gi;
-            }
+
+        for(j = 0; j < Nt; j++) {
             if ( (center == 1) && ( (axis == ChamColumnwise) || (axis == ChamEltwise) ) ) {
-                /* compute centered matrix factor */
-                A[j*LDA+i] -= gj;
-            }
-            if ( (center == 1) && (axis == ChamEltwise) ) {
-                /* compute centered matrix factor */
-                A[j*LDA+i] += g;
+                /* mean of values of the column */
+                gj = Gj[j*LDGJ] / ((double)M);
             }
             if ( (scale == 1) && (axis == ChamColumnwise) ) {
-                /* compute scaled matrix factor */
-                A[j*LDA+i] /= dj;
+                /* norm 2 of the column */
+                dj = Dj[j*LDDJ];
+            }
+            for(i = 0; i < Mt; i++) {
+                if ( (center == 1) && ( (axis == ChamRowwise) || (axis == ChamEltwise) ) ) {
+                    /* mean of values of the row */
+                    gi = Gi[i] / ((double)N);
+                    /* compute centered matrix factor */
+                    A[j*LDA+i] -= gi;
+                }
+                if ( (center == 1) && ( (axis == ChamColumnwise) || (axis == ChamEltwise) ) ) {
+                    /* compute centered matrix factor */
+                    A[j*LDA+i] -= gj;
+                }
+                if ( (center == 1) && (axis == ChamEltwise) ) {
+                    /* compute centered matrix factor */
+                    A[j*LDA+i] += g;
+                }
+                if ( (scale == 1) && (axis == ChamColumnwise) ) {
+                    /* compute scaled matrix factor */
+                    A[j*LDA+i] /= dj;
+                }
+                if ( (scale == 1) && (axis == ChamRowwise) ) {
+                    /* norm 2 of the row */
+                    di = Di[i];
+                    /* compute scaled matrix factor */
+                    A[j*LDA+i] /= di;
+                }
             }
-            if ( (scale == 1) && (axis == ChamRowwise) ) {
-                /* norm 2 of the row */
-                di = Di[i];
-                /* compute scaled matrix factor */
-                A[j*LDA+i] /= di;
+        }
+
+    } else {
+        /* COA case */
+
+        /* update the matrix */
+        for(j = 0; j < Nt; j++) {
+            for(i = 0; i < Mt; i++) {
+                rc = Gi[i] * Gj[j];
+                sqrc = sqrt(rc);
+                A[j*LDA + i] -= Gi[i] * Gj[j] / G[0];
+                A[j*LDA + i] /= sqrc;
             }
         }
+
     }
 
     return CHAMELEON_SUCCESS;
diff --git a/include/chameleon/chameleon_z.h b/include/chameleon/chameleon_z.h
index 119bb25bf..5ed871c02 100644
--- a/include/chameleon/chameleon_z.h
+++ b/include/chameleon/chameleon_z.h
@@ -357,8 +357,8 @@ int CHAMELEON_zbuild_Tile_Async(cham_uplo_t uplo, CHAM_desc_t *A, void *user_dat
 /**
  * Centered-Scaled function prototypes
  */
-int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHAMELEON_Complex64_t *A, int LDA );
-int CHAMELEON_zcesca_Tile( int center, int scale, cham_store_t axis, CHAM_desc_t *A );
+int CHAMELEON_zcesca(int center, int scale, cham_store_t axis, int M, int N, CHAMELEON_Complex64_t *A, int LDA, CHAMELEON_Complex64_t *SR, CHAMELEON_Complex64_t *SC );
+int CHAMELEON_zcesca_Tile( int center, int scale, cham_store_t axis, CHAM_desc_t *A, CHAMELEON_Complex64_t *SR, CHAMELEON_Complex64_t *SC );
 int CHAMELEON_zcesca_Tile_Async( int center, int scale, cham_store_t axis, CHAM_desc_t *A, void *user_ws, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
 /**
  * Gram function prototypes
diff --git a/testing/testing_zcesca.c b/testing/testing_zcesca.c
index b252476e8..edc42ce8d 100644
--- a/testing/testing_zcesca.c
+++ b/testing/testing_zcesca.c
@@ -63,7 +63,7 @@ testing_zcesca( run_arg_list_t *args, int check )
 
     /* Compute the centered-scaled matrix transformation */
     START_TIMING( t );
-    hres = CHAMELEON_zcesca_Tile( 1, 1, ChamColumnwise, descA );
+    hres = CHAMELEON_zcesca_Tile( 1, 1, ChamColumnwise, descA, NULL, NULL );
     STOP_TIMING( t );
     gflops = flops * 1.e-9 / t;
     run_arg_add_fixdbl( args, "time", t );
-- 
GitLab