From a4624e9d4bd72049c978b9979fc31a50af5feec5 Mon Sep 17 00:00:00 2001
From: Alycia Lisito <alycia.lisito@inria.fr>
Date: Wed, 5 Mar 2025 11:22:20 +0100
Subject: [PATCH] starpu/codelet: Add new task submit to
 codelet_zlaswp_batched.c

---
 .../starpu/codelets/codelet_zlaswp_batched.c  | 102 ++++++++++++++----
 1 file changed, 80 insertions(+), 22 deletions(-)

diff --git a/runtime/starpu/codelets/codelet_zlaswp_batched.c b/runtime/starpu/codelets/codelet_zlaswp_batched.c
index f43a68947..8cc2a3adc 100644
--- a/runtime/starpu/codelets/codelet_zlaswp_batched.c
+++ b/runtime/starpu/codelets/codelet_zlaswp_batched.c
@@ -18,7 +18,7 @@
 #include "chameleon_starpu_internal.h"
 #include "runtime_codelet_z.h"
 
-struct cl_laswp_batched_args_t {
+struct cl_zlaswp_batched_args_s {
     int                      tasks_nbr;
     int                      minmn;
     int                      m0[CHAMELEON_BATCH_SIZE];
@@ -32,7 +32,7 @@ cl_zlaswp_batched_cpu_func( void *descr[],
 {
     int          i, m0, minmn, *perm, *invp;
     CHAM_tile_t *A, *U, *B;
-    struct cl_laswp_batched_args_t *clargs = ( struct cl_laswp_batched_args_t * ) cl_arg;
+    struct cl_zlaswp_batched_args_s *clargs = ( struct cl_zlaswp_batched_args_s * ) cl_arg;
 
     minmn = clargs->minmn;
     perm = (int *)STARPU_VECTOR_GET_PTR( descr[0] );
@@ -73,14 +73,13 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
 {
     int task_num   = 0;
     int batch_size = ((struct chameleon_pzgetrf_s *)ws)->batch_size_swap;
-    int nhandles;
-    struct cl_laswp_batched_args_t *clargs = *clargs_ptr;
+    struct cl_zlaswp_batched_args_s *clargs = *clargs_ptr;
     if ( Am->get_rankof( Am, Amm, Amn) != Am->myrank ) {
         return;
     }
 
     if( clargs == NULL ) {
-        clargs = malloc( sizeof( struct cl_laswp_batched_args_t ) ) ;
+        clargs = malloc( sizeof( struct cl_zlaswp_batched_args_s ) ) ;
         clargs->tasks_nbr = 0;
         clargs->minmn     = minmn;
         *clargs_ptr       = clargs;
@@ -93,24 +92,12 @@ void INSERT_TASK_zlaswp_batched( const RUNTIME_option_t *options,
     clargs->tasks_nbr ++;
 
     if ( clargs->tasks_nbr == batch_size ) {
-        nhandles = clargs->tasks_nbr;
-        rt_starpu_insert_task(
-            &cl_zlaswp_batched,
-            STARPU_CL_ARGS,             clargs, sizeof(struct cl_laswp_batched_args_t),
-            STARPU_R,                   RUNTIME_perm_getaddr( ipiv, ipivk ),
-            STARPU_R,                   RUNTIME_invp_getaddr( ipiv, ipivk ),
-            STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, Um, Un),
-            STARPU_R,                   RTBLKADDR(Ak, ChamComplexDouble, Akm, Akn),
-            STARPU_DATA_MODE_ARRAY,     clargs->handle_mode, nhandles,
-            STARPU_PRIORITY,            options->priority,
-            STARPU_EXECUTE_ON_WORKER,   options->workerid,
-            0 );
-
-        /* clargs is freed by starpu. */
-        *clargs_ptr = NULL;
+        INSERT_TASK_zlaswp_batched_flush( options, ipiv, ipivk, Ak, Akm, Akn, U, Um, Un, clargs_ptr );
     }
 }
 
+#if defined(CHAMELEON_STARPU_USE_INSERT)
+
 void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
                                        const CHAM_ipiv_t      *ipiv,
                                        int                     ipivk,
@@ -122,7 +109,7 @@ void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
                                        int                     Un,
                                        void                  **clargs_ptr )
 {
-    struct cl_laswp_batched_args_t *clargs   = *clargs_ptr;
+    struct cl_zlaswp_batched_args_s *clargs   = *clargs_ptr;
     int                             nhandles;
 
     if( clargs == NULL ) {
@@ -132,7 +119,7 @@ void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
     nhandles = clargs->tasks_nbr;
     rt_starpu_insert_task(
         &cl_zlaswp_batched,
-        STARPU_CL_ARGS,             clargs, sizeof(struct cl_laswp_batched_args_t),
+        STARPU_CL_ARGS,             clargs, sizeof(struct cl_zlaswp_batched_args_s),
         STARPU_R,                   RUNTIME_perm_getaddr( ipiv, ipivk ),
         STARPU_R,                   RUNTIME_invp_getaddr( ipiv, ipivk ),
         STARPU_RW | STARPU_COMMUTE, RTBLKADDR(U, ChamComplexDouble, Um, Un),
@@ -145,3 +132,74 @@ void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
     /* clargs is freed by starpu. */
     *clargs_ptr = NULL;
 }
+
+#else /* defined(CHAMELEON_STARPU_USE_INSERT) */
+
+void INSERT_TASK_zlaswp_batched_flush( const RUNTIME_option_t *options,
+                                       const CHAM_ipiv_t      *ipiv,
+                                       int                     ipivk,
+                                       const CHAM_desc_t      *Ak,
+                                       int                     Akm,
+                                       int                     Akn,
+                                       const CHAM_desc_t      *U,
+                                       int                     Um,
+                                       int                     Un,
+                                       void                  **clargs_ptr )
+{
+    int ret, k;
+    struct starpu_task *task;
+    struct cl_zlaswp_batched_args_s *myclargs = *clargs_ptr;
+
+    if( myclargs == NULL ) {
+        return;
+    }
+
+    INSERT_TASK_COMMON_PARAMETERS( zlaswp_batched, myclargs->tasks_nbr + 4 );
+
+    /*
+     * Register the data handles, might need to receive perm and invp
+     */
+    starpu_cham_exchange_init_params( options, &params, Ak->myrank );
+    starpu_cham_exchange_handle_before_execution( options, &params, &nbdata, descrs,
+                                                  RUNTIME_perm_getaddr( ipiv, ipivk ),
+                                                  STARPU_R );
+    starpu_cham_exchange_handle_before_execution( options, &params, &nbdata, descrs,
+                                                  RUNTIME_invp_getaddr( ipiv, ipivk ),
+                                                  STARPU_R );
+    starpu_cham_register_descr( &nbdata, descrs, RTBLKADDR( U, ChamComplexDouble, Um, Un ),
+                                STARPU_RW | STARPU_COMMUTE );
+    starpu_cham_register_descr( &nbdata, descrs, RTBLKADDR( Ak, ChamComplexDouble, Akm, Akn ), STARPU_R );
+    for ( k = 0; k < myclargs->tasks_nbr; k++ ) {
+        starpu_cham_register_descr( &nbdata, descrs, myclargs->handle_mode[ k ].handle, STARPU_RW );
+    }
+
+    task = starpu_task_create();
+    task->cl = cl;
+
+    /* Set codelet parameters */
+    task->cl_arg      = myclargs;
+    task->cl_arg_size = sizeof( struct cl_zlaswp_batched_args_s );
+    task->cl_arg_free = 1;
+
+    /* Set common parameters */
+    starpu_cham_task_set_options( options, task, nbdata, descrs, NULL );
+
+    /* Flops */
+    task->flops = 0.;
+
+    ret = starpu_task_submit( task );
+    if ( ret == -ENODEV ) {
+        task->destroy = 0;
+        starpu_task_destroy( task );
+        chameleon_error( "INSERT_TASK_zlaswp_batched", "Failed to submit the task to StarPU" );
+        return;
+    }
+    starpu_cham_task_exchange_data_after_execution( options, params, nbdata, descrs );
+
+    /* clargs is freed by starpu. */
+    *clargs_ptr = NULL;
+    (void)clargs;
+    (void)cl_name;
+}
+
+#endif /* defined(CHAMELEON_STARPU_USE_INSERT) */
-- 
GitLab