From 375cb991cdf952cb6784b14c0761db03814cbaa7 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Thu, 10 Feb 2022 21:40:29 +0100
Subject: [PATCH] starpu: interface add the reuse_data_on_node function to the
 interface

---
 .../starpu/interface/cham_tile_interface.c    | 43 ++++++++++++++-----
 1 file changed, 32 insertions(+), 11 deletions(-)

diff --git a/runtime/starpu/interface/cham_tile_interface.c b/runtime/starpu/interface/cham_tile_interface.c
index 0065c9901..30348afe1 100644
--- a/runtime/starpu/interface/cham_tile_interface.c
+++ b/runtime/starpu/interface/cham_tile_interface.c
@@ -115,7 +115,7 @@ cti_init( void *data_interface )
 
 static void
 cti_register_data_handle( starpu_data_handle_t  handle,
-                          unsigned              home_node,
+                          int                   home_node,
                           void                 *data_interface )
 {
     starpu_cham_tile_interface_t *cham_tile_interface = (starpu_cham_tile_interface_t *) data_interface;
@@ -194,6 +194,23 @@ cti_free_data_on_node( void *data_interface, unsigned node )
     cham_tile_interface->dev_handle = 0;
 }
 
+#if defined(HAVE_STARPU_REUSE_DATA_ON_NODE)
+static void
+cti_reuse_data_on_node( void *dst_data_interface, const void *cached_interface, unsigned node )
+{
+    (void)node;
+    starpu_cham_tile_interface_t *dst_cham_tile =
+        (starpu_cham_tile_interface_t *) dst_data_interface;
+    starpu_cham_tile_interface_t *cached_cham_tile =
+        (starpu_cham_tile_interface_t *) cached_interface;
+
+    /* update the data properly */
+    dst_cham_tile->tile.mat   = cached_cham_tile->tile.mat;
+    dst_cham_tile->tile.ld    = dst_cham_tile->tile.m;
+    dst_cham_tile->dev_handle = cached_cham_tile->dev_handle;
+}
+#endif
+
 static void *
 cti_to_pointer( void *data_interface, unsigned node )
 {
@@ -565,6 +582,9 @@ static int cti_copy_any_to_any( void *src_interface, unsigned src_node,
     void *src_mat = CHAM_tile_get_ptr( &(cham_tile_src->tile) );
     void *dst_mat = CHAM_tile_get_ptr( &(cham_tile_dst->tile) );
 
+    assert( ld_src >= m );
+    assert( ld_dst >= m );
+
 #if defined(CHAMELEON_KERNELS_TRACE)
     fprintf( stderr,
              "[ANY->ANY] src(%s, type:%s, m=%d, n=%d, ld=%d, ptr:%p) dest(%s, type:%s, m=%d, n=%d, ld=%d, ptr:%p)\n",
@@ -574,12 +594,13 @@ static int cti_copy_any_to_any( void *src_interface, unsigned src_node,
              cham_tile_dst->tile.m, cham_tile_dst->tile.n, cham_tile_dst->tile.ld, dst_mat );
 #endif
 
+    m      = m      * elemsize;
+    ld_src = ld_src * elemsize;
+    ld_dst = ld_dst * elemsize;
 #if defined(HAVE_STARPU_INTERFACE_COPY2D)
-    ld_src *= elemsize;
-    ld_dst *= elemsize;
     if (starpu_interface_copy2d( (uintptr_t) src_mat, 0, src_node,
                                  (uintptr_t) dst_mat, 0, dst_node,
-                                 m * elemsize, n, ld_src, ld_dst, async_data ) ) {
+                                 m, n, ld_src, ld_dst, async_data ) ) {
         ret = -EAGAIN;
     }
 #else
@@ -588,7 +609,7 @@ static int cti_copy_any_to_any( void *src_interface, unsigned src_node,
         /* Optimize unpartitioned and y-partitioned cases */
         if ( starpu_interface_copy( (uintptr_t) src_mat, 0, src_node,
                                     (uintptr_t) dst_mat, 0, dst_node,
-                                    m * n * elemsize, async_data ) )
+                                    m * n, async_data ) )
         {
             ret = -EAGAIN;
 	}
@@ -596,9 +617,6 @@ static int cti_copy_any_to_any( void *src_interface, unsigned src_node,
     else
     {
         unsigned y;
-        ld_src *= elemsize;
-        ld_dst *= elemsize;
-
         for (y = 0; y < n; y++)
         {
             uint32_t src_offset = y * ld_src;
@@ -606,7 +624,7 @@ static int cti_copy_any_to_any( void *src_interface, unsigned src_node,
 
             if ( starpu_interface_copy( (uintptr_t) src_mat, src_offset, src_node,
                                         (uintptr_t) dst_mat, dst_offset, dst_node,
-                                        m * elemsize, async_data ) )
+                                        m, async_data ) )
             {
                 ret = -EAGAIN;
             }
@@ -614,7 +632,7 @@ static int cti_copy_any_to_any( void *src_interface, unsigned src_node,
     }
 #endif
 
-    starpu_interface_data_copy( src_node, dst_node, (size_t) n*m*elemsize );
+    starpu_interface_data_copy( src_node, dst_node, m * n );
 
     return ret;
 }
@@ -630,6 +648,10 @@ struct starpu_data_interface_ops starpu_interface_cham_tile_ops =
     .register_data_handle  = cti_register_data_handle,
     .allocate_data_on_node = cti_allocate_data_on_node,
     .free_data_on_node     = cti_free_data_on_node,
+#if defined(HAVE_STARPU_REUSE_DATA_ON_NODE)
+    .reuse_data_on_node    = cti_reuse_data_on_node,
+    .alloc_compare         = cti_alloc_compare,
+#endif
     .to_pointer            = cti_to_pointer,
     .pointer_is_inside     = cti_pointer_is_inside,
     .get_size              = cti_get_size,
@@ -637,7 +659,6 @@ struct starpu_data_interface_ops starpu_interface_cham_tile_ops =
     .footprint             = cti_footprint,
     .alloc_footprint       = cti_alloc_footprint,
     .compare               = cti_compare,
-    .alloc_compare         = cti_alloc_compare,
     .display               = cti_display,
     .pack_data             = cti_pack_data,
 #if defined(HAVE_STARPU_DATA_PEEK)
-- 
GitLab