From 16a70a7eb1e400aff2f84f89af3fac8a908378ca Mon Sep 17 00:00:00 2001 From: Mathieu Faverge <mathieu.faverge@inria.fr> Date: Tue, 5 Nov 2019 00:10:03 +0100 Subject: [PATCH] Add a function to check if trans parameters are valide without being affected by the conversion --- compute/zgeadd.c | 4 +-- compute/zgemm.c | 8 ++--- compute/ztradd.c | 4 +-- compute/ztrmm.c | 29 +++++++--------- compute/ztrsm.c | 24 +++++++------- coreblas/compute/core_zgeadd.c | 31 ++++++++---------- coreblas/compute/core_zlatro.c | 2 +- coreblas/compute/core_zpemv.c | 2 +- coreblas/compute/core_ztradd.c | 60 +++++++++++++++------------------- include/chameleon/constants.h | 11 +++++++ 10 files changed, 85 insertions(+), 90 deletions(-) diff --git a/compute/zgeadd.c b/compute/zgeadd.c index c98f82c73..6b2c6aa5a 100644 --- a/compute/zgeadd.c +++ b/compute/zgeadd.c @@ -105,7 +105,7 @@ int CHAMELEON_zgeadd( cham_trans_t trans, int M, int N, } /* Check input arguments */ - if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) { + if ( !isValidTrans( trans ) ) { chameleon_error("CHAMELEON_zgeadd", "illegal value of trans"); return -1; } @@ -317,7 +317,7 @@ int CHAMELEON_zgeadd_Tile_Async( cham_trans_t trans, return chameleon_request_fail(sequence, request, CHAMELEON_ERR_ILLEGAL_VALUE); } /* Check input arguments */ - if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) { + if ( !isValidTrans( trans ) ) { chameleon_error("CHAMELEON_zgeadd_Tile_Async", "illegal value of trans"); return chameleon_request_fail(sequence, request, -1); } diff --git a/compute/zgemm.c b/compute/zgemm.c index 6de404567..8fc4e8f89 100644 --- a/compute/zgemm.c +++ b/compute/zgemm.c @@ -146,11 +146,11 @@ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int } /* Check input arguments */ - if ((transA < ChamNoTrans) || (transA > ChamConjTrans)) { + if ( !isValidTrans( transA ) ) { chameleon_error("CHAMELEON_zgemm", "illegal value of transA"); return -1; } - if ((transB < ChamNoTrans) || (transB > ChamConjTrans)) { + if ( !isValidTrans( transB ) ) { chameleon_error("CHAMELEON_zgemm", "illegal value of transB"); return -2; } @@ -394,11 +394,11 @@ int CHAMELEON_zgemm_Tile_Async( cham_trans_t transA, cham_trans_t transB, return chameleon_request_fail(sequence, request, CHAMELEON_ERR_ILLEGAL_VALUE); } /* Check input arguments */ - if ((transA < ChamNoTrans) || (transA > ChamConjTrans)) { + if ( !isValidTrans( transA ) ) { chameleon_error("CHAMELEON_zgemm_Tile_Async", "illegal value of transA"); return chameleon_request_fail(sequence, request, -1); } - if ((transB < ChamNoTrans) || (transB > ChamConjTrans)) { + if ( !isValidTrans( transB ) ) { chameleon_error("CHAMELEON_zgemm_Tile_Async", "illegal value of transB"); return chameleon_request_fail(sequence, request, -2); } diff --git a/compute/ztradd.c b/compute/ztradd.c index 02cd629f3..0b1700a15 100644 --- a/compute/ztradd.c +++ b/compute/ztradd.c @@ -115,7 +115,7 @@ int CHAMELEON_ztradd( cham_uplo_t uplo, cham_trans_t trans, int M, int N, chameleon_error("CHAMELEON_ztradd", "illegal value of uplo"); return -1; } - if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) { + if ( !isValidTrans( trans ) ) { chameleon_error("CHAMELEON_ztradd", "illegal value of trans"); return -2; } @@ -333,7 +333,7 @@ int CHAMELEON_ztradd_Tile_Async( cham_uplo_t uplo, cham_trans_t trans, return chameleon_request_fail(sequence, request, CHAMELEON_ERR_ILLEGAL_VALUE); } /* Check input arguments */ - if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) { + if ( !isValidTrans( trans ) ) { chameleon_error("CHAMELEON_ztradd_Tile_Async", "illegal value of trans"); return chameleon_request_fail(sequence, request, -1); } diff --git a/compute/ztrmm.c b/compute/ztrmm.c index 8d1b203fb..09399a3b5 100644 --- a/compute/ztrmm.c +++ b/compute/ztrmm.c @@ -42,7 +42,7 @@ * = ChamUpper: Upper triangle of A is stored; * = ChamLower: Lower triangle of A is stored. * - * @param[in] transA + * @param[in] trans * Specifies whether the matrix A is transposed, not transposed or conjugate transposed: * = ChamNoTrans: A is transposed; * = ChamTrans: A is not transposed; @@ -97,7 +97,7 @@ * */ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo, - cham_trans_t transA, cham_diag_t diag, + cham_trans_t trans, cham_diag_t diag, int N, int NRHS, CHAMELEON_Complex64_t alpha, CHAMELEON_Complex64_t *A, int LDA, CHAMELEON_Complex64_t *B, int LDB ) @@ -131,13 +131,8 @@ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo, chameleon_error("CHAMELEON_ztrmm", "illegal value of uplo"); return -2; } - if ((transA != ChamNoTrans) && -#if defined(PRECISION_z) || defined(PRECISION_c) - (transA != ChamConjTrans) && -#endif - (transA != ChamTrans) ) - { - chameleon_error("CHAMELEON_ztrmm", "illegal value of transA"); + if ( !isValidTrans( trans ) ) { + chameleon_error("CHAMELEON_ztrmm", "illegal value of trans"); return -3; } if ((diag != ChamUnit) && (diag != ChamNonUnit)) { @@ -183,7 +178,7 @@ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo, B, NB, NB, LDB, NRHS, N, NRHS, sequence, &request ); /* Call the tile interface */ - CHAMELEON_ztrmm_Tile_Async( side, uplo, transA, diag, alpha, &descAt, &descBt, sequence, &request ); + CHAMELEON_ztrmm_Tile_Async( side, uplo, trans, diag, alpha, &descAt, &descBt, sequence, &request ); /* Submit the matrix conversion back */ chameleon_ztile2lap( chamctxt, &descAl, &descAt, @@ -224,7 +219,7 @@ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo, * = ChamUpper: Upper triangle of A is stored; * = ChamLower: Lower triangle of A is stored. * - * @param[in] transA + * @param[in] trans * Specifies whether the matrix A is transposed, not transposed or conjugate transposed: * = ChamNoTrans: A is transposed; * = ChamTrans: A is not transposed; @@ -264,7 +259,7 @@ int CHAMELEON_ztrmm( cham_side_t side, cham_uplo_t uplo, * */ int CHAMELEON_ztrmm_Tile( cham_side_t side, cham_uplo_t uplo, - cham_trans_t transA, cham_diag_t diag, + cham_trans_t trans, cham_diag_t diag, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B ) { CHAM_context_t *chamctxt; @@ -279,7 +274,7 @@ int CHAMELEON_ztrmm_Tile( cham_side_t side, cham_uplo_t uplo, } chameleon_sequence_create( chamctxt, &sequence ); - CHAMELEON_ztrmm_Tile_Async(side, uplo, transA, diag, alpha, A, B, sequence, &request ); + CHAMELEON_ztrmm_Tile_Async(side, uplo, trans, diag, alpha, A, B, sequence, &request ); CHAMELEON_Desc_Flush( A, sequence ); CHAMELEON_Desc_Flush( B, sequence ); @@ -319,7 +314,7 @@ int CHAMELEON_ztrmm_Tile( cham_side_t side, cham_uplo_t uplo, * */ int CHAMELEON_ztrmm_Tile_Async( cham_side_t side, cham_uplo_t uplo, - cham_trans_t transA, cham_diag_t diag, + cham_trans_t trans, cham_diag_t diag, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ) { @@ -368,8 +363,8 @@ int CHAMELEON_ztrmm_Tile_Async( cham_side_t side, cham_uplo_t uplo, chameleon_error("CHAMELEON_ztrmm_Tile", "illegal value of uplo"); return chameleon_request_fail(sequence, request, -2); } - if ((transA < ChamNoTrans) || (transA > ChamConjTrans)) { - chameleon_error("CHAMELEON_ztrmm_Tile", "illegal value of transA"); + if ( !isValidTrans( trans ) ) { + chameleon_error("CHAMELEON_ztrmm_Tile", "illegal value of trans"); return chameleon_request_fail(sequence, request, -3); } if ((diag != ChamUnit) && (diag != ChamNonUnit)) { @@ -378,7 +373,7 @@ int CHAMELEON_ztrmm_Tile_Async( cham_side_t side, cham_uplo_t uplo, } /* Quick return */ - chameleon_pztrmm( side, uplo, transA, diag, alpha, A, B, sequence, request ); + chameleon_pztrmm( side, uplo, trans, diag, alpha, A, B, sequence, request ); return CHAMELEON_SUCCESS; } diff --git a/compute/ztrsm.c b/compute/ztrsm.c index a0989cef6..cb9edbf88 100644 --- a/compute/ztrsm.c +++ b/compute/ztrsm.c @@ -43,7 +43,7 @@ * = ChamUpper: Upper triangle of A is stored; * = ChamLower: Lower triangle of A is stored. * - * @param[in] transA + * @param[in] trans * Specifies whether the matrix A is transposed, not transposed or conjugate transposed: * = ChamNoTrans: A is transposed; * = ChamTrans: A is not transposed; @@ -98,7 +98,7 @@ * */ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo, - cham_trans_t transA, cham_diag_t diag, + cham_trans_t trans, cham_diag_t diag, int N, int NRHS, CHAMELEON_Complex64_t alpha, CHAMELEON_Complex64_t *A, int LDA, CHAMELEON_Complex64_t *B, int LDB ) @@ -132,8 +132,8 @@ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo, chameleon_error("CHAMELEON_ztrsm", "illegal value of uplo"); return -2; } - if (((transA < ChamNoTrans) || (transA > ChamConjTrans)) ) { - chameleon_error("CHAMELEON_ztrsm", "illegal value of transA"); + if ( !isValidTrans( trans ) ) { + chameleon_error("CHAMELEON_ztrsm", "illegal value of trans"); return -3; } if ((diag != ChamUnit) && (diag != ChamNonUnit)) { @@ -179,7 +179,7 @@ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo, B, NB, NB, LDB, NRHS, N, NRHS, sequence, &request ); /* Call the tile interface */ - CHAMELEON_ztrsm_Tile_Async( side, uplo, transA, diag, alpha, &descAt, &descBt, sequence, &request ); + CHAMELEON_ztrsm_Tile_Async( side, uplo, trans, diag, alpha, &descAt, &descBt, sequence, &request ); /* Submit the matrix conversion back */ chameleon_ztile2lap( chamctxt, &descAl, &descAt, @@ -220,7 +220,7 @@ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo, * = ChamUpper: Upper triangle of A is stored; * = ChamLower: Lower triangle of A is stored. * - * @param[in] transA + * @param[in] trans * Specifies whether the matrix A is transposed, not transposed or conjugate transposed: * = ChamNoTrans: A is transposed; * = ChamTrans: A is not transposed; @@ -260,7 +260,7 @@ int CHAMELEON_ztrsm( cham_side_t side, cham_uplo_t uplo, * */ int CHAMELEON_ztrsm_Tile( cham_side_t side, cham_uplo_t uplo, - cham_trans_t transA, cham_diag_t diag, + cham_trans_t trans, cham_diag_t diag, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B ) { CHAM_context_t *chamctxt; @@ -275,7 +275,7 @@ int CHAMELEON_ztrsm_Tile( cham_side_t side, cham_uplo_t uplo, } chameleon_sequence_create( chamctxt, &sequence ); - CHAMELEON_ztrsm_Tile_Async(side, uplo, transA, diag, alpha, A, B, sequence, &request ); + CHAMELEON_ztrsm_Tile_Async(side, uplo, trans, diag, alpha, A, B, sequence, &request ); CHAMELEON_Desc_Flush( A, sequence ); CHAMELEON_Desc_Flush( B, sequence ); @@ -315,7 +315,7 @@ int CHAMELEON_ztrsm_Tile( cham_side_t side, cham_uplo_t uplo, * */ int CHAMELEON_ztrsm_Tile_Async( cham_side_t side, cham_uplo_t uplo, - cham_trans_t transA, cham_diag_t diag, + cham_trans_t trans, cham_diag_t diag, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ) { @@ -364,8 +364,8 @@ int CHAMELEON_ztrsm_Tile_Async( cham_side_t side, cham_uplo_t uplo, chameleon_error("CHAMELEON_ztrsm_Tile", "illegal value of uplo"); return chameleon_request_fail(sequence, request, -2); } - if ((transA < ChamNoTrans) || (transA > ChamConjTrans)) { - chameleon_error("CHAMELEON_ztrsm_Tile", "illegal value of transA"); + if ( !isValidTrans( trans ) ) { + chameleon_error("CHAMELEON_ztrsm_Tile", "illegal value of trans"); return chameleon_request_fail(sequence, request, -3); } if ((diag != ChamUnit) && (diag != ChamNonUnit)) { @@ -374,7 +374,7 @@ int CHAMELEON_ztrsm_Tile_Async( cham_side_t side, cham_uplo_t uplo, } /* Quick return */ - chameleon_pztrsm( side, uplo, transA, diag, alpha, A, B, sequence, request ); + chameleon_pztrsm( side, uplo, trans, diag, alpha, A, B, sequence, request ); return CHAMELEON_SUCCESS; } diff --git a/coreblas/compute/core_zgeadd.c b/coreblas/compute/core_zgeadd.c index 4882a79cf..a7def06bb 100644 --- a/coreblas/compute/core_zgeadd.c +++ b/coreblas/compute/core_zgeadd.c @@ -88,7 +88,7 @@ int CORE_zgeadd(cham_trans_t trans, int M, int N, { int i, j; - if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) + if ( !isValidTrans( trans ) ) { coreblas_error(1, "illegal value of trans"); return -1; @@ -122,36 +122,33 @@ int CORE_zgeadd(cham_trans_t trans, int M, int N, 0, 0, 1., beta, M, N, B, LDB ); } - switch( trans ) { + if( trans == ChamNoTrans ) { + for (j=0; j<N; j++) { + for(i=0; i<M; i++, B++, A++) { + *B += alpha * (*A); + } + A += LDA-M; + B += LDB-M; + } + } #if defined(PRECISION_z) || defined(PRECISION_c) - case ChamConjTrans: + else if ( trans == ChamConjTrans ) { for (j=0; j<N; j++, A++) { for(i=0; i<M; i++, B++) { *B += alpha * conj(A[LDA*i]); } B += LDB-M; } - break; + } #endif /* defined(PRECISION_z) || defined(PRECISION_c) */ - - case ChamTrans: + else { for (j=0; j<N; j++, A++) { for(i=0; i<M; i++, B++) { *B += alpha * A[LDA*i]; } B += LDB-M; } - break; - - case ChamNoTrans: - default: - for (j=0; j<N; j++) { - for(i=0; i<M; i++, B++, A++) { - *B += alpha * (*A); - } - A += LDA-M; - B += LDB-M; - } } + return CHAMELEON_SUCCESS; } diff --git a/coreblas/compute/core_zlatro.c b/coreblas/compute/core_zlatro.c index a139fd5d6..eb276e373 100644 --- a/coreblas/compute/core_zlatro.c +++ b/coreblas/compute/core_zlatro.c @@ -92,7 +92,7 @@ int CORE_zlatro(cham_uplo_t uplo, cham_trans_t trans, coreblas_error(1, "Illegal value of uplo"); return -1; } - if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) { + if ( !isValidTrans( trans ) ) { coreblas_error(2, "Illegal value of trans"); return -2; } diff --git a/coreblas/compute/core_zpemv.c b/coreblas/compute/core_zpemv.c index 0144020a7..236139d23 100644 --- a/coreblas/compute/core_zpemv.c +++ b/coreblas/compute/core_zpemv.c @@ -138,7 +138,7 @@ int CORE_zpemv(cham_trans_t trans, cham_store_t storev, /* Check input arguments */ - if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) { + if ( !isValidTrans( trans ) ) { coreblas_error(1, "Illegal value of trans"); return -1; } diff --git a/coreblas/compute/core_ztradd.c b/coreblas/compute/core_ztradd.c index 09af0f440..f87388f4d 100644 --- a/coreblas/compute/core_ztradd.c +++ b/coreblas/compute/core_ztradd.c @@ -113,7 +113,7 @@ int CORE_ztradd(cham_uplo_t uplo, cham_trans_t trans, int M, int N, return -1; } - if ((trans < ChamNoTrans) || (trans > ChamConjTrans)) + if ( !isValidTrans( trans ) ) { coreblas_error(2, "illegal value of trans"); return -2; @@ -153,45 +153,50 @@ int CORE_ztradd(cham_uplo_t uplo, cham_trans_t trans, int M, int N, * ChamLower */ if (uplo == ChamLower) { - switch( trans ) { + if( trans == ChamNoTrans ) { + for (j=0; j<minMN; j++) { + for(i=j; i<M; i++, B++, A++) { + *B += alpha * (*A); + } + B += LDB-M+j+1; + A += LDA-M+j+1; + } + } #if defined(PRECISION_z) || defined(PRECISION_c) - case ChamConjTrans: + else if ( trans == ChamConjTrans ) { for (j=0; j<minMN; j++, A++) { for(i=j; i<M; i++, B++) { *B += alpha * conj(A[LDA*i]); } B += LDB-M+j+1; } - break; + } #endif /* defined(PRECISION_z) || defined(PRECISION_c) */ - - case ChamTrans: + else { for (j=0; j<minMN; j++, A++) { for(i=j; i<M; i++, B++) { *B += alpha * A[LDA*i]; } B += LDB-M+j+1; } - break; - - case ChamNoTrans: - default: - for (j=0; j<minMN; j++) { - for(i=j; i<M; i++, B++, A++) { - *B += alpha * (*A); - } - B += LDB-M+j+1; - A += LDA-M+j+1; - } } } /** * ChamUpper */ else { - switch( trans ) { + if ( trans == ChamNoTrans ) { + for (j=0; j<N; j++) { + int mm = chameleon_min( j+1, M ); + for(i=0; i<mm; i++, B++, A++) { + *B += alpha * (*A); + } + B += LDB-mm; + A += LDA-mm; + } + } #if defined(PRECISION_z) || defined(PRECISION_c) - case ChamConjTrans: + else if ( trans == ChamConjTrans ) { for (j=0; j<N; j++, A++) { int mm = chameleon_min( j+1, M ); for(i=0; i<mm; i++, B++) { @@ -199,10 +204,9 @@ int CORE_ztradd(cham_uplo_t uplo, cham_trans_t trans, int M, int N, } B += LDB-mm; } - break; + } #endif /* defined(PRECISION_z) || defined(PRECISION_c) */ - - case ChamTrans: + else { for (j=0; j<N; j++, A++) { int mm = chameleon_min( j+1, M ); for(i=0; i<mm; i++, B++) { @@ -210,18 +214,6 @@ int CORE_ztradd(cham_uplo_t uplo, cham_trans_t trans, int M, int N, } B += LDB-mm; } - break; - - case ChamNoTrans: - default: - for (j=0; j<N; j++) { - int mm = chameleon_min( j+1, M ); - for(i=0; i<mm; i++, B++, A++) { - *B += alpha * (*A); - } - B += LDB-mm; - A += LDA-mm; - } } } return CHAMELEON_SUCCESS; diff --git a/include/chameleon/constants.h b/include/chameleon/constants.h index b58d4bcc6..f6c8b15e2 100644 --- a/include/chameleon/constants.h +++ b/include/chameleon/constants.h @@ -63,6 +63,17 @@ typedef enum chameleon_trans_e { ChamConjTrans = 113 /**< Use conj(A^t) */ } cham_trans_t; +static inline int +isValidTrans( cham_trans_t trans ) +{ + if ( (trans >= ChamNoTrans) && (trans <= ChamConjTrans) ) { + return 1; + } + else { + return 0; + } +} + /** * @brief Upper/Lower part */ -- GitLab