From 16a70a7eb1e400aff2f84f89af3fac8a908378ca Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Tue, 5 Nov 2019 00:10:03 +0100
Subject: [PATCH] Add a function to check if trans parameters are valide
 without being affected by the conversion

---
 compute/zgeadd.c               |  4 +--
 compute/zgemm.c                |  8 ++---
 compute/ztradd.c               |  4 +--
 compute/ztrmm.c                | 29 +++++++---------
 compute/ztrsm.c                | 24 +++++++-------
 coreblas/compute/core_zgeadd.c | 31 ++++++++----------
 coreblas/compute/core_zlatro.c |  2 +-
 coreblas/compute/core_zpemv.c  |  2 +-
 coreblas/compute/core_ztradd.c | 60 +++++++++++++++-------------------
 include/chameleon/constants.h  | 11 +++++++
 10 files changed, 85 insertions(+), 90 deletions(-)

diff --git a/compute/zgeadd.c b/compute/zgeadd.c
index c98f82c73..6b2c6aa5a 100644
--- a/compute/zgeadd.c
+++ b/compute/zgeadd.c
@@ -105,7 +105,7 @@ int CHAMELEON_zgeadd( cham_trans_t trans, int M, int N,
     }
 
     /* Check input arguments */
-    if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) {
+    if ( !isValidTrans( trans ) ) {
         chameleon_error("CHAMELEON_zgeadd", "illegal value of trans");
         return -1;
     }
@@ -317,7 +317,7 @@ int CHAMELEON_zgeadd_Tile_Async( cham_trans_t trans,
         return chameleon_request_fail(sequence, request, CHAMELEON_ERR_ILLEGAL_VALUE);
     }
     /* Check input arguments */
-    if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) {
+    if ( !isValidTrans( trans ) ) {
         chameleon_error("CHAMELEON_zgeadd_Tile_Async", "illegal value of trans");
         return chameleon_request_fail(sequence, request, -1);
     }
diff --git a/compute/zgemm.c b/compute/zgemm.c
index 6de404567..8fc4e8f89 100644
--- a/compute/zgemm.c
+++ b/compute/zgemm.c
@@ -146,11 +146,11 @@ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int
     }
 
     /* Check input arguments */
