From ebd114f0b2d33e0337051af35ccee8b6eb4e7478 Mon Sep 17 00:00:00 2001 From: Mathieu Faverge <mathieu.faverge@inria.fr> Date: Wed, 24 May 2023 09:20:06 -0400 Subject: [PATCH] gpus/cublas: add the gemmex kernel --- gpucublas/compute/CMakeLists.txt | 1 + gpucublas/compute/cuda_gemmex.c | 43 ++++++++++++++++++++++++++++++++ gpucublas/include/gpucublas.h | 39 +++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 gpucublas/compute/cuda_gemmex.c diff --git a/gpucublas/compute/CMakeLists.txt b/gpucublas/compute/CMakeLists.txt index 81b998145..1691581e7 100644 --- a/gpucublas/compute/CMakeLists.txt +++ b/gpucublas/compute/CMakeLists.txt @@ -292,6 +292,7 @@ precisions_rules_py( set(GPUCUBLAS_SRCS ${GPUCUBLAS_SRCS_GENERATED} cuda_hgemm.c + cuda_gemmex.c cudaglobal.c ) diff --git a/gpucublas/compute/cuda_gemmex.c b/gpucublas/compute/cuda_gemmex.c new file mode 100644 index 000000000..c384018e9 --- /dev/null +++ b/gpucublas/compute/cuda_gemmex.c @@ -0,0 +1,43 @@ +/** + * + * @file cuda_gemmex.c + * + * @copyright 2023-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria, + * Univ. Bordeaux. All rights reserved. + * + *** + * + * @brief Chameleon cuda_gemmex GPU kernel + * + * @version 1.3.0 + * @author Mathieu Faverge + * @date 2023-07-04 + * + */ +#include "gpucublas.h" + +int +CUDA_gemmex( cham_trans_t transa, cham_trans_t transb, + int m, int n, int k, + const void *alpha, + const void *A, int lda, cham_flttype_t Atype, + const void *B, int ldb, cham_flttype_t Btype, + const void *beta, + void *C, int ldc, cham_flttype_t Ctype, + cublasHandle_t handle ) +{ + cublasStatus_t rc; + + rc = cublasGemmEx( handle, + chameleon_cublas_const(transa), chameleon_cublas_const(transb), + m, n, k, + CUBLAS_VALUE(alpha), A, lda, chameleon_cublas_dtype( Atype ), + B, ldb, chameleon_cublas_dtype( Btype ), + CUBLAS_VALUE(beta), C, ldc, chameleon_cublas_dtype( Ctype ), + chameleon_cublas_ctype( Ctype ), + CUBLAS_GEMM_DEFAULT ); + + assert( rc == CUBLAS_STATUS_SUCCESS ); + (void)rc; + return CHAMELEON_SUCCESS; +} diff --git a/gpucublas/include/gpucublas.h b/gpucublas/include/gpucublas.h index 29a7046a8..8e7d4c3af 100644 --- a/gpucublas/include/gpucublas.h +++ b/gpucublas/include/gpucublas.h @@ -70,6 +70,45 @@ int CUDA_hgemm( cham_trans_t transa, cham_trans_t transb, CHAMELEON_Real16_t *C, int ldc, cublasHandle_t handle ); +int CUDA_gemmex( cham_trans_t transa, cham_trans_t transb, + int m, int n, int k, + const void *alpha, + const void *A, int lda, cham_flttype_t Atype, + const void *B, int ldb, cham_flttype_t Btype, + const void *beta, + void *C, int ldc, cham_flttype_t Ctype, + cublasHandle_t handle ); + +static inline cublasComputeType_t +chameleon_cublas_ctype( cham_flttype_t flttype ) { + + switch ( flttype ) { + case ChamRealHalf : return CUBLAS_COMPUTE_16F; + case ChamRealFloat : return CUBLAS_COMPUTE_32F; + case ChamRealDouble : return CUBLAS_COMPUTE_64F; + case ChamComplexFloat : return CUBLAS_COMPUTE_32F; + case ChamComplexDouble : return CUBLAS_COMPUTE_64F; + default: + fprintf( stderr, "chameleon_cublas_ctype(): Incorrect flttype\n" ); + exit(1); + } +} + +static inline cudaDataType_t +chameleon_cublas_dtype( cham_flttype_t flttype ) { + + switch ( flttype ) { + case ChamRealHalf : return CUDA_R_16F; + case ChamRealFloat : return CUDA_R_32F; + case ChamRealDouble : return CUDA_R_64F; + case ChamComplexFloat : return CUDA_C_32F; + case ChamComplexDouble : return CUDA_C_64F; + default: + fprintf( stderr, "chameleon_cublas_dtype(): Incorrect flttype\n" ); + exit(1); + } +} + END_C_DECLS /** -- GitLab