diff --git a/gpucublas/compute/CMakeLists.txt b/gpucublas/compute/CMakeLists.txt index 81b99814503b779db0e7b75d4c05e5aded4eec53..1691581e744da88cf6cc0e631937e2bf94f7ef29 100644 --- a/gpucublas/compute/CMakeLists.txt +++ b/gpucublas/compute/CMakeLists.txt @@ -292,6 +292,7 @@ precisions_rules_py( set(GPUCUBLAS_SRCS ${GPUCUBLAS_SRCS_GENERATED} cuda_hgemm.c + cuda_gemmex.c cudaglobal.c ) diff --git a/gpucublas/compute/cuda_gemmex.c b/gpucublas/compute/cuda_gemmex.c new file mode 100644 index 0000000000000000000000000000000000000000..c384018e900fc88a8939bac3f042c972f0dc367c --- /dev/null +++ b/gpucublas/compute/cuda_gemmex.c @@ -0,0 +1,43 @@ +/** + * + * @file cuda_gemmex.c + * + * @copyright 2023-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria, + * Univ. Bordeaux. All rights reserved. + * + *** + * + * @brief Chameleon cuda_gemmex GPU kernel + * + * @version 1.3.0 + * @author Mathieu Faverge + * @date 2023-07-04 + * + */ +#include "gpucublas.h" + +int +CUDA_gemmex( cham_trans_t transa, cham_trans_t transb, + int m, int n, int k, + const void *alpha, + const void *A, int lda, cham_flttype_t Atype, + const void *B, int ldb, cham_flttype_t Btype, + const void *beta, + void *C, int ldc, cham_flttype_t Ctype, + cublasHandle_t handle ) +{ + cublasStatus_t rc; + + rc = cublasGemmEx( handle, + chameleon_cublas_const(transa), chameleon_cublas_const(transb), + m, n, k, + CUBLAS_VALUE(alpha), A, lda, chameleon_cublas_dtype( Atype ), + B, ldb, chameleon_cublas_dtype( Btype ), + CUBLAS_VALUE(beta), C, ldc, chameleon_cublas_dtype( Ctype ), + chameleon_cublas_ctype( Ctype ), + CUBLAS_GEMM_DEFAULT ); + + assert( rc == CUBLAS_STATUS_SUCCESS ); + (void)rc; + return CHAMELEON_SUCCESS; +} diff --git a/gpucublas/include/gpucublas.h b/gpucublas/include/gpucublas.h index 29a7046a8c840faecc4fee7b2edb70de35700b98..8e7d4c3afa4d0165e2961c7791ae2dd2f12b3871 100644 --- a/gpucublas/include/gpucublas.h +++ b/gpucublas/include/gpucublas.h @@ -70,6 +70,45 @@ int CUDA_hgemm( cham_trans_t transa, cham_trans_t transb, CHAMELEON_Real16_t *C, int ldc, cublasHandle_t handle ); +int CUDA_gemmex( cham_trans_t transa, cham_trans_t transb, + int m, int n, int k, + const void *alpha, + const void *A, int lda, cham_flttype_t Atype, + const void *B, int ldb, cham_flttype_t Btype, + const void *beta, + void *C, int ldc, cham_flttype_t Ctype, + cublasHandle_t handle ); + +static inline cublasComputeType_t +chameleon_cublas_ctype( cham_flttype_t flttype ) { + + switch ( flttype ) { + case ChamRealHalf : return CUBLAS_COMPUTE_16F; + case ChamRealFloat : return CUBLAS_COMPUTE_32F; + case ChamRealDouble : return CUBLAS_COMPUTE_64F; + case ChamComplexFloat : return CUBLAS_COMPUTE_32F; + case ChamComplexDouble : return CUBLAS_COMPUTE_64F; + default: + fprintf( stderr, "chameleon_cublas_ctype(): Incorrect flttype\n" ); + exit(1); + } +} + +static inline cudaDataType_t +chameleon_cublas_dtype( cham_flttype_t flttype ) { + + switch ( flttype ) { + case ChamRealHalf : return CUDA_R_16F; + case ChamRealFloat : return CUDA_R_32F; + case ChamRealDouble : return CUDA_R_64F; + case ChamComplexFloat : return CUDA_C_32F; + case ChamComplexDouble : return CUDA_C_64F; + default: + fprintf( stderr, "chameleon_cublas_dtype(): Incorrect flttype\n" ); + exit(1); + } +} + END_C_DECLS /**