Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ebd114f0 authored by Mathieu Faverge's avatar Mathieu Faverge
Browse files

gpus/cublas: add the gemmex kernel

parent ae7bbb63
No related branches found
No related tags found
1 merge request!395Introduce half-precision conversion and gemm kernels for GPUs
......@@ -292,6 +292,7 @@ precisions_rules_py(
set(GPUCUBLAS_SRCS
${GPUCUBLAS_SRCS_GENERATED}
cuda_hgemm.c
cuda_gemmex.c
cudaglobal.c
)
......
/**
*
* @file cuda_gemmex.c
*
* @copyright 2023-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
* Univ. Bordeaux. All rights reserved.
*
***
*
* @brief Chameleon cuda_gemmex GPU kernel
*
* @version 1.3.0
* @author Mathieu Faverge
* @date 2023-07-04
*
*/
#include "gpucublas.h"
int
CUDA_gemmex( cham_trans_t transa, cham_trans_t transb,
int m, int n, int k,
const void *alpha,
const void *A, int lda, cham_flttype_t Atype,
const void *B, int ldb, cham_flttype_t Btype,
const void *beta,
void *C, int ldc, cham_flttype_t Ctype,
cublasHandle_t handle )
{
cublasStatus_t rc;
rc = cublasGemmEx( handle,
chameleon_cublas_const(transa), chameleon_cublas_const(transb),
m, n, k,
CUBLAS_VALUE(alpha), A, lda, chameleon_cublas_dtype( Atype ),
B, ldb, chameleon_cublas_dtype( Btype ),
CUBLAS_VALUE(beta), C, ldc, chameleon_cublas_dtype( Ctype ),
chameleon_cublas_ctype( Ctype ),
CUBLAS_GEMM_DEFAULT );
assert( rc == CUBLAS_STATUS_SUCCESS );
(void)rc;
return CHAMELEON_SUCCESS;
}
......@@ -70,6 +70,45 @@ int CUDA_hgemm( cham_trans_t transa, cham_trans_t transb,
CHAMELEON_Real16_t *C, int ldc,
cublasHandle_t handle );
int CUDA_gemmex( cham_trans_t transa, cham_trans_t transb,
int m, int n, int k,
const void *alpha,
const void *A, int lda, cham_flttype_t Atype,
const void *B, int ldb, cham_flttype_t Btype,
const void *beta,
void *C, int ldc, cham_flttype_t Ctype,
cublasHandle_t handle );
static inline cublasComputeType_t
chameleon_cublas_ctype( cham_flttype_t flttype ) {
switch ( flttype ) {
case ChamRealHalf : return CUBLAS_COMPUTE_16F;
case ChamRealFloat : return CUBLAS_COMPUTE_32F;
case ChamRealDouble : return CUBLAS_COMPUTE_64F;
case ChamComplexFloat : return CUBLAS_COMPUTE_32F;
case ChamComplexDouble : return CUBLAS_COMPUTE_64F;
default:
fprintf( stderr, "chameleon_cublas_ctype(): Incorrect flttype\n" );
exit(1);
}
}
static inline cudaDataType_t
chameleon_cublas_dtype( cham_flttype_t flttype ) {
switch ( flttype ) {
case ChamRealHalf : return CUDA_R_16F;
case ChamRealFloat : return CUDA_R_32F;
case ChamRealDouble : return CUDA_R_64F;
case ChamComplexFloat : return CUDA_C_32F;
case ChamComplexDouble : return CUDA_C_64F;
default:
fprintf( stderr, "chameleon_cublas_dtype(): Incorrect flttype\n" );
exit(1);
}
}
END_C_DECLS
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment