From a1c35a9a578afba9ed6d664102e1f5cf143d16c3 Mon Sep 17 00:00:00 2001
From: Alycia Lisito <alycia.lisito@inria.fr>
Date: Tue, 3 Sep 2024 09:23:15 +0200
Subject: [PATCH] zgetrf: Lookahead function

---
 compute/pzgetrf.c                        |  6 +++-
 include/chameleon/runtime.h              | 23 ++++++++++++++++
 runtime/openmp/control/runtime_control.c | 12 ++++++++
 runtime/parsec/control/runtime_control.c | 13 +++++++++
 runtime/quark/control/runtime_control.c  | 13 +++++++++
 runtime/starpu/control/runtime_control.c | 35 ++++++++++++++++++++++++
 6 files changed, 101 insertions(+), 1 deletion(-)

diff --git a/compute/pzgetrf.c b/compute/pzgetrf.c
index 635bbbb84..4428dc638 100644
--- a/compute/pzgetrf.c
+++ b/compute/pzgetrf.c
@@ -722,7 +722,8 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
     RUNTIME_option_t options;
 
     int k, m, n, tempkm, tempnn;
-    int min_mnt = chameleon_min( A->mt, A->nt );
+    int min_mnt  = chameleon_min( A->mt, A->nt );
+    int nb_tasks = 0;
 
     chamctxt = chameleon_context_self();
     if (sequence->status != CHAMELEON_SUCCESS) {
@@ -761,6 +762,9 @@ void chameleon_pzgetrf( struct chameleon_pzgetrf_s *ws,
         }
         RUNTIME_data_flush( sequence, Wu(A->myrank, k) );
 
+        if ( chamctxt->lookahead > 0 ) {
+            nb_tasks = RUNTIME_lookahead( chamctxt, k, nb_tasks );
+        }
         RUNTIME_iteration_pop( chamctxt );
     }
     CHAMELEON_Desc_Flush( &(ws->Wl), sequence );
diff --git a/include/chameleon/runtime.h b/include/chameleon/runtime.h
index 52993c9a6..82cf97773 100644
--- a/include/chameleon/runtime.h
+++ b/include/chameleon/runtime.h
@@ -182,6 +182,29 @@ RUNTIME_progress( CHAM_context_t *ctxt );
 int
 RUNTIME_thread_rank( CHAM_context_t *ctxt );
 
+/**
+ * @brief Lookahead based on the first iterations.
+ * Counts how many tasks are sumbitted at the first n iterations (with n the lookahead)
+ * and sets the sum of that as the limit.
+ * Pauses the task submission if the number of tasks submitted if greater than the
+ * limit.
+ *
+ * @param[in] chamctxt
+ *            The runtime context for which the thread rank is asked.
+ *
+ * @param[in] k
+ *            The iteration.
+ *
+ * @param[in] nb_tasks
+ *            The limit number of tasks.
+ *
+ * @retval The limit number of tasks.
+ */
+int
+RUNTIME_lookahead( CHAM_context_t *chamctxt,
+                   int             k,
+                   int             nb_tasks );
+
 /**
  * @brief Get the number of CPU workers of the runtime.
  *
diff --git a/runtime/openmp/control/runtime_control.c b/runtime/openmp/control/runtime_control.c
index e42395663..908bc52d3 100644
--- a/runtime/openmp/control/runtime_control.c
+++ b/runtime/openmp/control/runtime_control.c
@@ -92,6 +92,18 @@ void RUNTIME_progress( CHAM_context_t *chamctxt )
     return;
 }
 
+/**
+ *  Lookahead
+ */
+int RUNTIME_lookahead( CHAM_context_t *chamctxt,
+                       int             k,
+                       int             nb_tasks )
+{
+    (void)chamctxt;
+    (void)k;
+    (void)nb_tasks;
+    return;
+}
 
 /**
  * Thread rank.
diff --git a/runtime/parsec/control/runtime_control.c b/runtime/parsec/control/runtime_control.c
index fef8b4173..1d1415063 100644
--- a/runtime/parsec/control/runtime_control.c
+++ b/runtime/parsec/control/runtime_control.c
@@ -107,6 +107,19 @@ void RUNTIME_progress( CHAM_context_t *chamctxt )
     return;
 }
 
+/**
+ *  Lookahead
+ */
+int RUNTIME_lookahead( CHAM_context_t *chamctxt,
+                       int             k,
+                       int             nb_tasks )
+{
+    (void)chamctxt;
+    (void)k;
+    (void)nb_tasks;
+    return;
+}
+
 /**
  * Thread rank.
  */
diff --git a/runtime/quark/control/runtime_control.c b/runtime/quark/control/runtime_control.c
index a7c61ba27..cf7f42088 100644
--- a/runtime/quark/control/runtime_control.c
+++ b/runtime/quark/control/runtime_control.c
@@ -97,6 +97,19 @@ void RUNTIME_progress( CHAM_context_t *chamctxt )
     return;
 }
 
+/**
+ *  Lookahead
+ */
+int RUNTIME_lookahead( CHAM_context_t *chamctxt,
+                       int             k,
+                       int             nb_tasks )
+{
+    (void)chamctxt;
+    (void)k;
+    (void)nb_tasks;
+    return;
+}
+
 /**
  * Thread rank.
  */
diff --git a/runtime/starpu/control/runtime_control.c b/runtime/starpu/control/runtime_control.c
index c2cb79397..8ed654e06 100644
--- a/runtime/starpu/control/runtime_control.c
+++ b/runtime/starpu/control/runtime_control.c
@@ -353,6 +353,41 @@ void RUNTIME_progress( CHAM_context_t *chamctxt )
     return;
 }
 
+/**
+ * Lookahead based on the first iterations
+ * Counts how many tasks are sumbitted at the first n iterations (with n the lookahead)
+ * and sets the sum of that as the limit.
+ * Pauses the task submission if the number of tasks submitted if greater than the
+ * limit.
+ */
+int RUNTIME_lookahead( CHAM_context_t *chamctxt,
+                       int             k,
+                       int             nb_tasks )
+{
+    int tasks_submit, lookahead;
+    int max = 0;
+
+    /* Get how many tasks are currently submitted */
+    tasks_submit = starpu_task_nsubmitted();
+    if ( tasks_submit == 0 ) {
+        return chameleon_max( nb_tasks, chamctxt->nworkers * 20 );
+    }
+
+    lookahead = chamctxt->lookahead;
+
+    /* Add the number of tasks currently submitted to the previous task count */
+    if ( k < lookahead ) {
+        return nb_tasks + chameleon_max( tasks_submit, chamctxt->nworkers * 20 );
+    }
+
+    /* Wait until the number of tasks submitted is smaller than the limit */
+    while ( starpu_task_nsubmitted() > nb_tasks ) {
+        usleep(10000);
+    }
+
+    return nb_tasks;
+}
+
 /**
  * Thread rank.
  */
-- 
GitLab