From c19790a3f2af77a44518b65cda60c28f212be961 Mon Sep 17 00:00:00 2001
From: Alycia Lisito <alycia.lisito@inria.fr>
Date: Wed, 14 Feb 2024 16:17:52 +0100
Subject: [PATCH] zgetrf batched: add batched blocked algorithm

---
 compute/pzgetrf.c        | 75 +++++++++++++++++++++++++++++++++++++++-
 testing/CTestLists.cmake |  7 ++++
 2 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c
index 000209624..e6f3d107a 100644
--- a/compute/pzgetrf.c
+++ b/compute/pzgetrf.c
@@ -272,6 +272,74 @@ chameleon_pzgetrf_panel_facto_blocked( struct chameleon_pzgetrf_s *ws,
     RUNTIME_ipiv_flushk( options->sequence, ipiv, k );
 }
 
+/*
+ *  Factorization of panel k - dynamic scheduling - batched version / stock
+ */
+static inline void
+chameleon_pzgetrf_panel_facto_blocked_batched( struct chameleon_pzgetrf_s *ws,
+                                               CHAM_desc_t                *A,
+                                               CHAM_ipiv_t                *ipiv,
+                                               int                         k,
+                                               RUNTIME_option_t           *options )
+{
+    int m, h, b, nbblock, hmax, j;
+    int tempkm, tempkn, tempmm, minmn;
+    void **clargs = malloc( sizeof(char *) * A->p );
+    memset( clargs, 0, sizeof(char *) * A->p );
+
+    tempkm = k == A->mt-1 ? A->m-k*A->mb : A->mb;
+    tempkn = k == A->nt-1 ? A->n-k*A->nb : A->nb;
+    minmn  = chameleon_min( tempkm, tempkn );
+
+    /* Update the number of column */
+    ipiv->n = minmn;
+    nbblock = chameleon_ceil( minmn, ws->ib );
+
+    /*
+     * Algorithm per column with pivoting (no recursion)
+     */
+    /* Iterate on current panel column */
+    /* Since index h scales column h-1, we need to iterate up to minmn (included) */
+    for ( b = 0; b < nbblock; b++ ) {
+        hmax = b == nbblock-1 ? minmn + 1 - b * ws->ib : ws->ib;
+
+        for ( h = 0; h < hmax; h++ ) {
+            j =  h + b * ws->ib;
+
+            INSERT_TASK_zgetrf_panel_blocked_batched( options, tempkm, tempkn, j, k * A->mb, (void *)ws,
+                                                      A(k, k), Up(k, k), clargs, ipiv );
+
+            for ( m = k + 1; m < A->mt; m++ ) {
+                tempmm = (m == (A->mt - 1)) ? A->m - m * A->mb : A->mb;
+                INSERT_TASK_zgetrf_panel_blocked_batched( options, tempmm, tempkn, j, m * A->mb,
+                                                          (void *)ws, A(m, k), Up(k, k), clargs, ipiv );
+            }
+            INSERT_TASK_zgetrf_panel_blocked_batched_flush( options, A, k,
+                                                            Up(k, k), clargs, ipiv );
+
+            if ( (b < (nbblock-1)) && (h == hmax-1) ) {
+                INSERT_TASK_zgetrf_blocked_trsm(
+                    options,
+                    ws->ib, tempkn, b * ws->ib + hmax, ws->ib,
+                    Up(k, k),
+                    ipiv );
+            }
+
+            assert( j <= minmn );
+            if ( j < minmn ) {
+                /* Reduce globally (between MPI processes) */
+                INSERT_TASK_ipiv_reducek( options, ipiv, k, j );
+            }
+        }
+    }
+
+    free( clargs );
+
+    /* Flush temporary data used for the pivoting */
+    INSERT_TASK_ipiv_to_perm( options, k * A->mb, tempkm, minmn, ipiv, k );
+    RUNTIME_ipiv_flushk( options->sequence, ipiv, k );
+}
+
 static inline void
 chameleon_pzgetrf_panel_facto( struct chameleon_pzgetrf_s *ws,
                                CHAM_desc_t                *A,
@@ -295,7 +363,12 @@ chameleon_pzgetrf_panel_facto( struct chameleon_pzgetrf_s *ws,
         break;
 
     case ChamGetrfPPiv:
-        chameleon_pzgetrf_panel_facto_blocked( ws, A, ipiv, k, options );
+        if ( ws->batch_size > 1 ) {
+            chameleon_pzgetrf_panel_facto_blocked_batched( ws, A, ipiv, k, options );
+        }
+        else {
+            chameleon_pzgetrf_panel_facto_blocked( ws, A, ipiv, k, options );
+        }
         break;
 
     case ChamGetrfNoPiv:
diff --git a/testing/CTestLists.cmake b/testing/CTestLists.cmake
index 98bdb1939..a1b637f68 100644
--- a/testing/CTestLists.cmake
+++ b/testing/CTestLists.cmake
@@ -103,6 +103,13 @@ if (NOT CHAMELEON_SIMULATION)
             add_test( test_${cat}_${prec}getrf_ppiv ${PREFIX} ${CMD} -c -t ${THREADS} -g ${gpus} -P 1 -f input/getrf.in )
             set_tests_properties( test_${cat}_${prec}getrf_ppiv
                                 PROPERTIES ENVIRONMENT "CHAMELEON_GETRF_ALGO=ppiv;CHAMELEON_GETRF_BATCH_SIZE=1" )
+
+            if ( ${cat} STREQUAL "shm" )
+                add_test( test_${cat}_${prec}getrf_ppiv_batch ${PREFIX} ${CMD} -c -t ${THREADS} -g ${gpus} -P 1 -f input/getrf.in )
+                set_tests_properties( test_${cat}_${prec}getrf_ppiv_batch
+                                      PROPERTIES ENVIRONMENT "CHAMELEON_GETRF_ALGO=ppiv;CHAMELEON_GETRF_BATCH_SIZE=6" )
+            endif()
+
         endif()
 
         list( REMOVE_ITEM TESTSTMP print gepdf_qr )
-- 
GitLab