From 2279f49a3b6e989f6ece8db633dc9fde168d2b90 Mon Sep 17 00:00:00 2001 From: Mathieu Faverge <mathieu.faverge@inria.fr> Date: Wed, 5 Apr 2017 17:11:12 +0200 Subject: [PATCH] Add cuda_zttmqr --- cudablas/compute/CMakeLists.txt | 1 + cudablas/compute/cuda_ztpmqrt.c | 16 ++-- cudablas/compute/cuda_zttmqr.c | 151 ++++++++++++++++++++++++++++++++ cudablas/include/cudablas_z.h | 1 + 4 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 cudablas/compute/cuda_zttmqr.c diff --git a/cudablas/compute/CMakeLists.txt b/cudablas/compute/CMakeLists.txt index 40686c92f..06d25bd47 100644 --- a/cudablas/compute/CMakeLists.txt +++ b/cudablas/compute/CMakeLists.txt @@ -43,6 +43,7 @@ set(ZSRC cuda_ztrsm.c cuda_ztsmlq.c cuda_ztsmqr.c + cuda_zttmqr.c cuda_zunmlqt.c cuda_zunmqrt.c ) diff --git a/cudablas/compute/cuda_ztpmqrt.c b/cudablas/compute/cuda_ztpmqrt.c index b9f6afac4..640f9fdc8 100644 --- a/cudablas/compute/cuda_ztpmqrt.c +++ b/cudablas/compute/cuda_ztpmqrt.c @@ -48,14 +48,14 @@ CUDA_ztpmqrt( MORSE_enum side, MORSE_enum trans, n1 = N; ldwork = IB; ldworkc = M; - ws = K * n1; + ws = IB * n1; } else { m1 = M; n1 = K; ldwork = m1; ldworkc = IB; - ws = m1 * K; + ws = IB * m1; } /* TS case */ @@ -67,16 +67,14 @@ CUDA_ztpmqrt( MORSE_enum side, MORSE_enum trans, } /* TT case */ else if( L == M ) { - cudablas_error(-6, "TTMQRT not available on GPU yet\n" ); - return -6; - /* CUDA_zttmqr( side, trans, m1, n1, M, N, K, IB, */ - /* A, LDA, B, LDB, V, LDV, T, LDT, */ - /* WORK, ldwork ); */ + CUDA_zttmqr( side, trans, m1, n1, M, N, K, IB, + A, LDA, B, LDB, V, LDV, T, LDT, + WORK, ldwork, WORK + ws, ldworkc, + CUBLAS_STREAM_VALUE ); } else { - cudablas_error(-6, "TPMQRT not available on GPU yet\n" ); + cudablas_error(-6, "TPMQRT not available on GPU for general cases yet\n" ); return -6; - //LAPACKE_ztpmqrt_work( LAPACK_COL_MAJOR, M, N, K, L, IB, V, LDV, T, LDT, A, LDA, B, LDB, WORK ); } return MORSE_SUCCESS; diff --git a/cudablas/compute/cuda_zttmqr.c b/cudablas/compute/cuda_zttmqr.c new file mode 100644 index 000000000..107c7dec3 --- /dev/null +++ b/cudablas/compute/cuda_zttmqr.c @@ -0,0 +1,151 @@ +/** + * + * @copyright (c) 2009-2014 The University of Tennessee and The University + * of Tennessee Research Foundation. + * All rights reserved. + * @copyright (c) 2012-2017 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria, + * Univ. Bordeaux. All rights reserved. + **/ + +/** + * + * @file cuda_zttmqr.c + * + * MORSE cudablas kernel + * MORSE is a software package provided by Univ. of Tennessee, + * Univ. of California Berkeley and Univ. of Colorado Denver, + * and INRIA Bordeaux Sud-Ouest + * + * @author Florent Pruvost + * @author Mathieu Faverge + * @date 2015-09-16 + * @precisions normal z -> c d s + * + **/ +#include "cudablas/include/cudablas.h" +#include "cudablas/include/cudablas_z.h" + +int CUDA_zttmqr( + MORSE_enum side, MORSE_enum trans, + int M1, int N1, + int M2, int N2, + int K, int IB, + cuDoubleComplex *A1, int LDA1, + cuDoubleComplex *A2, int LDA2, + const cuDoubleComplex *V, int LDV, + const cuDoubleComplex *T, int LDT, + cuDoubleComplex *WORK, int LDWORK, + cuDoubleComplex *WORKC, int LDWORKC, + CUBLAS_STREAM_PARAM) +{ + int i, i1, i3; + int NQ, NW; + int kb; + int ic = 0; + int jc = 0; + int mi1 = M1; + int mi2 = M2; + int ni1 = N1; + int ni2 = N2; + + /* Check input arguments */ + if ((side != MorseLeft) && (side != MorseRight)) { + return -1; + } + + /* NQ is the order of Q */ + if (side == MorseLeft) { + NQ = M2; + NW = IB; + } + else { + NQ = N2; + NW = M1; + } + + if ((trans != MorseNoTrans) && (trans != MorseConjTrans)) { + return -2; + } + if (M1 < 0) { + return -3; + } + if (N1 < 0) { + return -4; + } + if ( (M2 < 0) || + ( (M2 != M1) && (side == MorseRight) ) ){ + return -5; + } + if ( (N2 < 0) || + ( (N2 != N1) && (side == MorseLeft) ) ){ + return -6; + } + if ((K < 0) || + ( (side == MorseLeft) && (K > M1) ) || + ( (side == MorseRight) && (K > N1) ) ) { + return -7; + } + if (IB < 0) { + return -8; + } + if (LDA1 < chameleon_max(1,M1)){ + return -10; + } + if (LDA2 < chameleon_max(1,M2)){ + return -12; + } + if (LDV < chameleon_max(1,NQ)){ + return -14; + } + if (LDT < chameleon_max(1,IB)){ + return -16; + } + if (LDWORK < chameleon_max(1,NW)){ + return -18; + } + + /* Quick return */ + if ((M1 == 0) || (N1 == 0) || (M2 == 0) || (N2 == 0) || (K == 0) || (IB == 0)) + return MORSE_SUCCESS; + + if (((side == MorseLeft) && (trans != MorseNoTrans)) + || ((side == MorseRight) && (trans == MorseNoTrans))) { + i1 = 0; + i3 = IB; + } + else { + i1 = ((K-1) / IB)*IB; + i3 = -IB; + } + + for(i = i1; (i > -1) && (i < K); i += i3) { + kb = chameleon_min(IB, K-i); + + if (side == MorseLeft) { + mi1 = kb; + mi2 = chameleon_min(i+kb, M2); + l = chameleon_min(kb, chameleon_max(0, M2-i)); + ic = i; + } + else { + ni1 = kb; + ni2 = chameleon_min(i+kb, N2); + l = chameleon_min(kb, chameleon_max(0, N2-i)); + jc = i; + } + + /* + * Apply H or H' (NOTE: CORE_zparfb used to be CORE_zttrfb) + */ + CUDA_zparfb( + side, trans, MorseForward, MorseColumnwise, + mi1, ni1, mi2, ni2, kb, l, + A1 + LDA1*jc+ic, LDA1, + A2, LDA2, + V + LDV*i, LDV, + T + LDT*i, LDT, + WORK, LDWORK, + WORKC, LDWORKC, CUBLAS_STREAM_VALUE ); + } + return MORSE_SUCCESS; +} diff --git a/cudablas/include/cudablas_z.h b/cudablas/include/cudablas_z.h index e73912750..911d1ef93 100644 --- a/cudablas/include/cudablas_z.h +++ b/cudablas/include/cudablas_z.h @@ -51,6 +51,7 @@ int CUDA_ztrmm( MORSE_enum side, MORSE_enum uplo, MORSE_enum transa, MORSE_enum int CUDA_ztrsm( MORSE_enum side, MORSE_enum uplo, MORSE_enum transa, MORSE_enum diag, int m, int n, cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb, CUBLAS_STREAM_PARAM); int CUDA_ztsmlq( MORSE_enum side, MORSE_enum trans, int M1, int N1, int M2, int N2, int K, int IB, cuDoubleComplex *A1, int LDA1, cuDoubleComplex *A2, int LDA2, const cuDoubleComplex *V, int LDV, const cuDoubleComplex *T, int LDT, cuDoubleComplex *WORK, int LDWORK, cuDoubleComplex *WORKC, int LDWORKC, CUBLAS_STREAM_PARAM); int CUDA_ztsmqr( MORSE_enum side, MORSE_enum trans, int M1, int N1, int M2, int N2, int K, int IB, cuDoubleComplex *A1, int LDA1, cuDoubleComplex *A2, int LDA2, const cuDoubleComplex *V, int LDV, const cuDoubleComplex *T, int LDT, cuDoubleComplex *WORK, int LDWORK, cuDoubleComplex *WORKC, int LDWORKC, CUBLAS_STREAM_PARAM); +int CUDA_zttmqr( MORSE_enum side, MORSE_enum trans, int M1, int N1, int M2, int N2, int K, int IB, cuDoubleComplex *A1, int LDA1, cuDoubleComplex *A2, int LDA2, const cuDoubleComplex *V, int LDV, const cuDoubleComplex *T, int LDT, cuDoubleComplex *WORK, int LDWORK, cuDoubleComplex *WORKC, int LDWORKC, CUBLAS_STREAM_PARAM); int CUDA_zunmlqt(MORSE_enum side, MORSE_enum trans, int M, int N, int K, int IB, const cuDoubleComplex *A, int LDA, const cuDoubleComplex *T, int LDT, cuDoubleComplex *C, int LDC, cuDoubleComplex *WORK, int LDWORK, CUBLAS_STREAM_PARAM ); int CUDA_zunmqrt(MORSE_enum side, MORSE_enum trans, int M, int N, int K, int IB, const cuDoubleComplex *A, int LDA, const cuDoubleComplex *T, int LDT, cuDoubleComplex *C, int LDC, cuDoubleComplex *WORK, int LDWORK, CUBLAS_STREAM_PARAM ); -- GitLab