From e455a279983378e4d7525a4bd37d3874bf7f6211 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Tue, 16 Mar 2021 14:33:53 +0100
Subject: [PATCH] Update GEPDF_QDWH

---
 compute/pzgepdf_qdwh.c | 26 +++++++++++++++++++-------
 1 file changed, 19 insertions(+), 7 deletions(-)

diff --git a/compute/pzgepdf_qdwh.c b/compute/pzgepdf_qdwh.c
index 219a3f2e4..8169d774c 100644
--- a/compute/pzgepdf_qdwh.c
+++ b/compute/pzgepdf_qdwh.c
@@ -99,13 +99,17 @@ chameleon_pzgepdf_parameters( double Li, double *params )
  * @param[out] B2, TS2, TT2, Q2, D2
  *        Set of workspace that match the size of H for the QR iteration
  *
+ * @param[out] gemm_ws
+ *        Pointer to the workspace data structure used by the GEMM operation
+ *
  */
 static inline void
 chameleon_pzgepdf_qdwh_init( const CHAM_desc_t *U, const CHAM_desc_t *H,
                              libhqr_tree_t *qrtreeT, libhqr_tree_t *qrtreeB,
                              CHAM_desc_t *A,  CHAM_desc_t *Ut,
                              CHAM_desc_t *B1, CHAM_desc_t *TS1, CHAM_desc_t *TT1, CHAM_desc_t *Q1, CHAM_desc_t *D1,
-                             CHAM_desc_t *B2, CHAM_desc_t *TS2, CHAM_desc_t *TT2, CHAM_desc_t *Q2, CHAM_desc_t *D2 )
+                             CHAM_desc_t *B2, CHAM_desc_t *TS2, CHAM_desc_t *TT2, CHAM_desc_t *Q2, CHAM_desc_t *D2,
+                             void **gemm_ws )
 {
     CHAM_context_t *chamctxt;
     int ib, nb = U->nb;
@@ -203,6 +207,12 @@ chameleon_pzgepdf_qdwh_init( const CHAM_desc_t *U, const CHAM_desc_t *H,
                          U->n, U->m, 0, 0,
                          U->n, U->m, U->p, U->q,
                          NULL, NULL, NULL );
+
+    /*
+     * Allocate the data descriptors for the lookahead if needed
+     */
+    *gemm_ws = CHAMELEON_zgemm_WS_Alloc( ChamNoTrans, ChamNoTrans, NULL, NULL, U );
+
     return;
 }
 
@@ -425,7 +435,7 @@ chameleon_pzgeqdwh_condest_qr( CHAM_context_t *chamctxt,
 static inline void
 chameleon_pzgepdf_qdwh_qrstep( int do_qr, int last, double *params,
                                const libhqr_tree_t *qrtreeT, const libhqr_tree_t *qrtreeB,
-                               CHAM_desc_t *U,
+                               CHAM_desc_t *U, void *gemm_ws,
                                CHAM_desc_t *B1, CHAM_desc_t *TS1, CHAM_desc_t *TT1, CHAM_desc_t *Q1, CHAM_desc_t *D1,
                                CHAM_desc_t *B2, CHAM_desc_t *TS2, CHAM_desc_t *TT2, CHAM_desc_t *Q2, CHAM_desc_t *D2,
                                RUNTIME_sequence_t *sequence, RUNTIME_request_t *request )
@@ -474,7 +484,7 @@ chameleon_pzgepdf_qdwh_qrstep( int do_qr, int last, double *params,
      */
     beta  = b / c;
     alpha = ( a - beta ) / sqrt_c;
-    chameleon_pzgemm( ChamNoTrans, ChamConjTrans,
+    chameleon_pzgemm( gemm_ws, ChamNoTrans, ChamConjTrans,
                       alpha, Q1, Q2, beta, U,
                       sequence, request );
 
@@ -621,6 +631,7 @@ chameleon_pzgepdf_qdwh( cham_mtxtype_t mtxtype, CHAM_desc_t *descU, CHAM_desc_t
     CHAM_desc_t descB1, descTS1, descTT1, descD1, descQ1, *D1ptr;
     CHAM_desc_t descB2, descTS2, descTT2, descD2, descQ2, *D2ptr;
     libhqr_tree_t qrtreeT, qrtreeB;
+    void *gemm_ws;
 
     double conv = 100.;
     double Li, params[3];
@@ -657,7 +668,7 @@ chameleon_pzgepdf_qdwh( cham_mtxtype_t mtxtype, CHAM_desc_t *descU, CHAM_desc_t
                                  &qrtreeT, &qrtreeB,
                                  &descA, &descUt,
                                  &descB1, &descTS1, &descTT1, &descQ1, &descD1,
-                                 &descB2, &descTS2, &descTT2, &descQ2, &descD2 );
+                                 &descB2, &descTS2, &descTT2, &descQ2, &descD2, &gemm_ws );
     if ( _zgepdf_qdwh_opt_genD ) {
         D1ptr = &descD1;
         D2ptr = &descD2;
@@ -762,7 +773,7 @@ chameleon_pzgepdf_qdwh( cham_mtxtype_t mtxtype, CHAM_desc_t *descU, CHAM_desc_t
             }
 
             chameleon_pzgepdf_qdwh_qrstep( do_qr, last, params,
-                                           &qrtreeT, &qrtreeB, descU,
+                                           &qrtreeT, &qrtreeB, descU, gemm_ws,
                                            &descB1, &descTS1, &descTT1, &descQ1, D1ptr,
                                            &descB2, &descTS2, &descTT2, &descQ2, D2ptr,
                                            sequence_it, request_it );
@@ -859,9 +870,10 @@ chameleon_pzgepdf_qdwh( cham_mtxtype_t mtxtype, CHAM_desc_t *descU, CHAM_desc_t
         }
         break;
     default:
-        chameleon_pzgemm( ChamConjTrans, ChamNoTrans,
+        chameleon_pzgemm( gemm_ws, ChamConjTrans, ChamNoTrans,
                           1., descU, &descA,
-                          0., descH, sequence, request );
+                          0., descH,
+                          sequence, request );
         if ( info ) {
             info->flops += flops_zgemm( descH->m, descH->n, descA.m );
         }
-- 
GitLab