-    if ((transA < ChamNoTrans) || (transA > ChamConjTrans)) {
+    if ( !isValidTrans( transA ) ) {
         chameleon_error("CHAMELEON_zgemm", "illegal value of transA");
         return -1;
     }
-    if ((transB < ChamNoTrans) || (transB > ChamConjTrans)) {
+    if ( !isValidTrans( transB ) ) {
         chameleon_error("CHAMELEON_zgemm", "illegal value of transB");
         return -2;
     }
@@ -394,11 +394,11 @@ int CHAMELEON_zgemm_Tile_Async( cham_trans_t transA, cham_trans_t transB,
         return chameleon_request_fail(sequence, request, CHAMELEON_ERR_ILLEGAL_VALUE);
     }
     /* Check input arguments */
-    if ((transA < ChamNoTrans) || (transA > ChamConjTrans)) {
+    if ( !isValidTrans( transA ) ) {
         chameleon_error("CHAMELEON_zgemm_Tile_Async", "illegal value of transA");
         return chameleon_request_fail(sequence, request, -1);
     }
-    if ((transB < ChamNoTrans) || (transB > ChamConjTrans)) {
+    if ( !isValidTrans( transB ) ) {
         chameleon_error("CHAMELEON_zgemm_Tile_Async", "illegal value of transB");
         return chameleon_request_fail(sequence, request, -2);
     }
diff --git a/compute/ztradd.c b/compute/ztradd.c
index 02cd629f3..0b1700a15 100644
--- a/compute/ztradd.c
+++ b/compute/ztradd.c
@@ -115,7 +115,7 @@ int CHAMELEON_ztradd( cham_uplo_t uplo, cham_trans_t trans, int M, int N,
         chameleon_error("CHAMELEON_ztradd", "illegal value of uplo");
         return -1;
     }
-    if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) {
+    if ( !isValidTrans( trans ) ) {
         chameleon_error("CHAMELEON_ztradd", "illegal value of trans");
         return -2;
     }
@@ -333,7 +333,7 @@ int CHAMELEON_ztradd_Tile_Async( cham_uplo_t uplo, cham_trans_t trans,
         return chameleon_request_fail(sequence, request, CHAMELEON_ERR_ILLEGAL_VALUE);
     }
     /* Check input arguments */
-    if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) {
+    if ( !isValidTrans( trans ) ) {
         chameleon_error("CHAMELEON_ztradd_Tile_Async", "illegal value of trans");
         return chameleon_request_fail(sequence, request, -1);
     }
diff --git a/compute/ztrmm.c b/compute/ztrmm.c
index 8d1b203fb..09399a3b5 100644
--- a/compute/ztrmm.c
+++ b/compute/ztrmm.c
@@ -42,7 +42,7 @@
  *          = ChamUpper: Upper triangle of A is stored;
  *          = ChamLower: Lower triangle of A is stored.
  *
- * @param[in] transA
+ * @param[in] trans
  *          Specifies whether the matrix A is transposed, not transposed or conjugate transposed:
  *          = ChamNoTrans:   A is transposed;
  *          = ChamTrans:     A is not transposed;
@@ -97,7 +97,7 @@
  *
  */
 int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo,
-                 cham_trans_t transA, cham_diag_t diag,
+                 cham_trans_t trans, cham_diag_t diag,
                  int N, int NRHS, CHAMELEON_Complex64_t alpha,
                  CHAMELEON_Complex64_t *A, int LDA,
                  CHAMELEON_Complex64_t *B, int LDB )
@@ -131,13 +131,8 @@ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo,
         chameleon_error("CHAMELEON_ztrmm", "illegal value of uplo");
         return -2;
     }
-    if ((transA != ChamNoTrans)   &&
-#if defined(PRECISION_z) || defined(PRECISION_c)
-        (transA != ChamConjTrans) &&
-#endif
-        (transA != ChamTrans) )
-    {
-        chameleon_error("CHAMELEON_ztrmm", "illegal value of transA");
+    if ( !isValidTrans( trans ) ) {
+        chameleon_error("CHAMELEON_ztrmm", "illegal value of trans");
         return -3;
     }
     if ((diag != ChamUnit) && (diag != ChamNonUnit)) {
@@ -183,7 +178,7 @@ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo,
                      B, NB, NB, LDB, NRHS, N, NRHS, sequence, &request );
 
     /* Call the tile interface */
-    CHAMELEON_ztrmm_Tile_Async(  side, uplo, transA, diag, alpha, &descAt, &descBt, sequence, &request );
+    CHAMELEON_ztrmm_Tile_Async(  side, uplo, trans, diag, alpha, &descAt, &descBt, sequence, &request );
 
     /* Submit the matrix conversion back */
     chameleon_ztile2lap( chamctxt, &descAl, &descAt,
@@ -224,7 +219,7 @@ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo,
  *          = ChamUpper: Upper triangle of A is stored;
  *          = ChamLower: Lower triangle of A is stored.
  *
- * @param[in] transA
+ * @param[in] trans
  *          Specifies whether the matrix A is transposed, not transposed or conjugate transposed:
  *          = ChamNoTrans:   A is transposed;
  *          = ChamTrans:     A is not transposed;
@@ -264,7 +259,7 @@ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo,
  *
  */
 int CHAMELEON_ztrmm_Tile( cham_side_t side, cham_uplo_t uplo,
-                      cham_trans_t transA, cham_diag_t diag,
+                      cham_trans_t trans, cham_diag_t diag,
                       CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B )
 {
     CHAM_context_t *chamctxt;
@@ -279,7 +274,7 @@ int CHAMELEON_ztrmm_Tile( cham_side_t side, cham_uplo_t uplo,
     }
     chameleon_sequence_create( chamctxt, &sequence );
 
-    CHAMELEON_ztrmm_Tile_Async(side, uplo, transA, diag, alpha, A, B, sequence, &request );
+    CHAMELEON_ztrmm_Tile_Async(side, uplo, trans, diag, alpha, A, B, sequence, &request );
 
     CHAMELEON_Desc_Flush( A, sequence );
     CHAMELEON_Desc_Flush( B, sequence );
@@ -319,7 +314,7 @@ int CHAMELEON_ztrmm_Tile( cham_side_t side, cham_uplo_t uplo,
  *
  */
 int CHAMELEON_ztrmm_Tile_Async( cham_side_t side, cham_uplo_t uplo,
-                            cham_trans_t transA, cham_diag_t diag,
+                            cham_trans_t trans, cham_diag_t diag,
                             CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
                             RUNTIME_sequence_t *sequence, RUNTIME_request_t *request )
 {
@@ -368,8 +363,8 @@ int CHAMELEON_ztrmm_Tile_Async( cham_side_t side, cham_uplo_t uplo,
         chameleon_error("CHAMELEON_ztrmm_Tile", "illegal value of uplo");
         return chameleon_request_fail(sequence, request, -2);
     }
-    if ((transA < ChamNoTrans) || (transA > ChamConjTrans)) {
-        chameleon_error("CHAMELEON_ztrmm_Tile", "illegal value of transA");
+    if ( !isValidTrans( trans ) ) {
+        chameleon_error("CHAMELEON_ztrmm_Tile", "illegal value of trans");
         return chameleon_request_fail(sequence, request, -3);
     }
     if ((diag != ChamUnit) && (diag != ChamNonUnit)) {
@@ -378,7 +373,7 @@ int CHAMELEON_ztrmm_Tile_Async( cham_side_t side, cham_uplo_t uplo,
     }
 
     /* Quick return */
-    chameleon_pztrmm( side, uplo, transA, diag, alpha, A, B, sequence, request );
+    chameleon_pztrmm( side, uplo, trans, diag, alpha, A, B, sequence, request );
 
     return CHAMELEON_SUCCESS;
 }
diff --git a/compute/ztrsm.c b/compute/ztrsm.c
index a0989cef6..cb9edbf88 100644
--- a/compute/ztrsm.c
+++ b/compute/ztrsm.c
@@ -43,7 +43,7 @@
  *          = ChamUpper: Upper triangle of A is stored;
  *          = ChamLower: Lower triangle of A is stored.
  *
- * @param[in] transA
+ * @param[in] trans
  *          Specifies whether the matrix A is transposed, not transposed or conjugate transposed:
  *          = ChamNoTrans:   A is transposed;
  *          = ChamTrans:     A is not transposed;
@@ -98,7 +98,7 @@
  *
  */
 int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo,
-                 cham_trans_t transA, cham_diag_t diag,
+                 cham_trans_t trans, cham_diag_t diag,
                  int N, int NRHS, CHAMELEON_Complex64_t alpha,
                  CHAMELEON_Complex64_t *A, int LDA,
                  CHAMELEON_Complex64_t *B, int LDB )
@@ -132,8 +132,8 @@ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo,
         chameleon_error("CHAMELEON_ztrsm", "illegal value of uplo");
         return -2;
     }
-    if (((transA < ChamNoTrans) || (transA > ChamConjTrans)) ) {
-        chameleon_error("CHAMELEON_ztrsm", "illegal value of transA");
+    if ( !isValidTrans( trans ) ) {
+        chameleon_error("CHAMELEON_ztrsm", "illegal value of trans");
         return -3;
     }
     if ((diag != ChamUnit) && (diag != ChamNonUnit)) {
@@ -179,7 +179,7 @@ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo,
                      B, NB, NB, LDB, NRHS, N,  NRHS, sequence, &request );
 
     /* Call the tile interface */
-    CHAMELEON_ztrsm_Tile_Async(  side, uplo, transA, diag, alpha, &descAt, &descBt, sequence, &request );
+    CHAMELEON_ztrsm_Tile_Async(  side, uplo, trans, diag, alpha, &descAt, &descBt, sequence, &request );
 
     /* Submit the matrix conversion back */
     chameleon_ztile2lap( chamctxt, &descAl, &descAt,
@@ -220,7 +220,7 @@ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo,
  *          = ChamUpper: Upper triangle of A is stored;
  *          = ChamLower: Lower triangle of A is stored.
  *
- * @param[in] transA
+ * @param[in] trans
  *          Specifies whether the matrix A is transposed, not transposed or conjugate transposed:
  *          = ChamNoTrans:   A is transposed;
  *          = ChamTrans:     A is not transposed;
@@ -260,7 +260,7 @@ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo,
  *
  */
 int CHAMELEON_ztrsm_Tile( cham_side_t side, cham_uplo_t uplo,
-                      cham_trans_t transA, cham_diag_t diag,
+                      cham_trans_t trans, cham_diag_t diag,
                       CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B )
 {
     CHAM_context_t *chamctxt;
@@ -275,7 +275,7 @@ int CHAMELEON_ztrsm_Tile( cham_side_t side, cham_uplo_t uplo,
     }
     chameleon_sequence_create( chamctxt, &sequence );
 
-    CHAMELEON_ztrsm_Tile_Async(side, uplo, transA, diag, alpha, A, B, sequence, &request );
+    CHAMELEON_ztrsm_Tile_Async(side, uplo, trans, diag, alpha, A, B, sequence, &request );
 
     CHAMELEON_Desc_Flush( A, sequence );
     CHAMELEON_Desc_Flush( B, sequence );
@@ -315,7 +315,7 @@ int CHAMELEON_ztrsm_Tile( cham_side_t side, cham_uplo_t uplo,
  *
  */
 int CHAMELEON_ztrsm_Tile_Async( cham_side_t side, cham_uplo_t uplo,
-                            cham_trans_t transA, cham_diag_t diag,
+                            cham_trans_t trans, cham_diag_t diag,
                             CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
                             RUNTIME_sequence_t *sequence, RUNTIME_request_t *request )
 {
@@ -364,8 +364,8 @@ int CHAMELEON_ztrsm_Tile_Async( cham_side_t side, cham_uplo_t uplo,
         chameleon_error("CHAMELEON_ztrsm_Tile", "illegal value of uplo");
         return chameleon_request_fail(sequence, request, -2);
     }
-    if ((transA < ChamNoTrans) || (transA > ChamConjTrans)) {
-        chameleon_error("CHAMELEON_ztrsm_Tile", "illegal value of transA");
+    if ( !isValidTrans( trans ) ) {
+        chameleon_error("CHAMELEON_ztrsm_Tile", "illegal value of trans");
         return chameleon_request_fail(sequence, request, -3);
     }
     if ((diag != ChamUnit) && (diag != ChamNonUnit)) {
@@ -374,7 +374,7 @@ int CHAMELEON_ztrsm_Tile_Async( cham_side_t side, cham_uplo_t uplo,
     }
 
     /* Quick return */
-    chameleon_pztrsm( side, uplo, transA, diag, alpha, A, B, sequence, request );
+    chameleon_pztrsm( side, uplo, trans, diag, alpha, A, B, sequence, request );
 
     return CHAMELEON_SUCCESS;
 }
diff --git a/coreblas/compute/core_zgeadd.c b/coreblas/compute/core_zgeadd.c
index 4882a79cf..a7def06bb 100644
--- a/coreblas/compute/core_zgeadd.c
+++ b/coreblas/compute/core_zgeadd.c
@@ -88,7 +88,7 @@ int CORE_zgeadd(cham_trans_t trans, int M, int N,
 {
     int i, j;
 
-    if ((trans < ChamNoTrans) || (trans > ChamConjTrans))
+    if ( !isValidTrans( trans ) )
     {
         coreblas_error(1, "illegal value of trans");
         return -1;
@@ -122,36 +122,33 @@ int CORE_zgeadd(cham_trans_t trans, int M, int N,
                              0, 0, 1., beta, M, N, B, LDB );
     }
 
-    switch( trans ) {
+    if( trans == ChamNoTrans ) {
+        for (j=0; j<N; j++) {
+            for(i=0; i<M; i++, B++, A++) {
+                *B += alpha * (*A);
+            }
+            A += LDA-M;
+            B += LDB-M;
+        }
+    }
 #if defined(PRECISION_z) || defined(PRECISION_c)
-    case ChamConjTrans:
+    else if ( trans == ChamConjTrans ) {
         for (j=0; j<N; j++, A++) {
             for(i=0; i<M; i++, B++) {
                 *B += alpha * conj(A[LDA*i]);
             }
             B += LDB-M;
         }
-        break;
+    }
 #endif /* defined(PRECISION_z) || defined(PRECISION_c) */
-
-    case ChamTrans:
+    else {
         for (j=0; j<N; j++, A++) {
             for(i=0; i<M; i++, B++) {
                 *B += alpha * A[LDA*i];
             }
             B += LDB-M;
         }
-        break;
-
-    case ChamNoTrans:
-    default:
-        for (j=0; j<N; j++) {
-            for(i=0; i<M; i++, B++, A++) {
-                *B += alpha * (*A);
-            }
-            A += LDA-M;
-            B += LDB-M;
-        }
     }
+
     return CHAMELEON_SUCCESS;
 }
diff --git a/coreblas/compute/core_zlatro.c b/coreblas/compute/core_zlatro.c
index a139fd5d6..eb276e373 100644
--- a/coreblas/compute/core_zlatro.c
+++ b/coreblas/compute/core_zlatro.c
@@ -92,7 +92,7 @@ int CORE_zlatro(cham_uplo_t uplo, cham_trans_t trans,
         coreblas_error(1, "Illegal value of uplo");
         return -1;
     }
-    if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) {
+    if ( !isValidTrans( trans ) ) {
         coreblas_error(2, "Illegal value of trans");
         return -2;
     }
diff --git a/coreblas/compute/core_zpemv.c b/coreblas/compute/core_zpemv.c
index 0144020a7..236139d23 100644
--- a/coreblas/compute/core_zpemv.c
+++ b/coreblas/compute/core_zpemv.c
@@ -138,7 +138,7 @@ int CORE_zpemv(cham_trans_t trans, cham_store_t storev,
 
 
     /* Check input arguments */
-    if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) {
+    if ( !isValidTrans( trans ) ) {
         coreblas_error(1, "Illegal value of trans");
         return -1;
     }
diff --git a/coreblas/compute/core_ztradd.c b/coreblas/compute/core_ztradd.c
index 09af0f440..f87388f4d 100644
--- a/coreblas/compute/core_ztradd.c
+++ b/coreblas/compute/core_ztradd.c
@@ -113,7 +113,7 @@ int CORE_ztradd(cham_uplo_t uplo, cham_trans_t trans, int M, int N,
         return -1;
     }
 
-    if ((trans < ChamNoTrans) || (trans > ChamConjTrans))
+    if ( !isValidTrans( trans ) )
     {
         coreblas_error(2, "illegal value of trans");
         return -2;
@@ -153,45 +153,50 @@ int CORE_ztradd(cham_uplo_t uplo, cham_trans_t trans, int M, int N,
      * ChamLower
      */
     if (uplo == ChamLower) {
-        switch( trans ) {
+        if( trans == ChamNoTrans ) {
+            for (j=0; j<minMN; j++) {
+                for(i=j; i<M; i++, B++, A++) {
+                    *B += alpha * (*A);
+                }
+                B += LDB-M+j+1;
+                A += LDA-M+j+1;
+            }
+        }
 #if defined(PRECISION_z) || defined(PRECISION_c)
-        case ChamConjTrans:
+        else if ( trans == ChamConjTrans ) {
             for (j=0; j<minMN; j++, A++) {
                 for(i=j; i<M; i++, B++) {
                     *B += alpha * conj(A[LDA*i]);
                 }
                 B += LDB-M+j+1;
             }
-            break;
+        }
 #endif /* defined(PRECISION_z) || defined(PRECISION_c) */
-
-        case ChamTrans:
+        else {
             for (j=0; j<minMN; j++, A++) {
                 for(i=j; i<M; i++, B++) {
                     *B += alpha * A[LDA*i];
                 }
                 B += LDB-M+j+1;
             }
-            break;
-
-        case ChamNoTrans:
-        default:
-            for (j=0; j<minMN; j++) {
-                for(i=j; i<M; i++, B++, A++) {
-                    *B += alpha * (*A);
-                }
-                B += LDB-M+j+1;
-                A += LDA-M+j+1;
-            }
         }
     }
     /**
      * ChamUpper
      */
     else {
-        switch( trans ) {
+        if ( trans == ChamNoTrans ) {
+            for (j=0; j<N; j++) {
+                int mm = chameleon_min( j+1, M );
+                for(i=0; i<mm; i++, B++, A++) {
+                    *B += alpha * (*A);
+                }
+                B += LDB-mm;
+                A += LDA-mm;
+            }
+        }
 #if defined(PRECISION_z) || defined(PRECISION_c)
-        case ChamConjTrans:
+        else if ( trans == ChamConjTrans ) {
             for (j=0; j<N; j++, A++) {
                 int mm = chameleon_min( j+1, M );
                 for(i=0; i<mm; i++, B++) {
@@ -199,10 +204,9 @@ int CORE_ztradd(cham_uplo_t uplo, cham_trans_t trans, int M, int N,
                 }
                 B += LDB-mm;
             }
-            break;
+        }
 #endif /* defined(PRECISION_z) || defined(PRECISION_c) */
-
-        case ChamTrans:
+        else {
             for (j=0; j<N; j++, A++) {
                 int mm = chameleon_min( j+1, M );
                 for(i=0; i<mm; i++, B++) {
@@ -210,18 +214,6 @@ int CORE_ztradd(cham_uplo_t uplo, cham_trans_t trans, int M, int N,
                 }
                 B += LDB-mm;
             }
-            break;
-
-        case ChamNoTrans:
-        default:
-            for (j=0; j<N; j++) {
-                int mm = chameleon_min( j+1, M );
-                for(i=0; i<mm; i++, B++, A++) {
-                    *B += alpha * (*A);
-                }
-                B += LDB-mm;
-                A += LDA-mm;
-            }
         }
     }
     return CHAMELEON_SUCCESS;
diff --git a/include/chameleon/constants.h b/include/chameleon/constants.h
index b58d4bcc6..f6c8b15e2 100644
--- a/include/chameleon/constants.h
+++ b/include/chameleon/constants.h
@@ -63,6 +63,17 @@ typedef enum chameleon_trans_e {
     ChamConjTrans = 113  /**< Use conj(A^t) */
 } cham_trans_t;
 
+static inline int
+isValidTrans( cham_trans_t trans )
+{
+    if ( (trans >= ChamNoTrans) && (trans <= ChamConjTrans) ) {
+        return 1;
+    }
+    else {
+        return 0;
+    }
+}
+
 /**
  * @brief Upper/Lower part
  */
-- 
GitLab