From 2279f49a3b6e989f6ece8db633dc9fde168d2b90 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Wed, 5 Apr 2017 17:11:12 +0200
Subject: [PATCH] Add cuda_zttmqr

---
 cudablas/compute/CMakeLists.txt |   1 +
 cudablas/compute/cuda_ztpmqrt.c |  16 ++--
 cudablas/compute/cuda_zttmqr.c  | 151 ++++++++++++++++++++++++++++++++
 cudablas/include/cudablas_z.h   |   1 +
 4 files changed, 160 insertions(+), 9 deletions(-)
 create mode 100644 cudablas/compute/cuda_zttmqr.c

diff --git a/cudablas/compute/CMakeLists.txt b/cudablas/compute/CMakeLists.txt
index 40686c92f..06d25bd47 100644
--- a/cudablas/compute/CMakeLists.txt
+++ b/cudablas/compute/CMakeLists.txt
@@ -43,6 +43,7 @@ set(ZSRC
     cuda_ztrsm.c
     cuda_ztsmlq.c
     cuda_ztsmqr.c
+    cuda_zttmqr.c
     cuda_zunmlqt.c
     cuda_zunmqrt.c
     )
diff --git a/cudablas/compute/cuda_ztpmqrt.c b/cudablas/compute/cuda_ztpmqrt.c
index b9f6afac4..640f9fdc8 100644
--- a/cudablas/compute/cuda_ztpmqrt.c
+++ b/cudablas/compute/cuda_ztpmqrt.c
@@ -48,14 +48,14 @@ CUDA_ztpmqrt( MORSE_enum side, MORSE_enum trans,
         n1 = N;
         ldwork  = IB;
         ldworkc = M;
-        ws = K * n1;
+        ws = IB * n1;
     }
     else {
         m1 = M;
         n1 = K;
         ldwork  = m1;
         ldworkc = IB;
-        ws = m1 * K;
+        ws = IB * m1;
     }
 
     /* TS case */
@@ -67,16 +67,14 @@ CUDA_ztpmqrt( MORSE_enum side, MORSE_enum trans,
     }
     /* TT case */
     else  if( L == M ) {
-        cudablas_error(-6, "TTMQRT not available on GPU yet\n" );
-        return -6;
-        /* CUDA_zttmqr( side, trans, m1, n1, M, N, K, IB, */
-        /*              A, LDA, B, LDB, V, LDV, T, LDT, */
-        /*              WORK, ldwork ); */
+        CUDA_zttmqr( side, trans, m1, n1, M, N, K, IB,
+                     A, LDA, B, LDB, V, LDV, T, LDT,
+                     WORK, ldwork, WORK + ws, ldworkc,
+                     CUBLAS_STREAM_VALUE );
     }
     else {
-        cudablas_error(-6, "TPMQRT not available on GPU yet\n" );
+        cudablas_error(-6, "TPMQRT not available on GPU for general cases yet\n" );
         return -6;
-        //LAPACKE_ztpmqrt_work( LAPACK_COL_MAJOR, M, N, K, L, IB, V, LDV, T, LDT, A, LDA, B, LDB, WORK );
     }
 
     return MORSE_SUCCESS;
