From b4c9ffb30d9b41a6cc1f0a3fcdf5871da27f7111 Mon Sep 17 00:00:00 2001
From: Samuel Thibault <samuel.thibault@ens-lyon.org>
Date: Tue, 26 Jan 2021 23:33:24 +0100
Subject: [PATCH] starpu: Use starpu_mpi_interface_datatype_node_register when
 available

starpu_mpi_interface_datatype_node_register will allow StarPU-MPI to use
NUMA buffers and GPUDirect.
---
 runtime/starpu/CMakeLists.txt                 |  5 ++++-
 runtime/starpu/include/chameleon_starpu.h.in  |  3 ++-
 .../starpu/interface/cham_tile_interface.c    | 20 ++++++++++++++++---
 3 files changed, 23 insertions(+), 5 deletions(-)

diff --git a/runtime/starpu/CMakeLists.txt b/runtime/starpu/CMakeLists.txt
index 2fdfe5266..e073ad6f9 100644
--- a/runtime/starpu/CMakeLists.txt
+++ b/runtime/starpu/CMakeLists.txt
@@ -84,8 +84,11 @@ if ( STARPU_FOUND )
       set(CHAMELEON_USE_MIGRATE "OFF")
       message("-- ${Blue}CHAMELEON_USE_MIGRATE is turned OFF because starpu_mpi_data_migrate not found${ColourReset}")
     endif()
+    check_function_exists(starpu_mpi_interface_datatype_node_register HAVE_STARPU_MPI_INTERFACE_DATATYPE_NODE_REGISTER)
     check_function_exists(starpu_mpi_interface_datatype_register HAVE_STARPU_MPI_INTERFACE_DATATYPE_REGISTER)
-    if ( HAVE_STARPU_MPI_INTERFACE_DATATYPE_REGISTER )
+    if ( HAVE_STARPU_MPI_INTERFACE_DATATYPE_NODE_REGISTER )
+      message("-- ${Blue}Add definition HAVE_STARPU_MPI_INTERFACE_DATATYPE_NODE_REGISTER${ColourReset}")
+    elseif ( HAVE_STARPU_MPI_INTERFACE_DATATYPE_REGISTER )
       message("-- ${Blue}Add definition HAVE_STARPU_MPI_INTERFACE_DATATYPE_REGISTER${ColourReset}")
     else()
       if( CHAMELEON_USE_MPI_DATATYPES )
diff --git a/runtime/starpu/include/chameleon_starpu.h.in b/runtime/starpu/include/chameleon_starpu.h.in
index 89e6fc784..848f5c175 100644
--- a/runtime/starpu/include/chameleon_starpu.h.in
+++ b/runtime/starpu/include/chameleon_starpu.h.in
@@ -38,9 +38,10 @@
 #cmakedefine HAVE_STARPU_MPI_COMM_GET_ATTR
 #cmakedefine HAVE_STARPU_MPI_INIT_CONF
 #cmakedefine HAVE_STARPU_MPI_WAIT_FOR_ALL
+#cmakedefine HAVE_STARPU_MPI_INTERFACE_DATATYPE_NODE_REGISTER
 #cmakedefine HAVE_STARPU_MPI_INTERFACE_DATATYPE_REGISTER
 
-#if !defined(HAVE_STARPU_MPI_INTERFACE_DATATYPE_REGISTER) && defined(CHAMELEON_USE_MPI_DATATYPES)
+#if (!defined(HAVE_STARPU_MPI_INTERFACE_DATATYPE_NODE_REGISTER) && !defined(HAVE_STARPU_MPI_INTERFACE_DATATYPE_REGISTER)) && defined(CHAMELEON_USE_MPI_DATATYPES)
 #error "This version of StarPU does not support MPI datatypes (Please compile with -DCHAMELEON_USE_MPI_DATATYPES=OFF)"
 #endif
 
diff --git a/runtime/starpu/interface/cham_tile_interface.c b/runtime/starpu/interface/cham_tile_interface.c
index ea335c640..019cc78c1 100644
--- a/runtime/starpu/interface/cham_tile_interface.c
+++ b/runtime/starpu/interface/cham_tile_interface.c
@@ -500,13 +500,14 @@ cti_handle_get_allocsize( starpu_data_handle_t handle )
 
 #if defined(CHAMELEON_USE_MPI_DATATYPES)
 int
-cti_allocate_datatype( starpu_data_handle_t handle,
-                       MPI_Datatype        *datatype )
+cti_allocate_datatype_node( starpu_data_handle_t handle,
+                            unsigned             node,
+                            MPI_Datatype        *datatype )
 {
     int ret;
 
     starpu_cham_tile_interface_t *cham_tile_interface = (starpu_cham_tile_interface_t *)
-        starpu_data_get_interface_on_node( handle, STARPU_MAIN_RAM );
+        starpu_data_get_interface_on_node( handle, node );
 
     size_t m  = cham_tile_interface->tile.m;
     size_t n  = cham_tile_interface->tile.n;
@@ -522,6 +523,13 @@ cti_allocate_datatype( starpu_data_handle_t handle,
     return 0;
 }
 
+int
+cti_allocate_datatype( starpu_data_handle_t handle,
+                       MPI_Datatype        *datatype )
+{
+    return cti_allocate_datatype_node( handle, STARPU_MAIN_RAM, datatype );
+}
+
 void
 cti_free_datatype( MPI_Datatype *datatype )
 {
@@ -536,9 +544,15 @@ starpu_cham_tile_interface_init()
     {
         starpu_interface_cham_tile_ops.interfaceid = starpu_data_interface_get_next_id();
 #if defined(CHAMELEON_USE_MPI_DATATYPES)
+  #if defined(HAVE_STARPU_MPI_INTERFACE_DATATYPE_NODE_REGISTER)
+        starpu_mpi_interface_datatype_node_register( starpu_interface_cham_tile_ops.interfaceid,
+                                                    cti_allocate_datatype_node,
+                                                    cti_free_datatype );
+  #else
         starpu_mpi_interface_datatype_register( starpu_interface_cham_tile_ops.interfaceid,
                                                 cti_allocate_datatype,
                                                 cti_free_datatype );
+  #endif
 #endif
     }
 }
-- 
GitLab