From f08ed1d4d0c814a7bf4ff5b059610578a5178b47 Mon Sep 17 00:00:00 2001
From: Mathieu Faverge <mathieu.faverge@inria.fr>
Date: Tue, 18 Jun 2024 01:00:07 +0200
Subject: [PATCH] cmake/gpucublas: Move the cuda toolkit detection to better
 select the inclusion of extra kernels

---
 gpucublas/CMakeLists.txt         | 42 ++++++++++++++++++++++++++++++++
 gpucublas/compute/CMakeLists.txt | 40 +++++++++++++++---------------
 2 files changed, 63 insertions(+), 19 deletions(-)

diff --git a/gpucublas/CMakeLists.txt b/gpucublas/CMakeLists.txt
index aae427010..87584aa8a 100644
--- a/gpucublas/CMakeLists.txt
+++ b/gpucublas/CMakeLists.txt
@@ -26,6 +26,48 @@
 #
 ###
 
+# Add CUDA kernel if compiler and toolkit are available
+# -----------------------------------------------------
+include(CheckLanguage)
+check_language(CUDA)
+
+if(CMAKE_CUDA_COMPILER)
+  enable_language(CUDA)
+  find_package(CUDAToolkit)
+else()
+  message(STATUS "CUDA language is not supported")
+endif()
+
+if (CUDAToolkit_FOUND)
+  set(GPUCUBLAS_HAVE_CUDA_TOOLKIT ON CACHE INTERNAL "Indicate if cuda kernels are enabled or not" FORCE)
+
+  include(SetCMakeCudaArchitectures)
+
+  if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "7.5")
+    set(GPUCUBLAS_HAVE_CUDA_HALF ON CACHE INTERNAL "Indicate if half precision support is enabled or not" FORCE)
+  else()
+    set(GPUCUBLAS_HAVE_CUDA_HALF OFF CACHE INTERNAL "Indicate if half precision support is enabled or not" FORCE)
+  endif()
+else()
+  set(GPUCUBLAS_HAVE_CUDA_TOOLKIT OFF CACHE INTERNAL "Indicate if cuda kernels are enabled or not" FORCE)
+  set(GPUCUBLAS_HAVE_CUDA_HALF OFF CACHE INTERNAL "Indicate if half precision support is enabled or not" FORCE)
+endif()
+
+if ( GPUCUBLAS_HAVE_CUDA_HALF )
+  ##morse_cmake_required_set( CUBLAS )
+  set(CMAKE_REQUIRED_LIBRARIES CUDA::CUBLAS)
+
+  check_function_exists(cublasHgemm GPUCUBLAS_HAVE_CUBLASHGEMM)
+  if ( GPUCUBLAS_HAVE_CUBLASHGEMM )
+    message("-- ${Blue}Add definition HAVE_CUBLASHGEMM${ColourReset}")
+  endif()
+  check_function_exists(cublasGemmEx GPUCUBLAS_HAVE_CUBLASGEMMEX)
+  if ( GPUCUBLAS_HAVE_CUBLASGEMMEX )
+    message("-- ${Blue}Add definition HAVE_CUBLASGEMMEX${ColourReset}")
+  endif()
+  morse_cmake_required_unset()
+endif()
+
 add_subdirectory(include)
 add_subdirectory(compute)
 add_subdirectory(eztrace_module)
diff --git a/gpucublas/compute/CMakeLists.txt b/gpucublas/compute/CMakeLists.txt
index 80828b2e2..556e18827 100644
--- a/gpucublas/compute/CMakeLists.txt
+++ b/gpucublas/compute/CMakeLists.txt
@@ -56,24 +56,15 @@ set(ZSRC
     cuda_zunmqrt.c
     )
 
-# Add CUDA kernel if compiler and toolkit are available
-# -----------------------------------------------------
-include(CheckLanguage)
-check_language(CUDA)
-
-if(CMAKE_CUDA_COMPILER)
-  enable_language(CUDA)
-  find_package(CUDAToolkit)
-else()
-  message(STATUS "CUDA language is not supported")
-endif()
-
-if (CUDAToolkit_FOUND)
-  include(SetCMakeCudaArchitectures)
-
+if ( GPUCUBLAS_HAVE_CUDA_TOOLKIT )
   set(ZSRC
     ${ZSRC}
     cuda_zlag2c.cu
+  )
+endif()
+if ( GPUCUBLAS_HAVE_CUDA_HALF )
+  set(ZSRC
+    ${ZSRC}
     cuda_dlag2h.cu
   )
 endif()
@@ -102,13 +93,24 @@ precisions_rules_py(
 
 set(GPUCUBLAS_SRCS
   ${GPUCUBLAS_SRCS_GENERATED}
-  cuda_hgemm.c
-  cuda_gemmex.c
   cudaglobal.c
+)
+
+if (GPUCUBLAS_HAVE_CUBLASHGEMM)
+  set(GPUCUBLAS_SRCS
+    ${GPUCUBLAS_SRCS}
+    cuda_hgemm.c
   )
+  # Need to use CXX compiler to have the __half support and access to cublasHgemm()
+  set_source_files_properties( cuda_hgemm.c PROPERTIES LANGUAGE CXX )
+endif()
 
-# Need to use CXX compiler to have the __half support and access to cublasHgemm()
-set_source_files_properties( cuda_hgemm.c PROPERTIES LANGUAGE CXX )
+if (GPUCUBLAS_HAVE_CUBLASGEMMEX)
+  set(GPUCUBLAS_SRCS
+    ${GPUCUBLAS_SRCS}
+    cuda_gemmex.c
+  )
+endif()
 
 # Force generation of sources
 # ---------------------------
-- 
GitLab