diff --git a/cudablas/compute/cuda_zttmqr.c b/cudablas/compute/cuda_zttmqr.c
new file mode 100644
index 000000000..107c7dec3
--- /dev/null
+++ b/cudablas/compute/cuda_zttmqr.c
@@ -0,0 +1,151 @@
+/**
+ *
+ * @copyright (c) 2009-2014 The University of Tennessee and The University
+ *                          of Tennessee Research Foundation.
+ *                          All rights reserved.
+ * @copyright (c) 2012-2017 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
+ *                          Univ. Bordeaux. All rights reserved.
+ **/
+
+/**
+ *
+ * @file cuda_zttmqr.c
+ *
+ *  MORSE cudablas kernel
+ *  MORSE is a software package provided by Univ. of Tennessee,
+ *  Univ. of California Berkeley and Univ. of Colorado Denver,
+ *  and INRIA Bordeaux Sud-Ouest
+ *
+ * @author Florent Pruvost
+ * @author Mathieu Faverge
+ * @date 2015-09-16
+ * @precisions normal z -> c d s
+ *
+ **/
+#include "cudablas/include/cudablas.h"
+#include "cudablas/include/cudablas_z.h"
+
+int CUDA_zttmqr(
+        MORSE_enum side, MORSE_enum trans,
+        int M1, int N1,
+        int M2, int N2,
+        int K, int IB,
+              cuDoubleComplex *A1,    int LDA1,
+              cuDoubleComplex *A2,    int LDA2,
+        const cuDoubleComplex *V,     int LDV,
+        const cuDoubleComplex *T,     int LDT,
+              cuDoubleComplex *WORK,  int LDWORK,
+              cuDoubleComplex *WORKC, int LDWORKC,
+        CUBLAS_STREAM_PARAM)
+{
+    int i, i1, i3;
+    int NQ, NW;
+    int kb;
+    int ic = 0;
+    int jc = 0;
+    int mi1 = M1;
+    int mi2 = M2;
+    int ni1 = N1;
+    int ni2 = N2;
+
+    /* Check input arguments */
+    if ((side != MorseLeft) && (side != MorseRight)) {
+        return -1;
+    }
+
+    /* NQ is the order of Q */
+    if (side == MorseLeft) {
+        NQ = M2;
+        NW = IB;
+    }
+    else {
+        NQ = N2;
+        NW = M1;
+    }
+
+    if ((trans != MorseNoTrans) && (trans != MorseConjTrans)) {
+        return -2;
+    }
+    if (M1 < 0) {
+        return -3;
+    }
+    if (N1 < 0) {
+        return -4;
+    }
+    if ( (M2 < 0) ||
+         ( (M2 != M1) && (side == MorseRight) ) ){
+        return -5;
+    }
+    if ( (N2 < 0) ||
+         ( (N2 != N1) && (side == MorseLeft) ) ){
+        return -6;
+    }
+    if ((K < 0) ||
+        ( (side == MorseLeft)  && (K > M1) ) ||
+        ( (side == MorseRight) && (K > N1) ) ) {
+        return -7;
+    }
+    if (IB < 0) {
+        return -8;
+    }
+    if (LDA1 < chameleon_max(1,M1)){
+        return -10;
+    }
+    if (LDA2 < chameleon_max(1,M2)){
+        return -12;
+    }
+    if (LDV < chameleon_max(1,NQ)){
+        return -14;
+    }
+    if (LDT < chameleon_max(1,IB)){
+        return -16;
+    }
+    if (LDWORK < chameleon_max(1,NW)){
+        return -18;
+    }
+
+    /* Quick return */
+    if ((M1 == 0) || (N1 == 0) || (M2 == 0) || (N2 == 0) || (K == 0) || (IB == 0))
+        return MORSE_SUCCESS;
+
+    if (((side == MorseLeft)  && (trans != MorseNoTrans))
+        || ((side == MorseRight) && (trans == MorseNoTrans))) {
+        i1 = 0;
+        i3 = IB;
+    }
+    else {
+        i1 = ((K-1) / IB)*IB;
+        i3 = -IB;
+    }
+
+    for(i = i1; (i > -1) && (i < K); i += i3) {
+        kb = chameleon_min(IB, K-i);
+
+        if (side == MorseLeft) {
+            mi1 = kb;
+            mi2 = chameleon_min(i+kb, M2);
+            l   = chameleon_min(kb, chameleon_max(0, M2-i));
+            ic  = i;
+        }
+        else {
+            ni1 = kb;
+            ni2 = chameleon_min(i+kb, N2);
+            l   = chameleon_min(kb, chameleon_max(0, N2-i));
+            jc  = i;
+        }
+
+        /*
+         * Apply H or H' (NOTE: CORE_zparfb used to be CORE_zttrfb)
+         */
+        CUDA_zparfb(
+            side, trans, MorseForward, MorseColumnwise,
+            mi1, ni1, mi2, ni2, kb, l,
+            A1 + LDA1*jc+ic, LDA1,
+            A2, LDA2,
+            V + LDV*i, LDV,
+            T + LDT*i, LDT,
+            WORK, LDWORK,
+            WORKC, LDWORKC, CUBLAS_STREAM_VALUE );
+    }
+    return MORSE_SUCCESS;
+}
diff --git a/cudablas/include/cudablas_z.h b/cudablas/include/cudablas_z.h
index e73912750..911d1ef93 100644
--- a/cudablas/include/cudablas_z.h
+++ b/cudablas/include/cudablas_z.h
@@ -51,6 +51,7 @@ int CUDA_ztrmm(  MORSE_enum side, MORSE_enum uplo, MORSE_enum transa, MORSE_enum
 int CUDA_ztrsm(  MORSE_enum side, MORSE_enum uplo, MORSE_enum transa, MORSE_enum diag, int m, int n, cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb, CUBLAS_STREAM_PARAM);
 int CUDA_ztsmlq( MORSE_enum side, MORSE_enum trans, int M1, int N1, int M2, int N2, int K, int IB, cuDoubleComplex *A1, int LDA1, cuDoubleComplex *A2, int LDA2, const cuDoubleComplex *V, int LDV, const cuDoubleComplex *T, int LDT, cuDoubleComplex *WORK, int LDWORK, cuDoubleComplex *WORKC, int LDWORKC, CUBLAS_STREAM_PARAM);
 int CUDA_ztsmqr( MORSE_enum side, MORSE_enum trans, int M1, int N1, int M2, int N2, int K, int IB, cuDoubleComplex *A1, int LDA1, cuDoubleComplex *A2, int LDA2, const cuDoubleComplex *V, int LDV, const cuDoubleComplex *T, int LDT, cuDoubleComplex *WORK, int LDWORK, cuDoubleComplex *WORKC, int LDWORKC, CUBLAS_STREAM_PARAM);
+int CUDA_zttmqr( MORSE_enum side, MORSE_enum trans, int M1, int N1, int M2, int N2, int K, int IB, cuDoubleComplex *A1, int LDA1, cuDoubleComplex *A2, int LDA2, const cuDoubleComplex *V, int LDV, const cuDoubleComplex *T, int LDT, cuDoubleComplex *WORK, int LDWORK, cuDoubleComplex *WORKC, int LDWORKC, CUBLAS_STREAM_PARAM);
 int CUDA_zunmlqt(MORSE_enum side, MORSE_enum trans, int M, int N, int K, int IB, const cuDoubleComplex *A,    int LDA, const cuDoubleComplex *T,    int LDT, cuDoubleComplex *C,    int LDC, cuDoubleComplex *WORK, int LDWORK, CUBLAS_STREAM_PARAM );
 int CUDA_zunmqrt(MORSE_enum side, MORSE_enum trans, int M, int N, int K, int IB, const cuDoubleComplex *A,    int LDA, const cuDoubleComplex *T,    int LDT, cuDoubleComplex *C,    int LDC, cuDoubleComplex *WORK, int LDWORK, CUBLAS_STREAM_PARAM );
 
-- 
GitLab