Mentions légales du service

Skip to content
Snippets Groups Projects
Commit e229fd74 authored by Mathieu Faverge's avatar Mathieu Faverge
Browse files

Merge branch 'starpu/cublas_v2' into 'master'

Starpu/cublas v2

See merge request !36
parents f7e10e33 8617f58a
No related branches found
No related tags found
1 merge request!36Starpu/cublas v2
Pipeline #
Showing
with 225 additions and 113 deletions
...@@ -1003,6 +1003,12 @@ if( CHAMELEON_SCHED_QUARK ) ...@@ -1003,6 +1003,12 @@ if( CHAMELEON_SCHED_QUARK )
endif() endif()
# Add option to exploit cublas API v2
# -----------------------------------
cmake_dependent_option(CHAMELEON_USE_CUBLAS_V2
"Enable cublas API v2" ON
"CHAMELEON_USE_CUDA;CHAMELEON_SCHED_STARPU" OFF)
list(REMOVE_DUPLICATES CMAKE_EXE_LINKER_FLAGS) list(REMOVE_DUPLICATES CMAKE_EXE_LINKER_FLAGS)
string(REPLACE ";" " " CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS}") string(REPLACE ";" " " CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS}")
# Fix a problem on Mac OS X when building shared libraries # Fix a problem on Mac OS X when building shared libraries
......
...@@ -54,6 +54,7 @@ ...@@ -54,6 +54,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#if defined(CHAMELEON_USE_CUBLAS_V2) #if defined(CHAMELEON_USE_CUBLAS_V2)
#include <cublas.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#else #else
#include <cublas.h> #include <cublas.h>
......
...@@ -65,12 +65,20 @@ if( CHAMELEON_USE_MAGMA ) ...@@ -65,12 +65,20 @@ if( CHAMELEON_USE_MAGMA )
) )
endif() endif()
precisions_rules_py(CUDABLAS_SRCS_GENERATED "${ZSRC}" precisions_rules_py(
PRECISIONS "${CHAMELEON_PRECISION}") CUDABLAS_SRCS_GENERATED "${ZSRC}"
PRECISIONS "${CHAMELEON_PRECISION}")
set(CUDABLAS_SRCS set(CUDABLAS_SRCS
${CUDABLAS_SRCS_GENERATED} ${CUDABLAS_SRCS_GENERATED}
)
if (CHAMELEON_USE_CUBLAS_V2)
set(CUDABLAS_SRCS
${CUDABLAS_SRCS}
cudaglobal.c
) )
endif (CHAMELEON_USE_CUBLAS_V2)
# Compile step # Compile step
# ------------ # ------------
......
...@@ -34,18 +34,13 @@ int CUDA_zgemm(MORSE_enum transa, MORSE_enum transb, ...@@ -34,18 +34,13 @@ int CUDA_zgemm(MORSE_enum transa, MORSE_enum transb,
cuDoubleComplex *C, int ldc, cuDoubleComplex *C, int ldc,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
cublasZgemm(CUBLAS_HANDLE cublasZgemm(CUBLAS_HANDLE
morse_lapack_const(transa), morse_lapack_const(transb), morse_cublas_const(transa), morse_cublas_const(transb),
m, n, k, m, n, k,
CUBLAS_VALUE(alpha), A, lda, CUBLAS_VALUE(alpha), A, lda,
B, ldb, B, ldb,
CUBLAS_VALUE(beta), C, ldc); CUBLAS_VALUE(beta), C, ldc);
assert( CUBLAS_STATUS_SUCCESS == cublasGetError() ); assert( CUBLAS_STATUS_SUCCESS == cublasGetError() );
return MORSE_SUCCESS; return MORSE_SUCCESS;
} }
...@@ -34,18 +34,13 @@ int CUDA_zhemm(MORSE_enum side, MORSE_enum uplo, ...@@ -34,18 +34,13 @@ int CUDA_zhemm(MORSE_enum side, MORSE_enum uplo,
cuDoubleComplex *C, int ldc, cuDoubleComplex *C, int ldc,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
cublasZhemm(CUBLAS_HANDLE cublasZhemm(CUBLAS_HANDLE
morse_lapack_const(side), morse_lapack_const(uplo), morse_cublas_const(side), morse_cublas_const(uplo),
m, n, m, n,
CUBLAS_VALUE(alpha), A, lda, CUBLAS_VALUE(alpha), A, lda,
B, ldb, B, ldb,
CUBLAS_VALUE(beta), C, ldc); CUBLAS_VALUE(beta), C, ldc);
assert( CUBLAS_STATUS_SUCCESS == cublasGetError() ); assert( CUBLAS_STATUS_SUCCESS == cublasGetError() );
return MORSE_SUCCESS; return MORSE_SUCCESS;
} }
...@@ -34,12 +34,8 @@ int CUDA_zher2k(MORSE_enum uplo, MORSE_enum trans, ...@@ -34,12 +34,8 @@ int CUDA_zher2k(MORSE_enum uplo, MORSE_enum trans,
cuDoubleComplex *C, int ldc, cuDoubleComplex *C, int ldc,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
cublasZher2k(CUBLAS_HANDLE cublasZher2k(CUBLAS_HANDLE
morse_lapack_const(uplo), morse_lapack_const(trans), morse_cublas_const(uplo), morse_cublas_const(trans),
n, k, n, k,
CUBLAS_VALUE(alpha), A, lda, CUBLAS_VALUE(alpha), A, lda,
B, ldb, B, ldb,
......
...@@ -33,15 +33,11 @@ int CUDA_zherk( MORSE_enum uplo, MORSE_enum trans, ...@@ -33,15 +33,11 @@ int CUDA_zherk( MORSE_enum uplo, MORSE_enum trans,
cuDoubleComplex *B, int ldb, cuDoubleComplex *B, int ldb,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2) cublasZherk( CUBLAS_HANDLE
cublasSetKernelStream( stream ); morse_cublas_const(uplo), morse_cublas_const(trans),
#endif n, k,
CUBLAS_VALUE(alpha), A, lda,
cublasZherk( CUBLAS_VALUE(beta), B, ldb);
morse_lapack_const(uplo), morse_lapack_const(trans),
n, k,
*alpha, A, lda,
*beta, B, ldb);
assert( CUBLAS_STATUS_SUCCESS == cublasGetError() ); assert( CUBLAS_STATUS_SUCCESS == cublasGetError() );
......
...@@ -49,10 +49,6 @@ CUDA_zlarfb(MORSE_enum side, MORSE_enum trans, ...@@ -49,10 +49,6 @@ CUDA_zlarfb(MORSE_enum side, MORSE_enum trans,
MORSE_enum transT, uplo, notransV, transV; MORSE_enum transT, uplo, notransV, transV;
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
/* Check input arguments */ /* Check input arguments */
if ((side != MorseLeft) && (side != MorseRight)) { if ((side != MorseLeft) && (side != MorseRight)) {
return -1; return -1;
...@@ -107,23 +103,22 @@ CUDA_zlarfb(MORSE_enum side, MORSE_enum trans, ...@@ -107,23 +103,22 @@ CUDA_zlarfb(MORSE_enum side, MORSE_enum trans,
// W = C^H V // W = C^H V
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(MorseConjTrans), morse_lapack_const(notransV), morse_cublas_const(MorseConjTrans), morse_cublas_const(notransV),
N, K, M, N, K, M,
CUBLAS_SADDR(zone), C, LDC, CUBLAS_SADDR(zone), C, LDC,
V, LDV, V, LDV,
CUBLAS_SADDR(zzero), WORK, LDWORK ); CUBLAS_SADDR(zzero), WORK, LDWORK );
// W = W T^H = C^H V T^H // W = W T^H = C^H V T^H
cublasZtrmm( CUBLAS_HANDLE CUDA_ztrmm( MorseRight, uplo, transT, MorseNonUnit,
morse_lapack_const(MorseRight), morse_lapack_const(uplo), N, K,
morse_lapack_const(transT), morse_lapack_const(MorseNonUnit), CUBLAS_SADDR(zone), T, LDT,
N, K, CUBLAS_SADDR(zone), WORK, LDWORK,
T, LDT, CUBLAS_STREAM_VALUE );
WORK, LDWORK);
// C = C - V W^H = C - V T V^H C = (I - V T V^H) C = H C // C = C - V W^H = C - V T V^H C = (I - V T V^H) C = H C
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(notransV), morse_lapack_const(MorseConjTrans), morse_cublas_const(notransV), morse_cublas_const(MorseConjTrans),
M, N, K, M, N, K,
CUBLAS_SADDR(mzone), V, LDV, CUBLAS_SADDR(mzone), V, LDV,
WORK, LDWORK, WORK, LDWORK,
...@@ -135,23 +130,22 @@ CUDA_zlarfb(MORSE_enum side, MORSE_enum trans, ...@@ -135,23 +130,22 @@ CUDA_zlarfb(MORSE_enum side, MORSE_enum trans,
// W = C V // W = C V
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(MorseNoTrans), morse_lapack_const(notransV), morse_cublas_const(MorseNoTrans), morse_cublas_const(notransV),
M, K, N, M, K, N,
CUBLAS_SADDR(zone), C, LDC, CUBLAS_SADDR(zone), C, LDC,
V, LDV, V, LDV,
CUBLAS_SADDR(zzero), WORK, LDWORK ); CUBLAS_SADDR(zzero), WORK, LDWORK );
// W = W T = C V T // W = W T = C V T
cublasZtrmm( CUBLAS_HANDLE CUDA_ztrmm( MorseRight, uplo, trans, MorseNonUnit,
morse_lapack_const(MorseRight), morse_lapack_const(uplo), M, K,
morse_lapack_const(trans), morse_lapack_const(MorseNonUnit), CUBLAS_SADDR(zone), T, LDT,
M, K, CUBLAS_SADDR(zone), WORK, LDWORK,
T, LDT, CUBLAS_STREAM_VALUE );
WORK, LDWORK);
// C = C - W V^H = C - C V T V^H = C (I - V T V^H) = C H // C = C - W V^H = C - C V T V^H = C (I - V T V^H) = C H
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(MorseNoTrans), morse_lapack_const(transV), morse_cublas_const(MorseNoTrans), morse_cublas_const(transV),
M, N, K, M, N, K,
CUBLAS_SADDR(mzone), WORK, LDWORK, CUBLAS_SADDR(mzone), WORK, LDWORK,
V, LDV, V, LDV,
......
...@@ -243,7 +243,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -243,7 +243,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
transA2 = storev == MorseColumnwise ? MorseNoTrans : MorseConjTrans; transA2 = storev == MorseColumnwise ? MorseNoTrans : MorseConjTrans;
cublasZgemm(CUBLAS_HANDLE cublasZgemm(CUBLAS_HANDLE
morse_lapack_const(transW), morse_lapack_const(MorseNoTrans), morse_cublas_const(transW), morse_cublas_const(MorseNoTrans),
K, N1, M2, K, N1, M2,
CUBLAS_SADDR(zone), CUBLAS_SADDR(zone),
V /* K*M2 */, LDV, V /* K*M2 */, LDV,
...@@ -253,14 +253,11 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -253,14 +253,11 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
if (WORKC == NULL) { if (WORKC == NULL) {
/* W = op(T) * W */ /* W = op(T) * W */
cublasZtrmm( CUBLAS_HANDLE CUDA_ztrmm( MorseLeft, MorseUpper, trans, MorseNonUnit,
morse_lapack_const(MorseLeft), morse_lapack_const(MorseUpper), K, N2,
morse_lapack_const(trans), morse_lapack_const(MorseNonUnit), CUBLAS_SADDR(zone), T, LDT,
K, N2, WORK, LDWORK,
CUBLAS_SADDR(zone), CUBLAS_STREAM_VALUE );
T, LDT,
WORK, LDWORK);
/* A1 = A1 - W = A1 - op(T) * W */ /* A1 = A1 - W = A1 - op(T) * W */
for(j = 0; j < N1; j++) { for(j = 0; j < N1; j++) {
...@@ -272,7 +269,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -272,7 +269,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
/* A2 = A2 - op(V) * W */ /* A2 = A2 - op(V) * W */
cublasZgemm(CUBLAS_HANDLE cublasZgemm(CUBLAS_HANDLE
morse_lapack_const(transA2), morse_lapack_const(MorseNoTrans), morse_cublas_const(transA2), morse_cublas_const(MorseNoTrans),
M2, N2, K, M2, N2, K,
CUBLAS_SADDR(mzone), V /* M2*K */, LDV, CUBLAS_SADDR(mzone), V /* M2*K */, LDV,
WORK /* K*N2 */, LDWORK, WORK /* K*N2 */, LDWORK,
...@@ -281,7 +278,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -281,7 +278,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
} else { } else {
/* Wc = V * op(T) */ /* Wc = V * op(T) */
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(transA2), morse_lapack_const(trans), morse_cublas_const(transA2), morse_cublas_const(trans),
M2, K, K, M2, K, K,
CUBLAS_SADDR(zone), V, LDV, CUBLAS_SADDR(zone), V, LDV,
T, LDT, T, LDT,
...@@ -289,7 +286,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -289,7 +286,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
/* A1 = A1 - opt(T) * W */ /* A1 = A1 - opt(T) * W */
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(trans), morse_lapack_const(MorseNoTrans), morse_cublas_const(trans), morse_cublas_const(MorseNoTrans),
K, N1, K, K, N1, K,
CUBLAS_SADDR(mzone), T, LDT, CUBLAS_SADDR(mzone), T, LDT,
WORK, LDWORK, WORK, LDWORK,
...@@ -297,7 +294,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -297,7 +294,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
/* A2 = A2 - Wc * W */ /* A2 = A2 - Wc * W */
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(MorseNoTrans), morse_lapack_const(MorseNoTrans), morse_cublas_const(MorseNoTrans), morse_cublas_const(MorseNoTrans),
M2, N2, K, M2, N2, K,
CUBLAS_SADDR(mzone), WORKC, LDWORKC, CUBLAS_SADDR(mzone), WORKC, LDWORKC,
WORK, LDWORK, WORK, LDWORK,
...@@ -328,7 +325,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -328,7 +325,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
transA2 = storev == MorseColumnwise ? MorseConjTrans : MorseNoTrans; transA2 = storev == MorseColumnwise ? MorseConjTrans : MorseNoTrans;
cublasZgemm(CUBLAS_HANDLE cublasZgemm(CUBLAS_HANDLE
morse_lapack_const(MorseNoTrans), morse_lapack_const(transW), morse_cublas_const(MorseNoTrans), morse_cublas_const(transW),
M1, K, N2, M1, K, N2,
CUBLAS_SADDR(zone), A2 /* M1*N2 */, LDA2, CUBLAS_SADDR(zone), A2 /* M1*N2 */, LDA2,
V /* N2*K */, LDV, V /* N2*K */, LDV,
...@@ -336,14 +333,11 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -336,14 +333,11 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
if (WORKC == NULL) { if (WORKC == NULL) {
/* W = W * op(T) */ /* W = W * op(T) */
cublasZtrmm( CUBLAS_HANDLE CUDA_ztrmm( MorseRight, MorseUpper, trans, MorseNonUnit,
morse_lapack_const(MorseRight), morse_lapack_const(MorseUpper), M2, K,
morse_lapack_const(trans), morse_lapack_const(MorseNonUnit), CUBLAS_SADDR(zone), T, LDT,
M2, K, WORK, LDWORK,
CUBLAS_SADDR(zone), CUBLAS_STREAM_VALUE );
T, LDT,
WORK, LDWORK);
/* A1 = A1 - W = A1 - W * op(T) */ /* A1 = A1 - W = A1 - W * op(T) */
for(j = 0; j < K; j++) { for(j = 0; j < K; j++) {
...@@ -355,7 +349,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -355,7 +349,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
/* A2 = A2 - W * op(V) */ /* A2 = A2 - W * op(V) */
cublasZgemm(CUBLAS_HANDLE cublasZgemm(CUBLAS_HANDLE
morse_lapack_const(MorseNoTrans), morse_lapack_const(transA2), morse_cublas_const(MorseNoTrans), morse_cublas_const(transA2),
M2, N2, K, M2, N2, K,
CUBLAS_SADDR(mzone), WORK /* M2*K */, LDWORK, CUBLAS_SADDR(mzone), WORK /* M2*K */, LDWORK,
V /* K*N2 */, LDV, V /* K*N2 */, LDV,
...@@ -364,7 +358,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -364,7 +358,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
} else { } else {
/* A1 = A1 - W * opt(T) */ /* A1 = A1 - W * opt(T) */
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(MorseNoTrans), morse_lapack_const(trans), morse_cublas_const(MorseNoTrans), morse_cublas_const(trans),
M1, K, K, M1, K, K,
CUBLAS_SADDR(mzone), WORK, LDWORK, CUBLAS_SADDR(mzone), WORK, LDWORK,
T, LDT, T, LDT,
...@@ -372,7 +366,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -372,7 +366,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
/* Wc = op(T) * V */ /* Wc = op(T) * V */
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(trans), morse_lapack_const(transA2), morse_cublas_const(trans), morse_cublas_const(transA2),
K, N2, K, K, N2, K,
CUBLAS_SADDR(zone), T, LDT, CUBLAS_SADDR(zone), T, LDT,
V, LDV, V, LDV,
...@@ -380,7 +374,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans, ...@@ -380,7 +374,7 @@ CUDA_zparfb(MORSE_enum side, MORSE_enum trans,
/* A2 = A2 - W * Wc */ /* A2 = A2 - W * Wc */
cublasZgemm( CUBLAS_HANDLE cublasZgemm( CUBLAS_HANDLE
morse_lapack_const(MorseNoTrans), morse_lapack_const(MorseNoTrans), morse_cublas_const(MorseNoTrans), morse_cublas_const(MorseNoTrans),
M2, N2, K, M2, N2, K,
CUBLAS_SADDR(mzone), WORK, LDWORK, CUBLAS_SADDR(mzone), WORK, LDWORK,
WORKC, LDWORKC, WORKC, LDWORKC,
......
...@@ -34,12 +34,8 @@ int CUDA_zsymm(MORSE_enum side, MORSE_enum uplo, ...@@ -34,12 +34,8 @@ int CUDA_zsymm(MORSE_enum side, MORSE_enum uplo,
cuDoubleComplex *C, int ldc, cuDoubleComplex *C, int ldc,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
cublasZsymm(CUBLAS_HANDLE cublasZsymm(CUBLAS_HANDLE
morse_lapack_const(side), morse_lapack_const(uplo), morse_cublas_const(side), morse_cublas_const(uplo),
m, n, m, n,
CUBLAS_VALUE(alpha), A, lda, CUBLAS_VALUE(alpha), A, lda,
B, ldb, B, ldb,
......
...@@ -35,12 +35,8 @@ int CUDA_zsyr2k( ...@@ -35,12 +35,8 @@ int CUDA_zsyr2k(
cuDoubleComplex *C, int ldc, cuDoubleComplex *C, int ldc,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
cublasZsyr2k(CUBLAS_HANDLE cublasZsyr2k(CUBLAS_HANDLE
morse_lapack_const(uplo), morse_lapack_const(trans), morse_cublas_const(uplo), morse_cublas_const(trans),
n, k, n, k,
CUBLAS_VALUE(alpha), A, lda, CUBLAS_VALUE(alpha), A, lda,
B, ldb, B, ldb,
......
...@@ -33,12 +33,8 @@ int CUDA_zsyrk(MORSE_enum uplo, MORSE_enum trans, ...@@ -33,12 +33,8 @@ int CUDA_zsyrk(MORSE_enum uplo, MORSE_enum trans,
cuDoubleComplex *C, int ldc, cuDoubleComplex *C, int ldc,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
cublasZsyrk(CUBLAS_HANDLE cublasZsyrk(CUBLAS_HANDLE
morse_lapack_const(uplo), morse_lapack_const(trans), morse_cublas_const(uplo), morse_cublas_const(trans),
n, k, n, k,
CUBLAS_VALUE(alpha), A, lda, CUBLAS_VALUE(alpha), A, lda,
CUBLAS_VALUE(beta), C, ldc); CUBLAS_VALUE(beta), C, ldc);
......
...@@ -34,17 +34,29 @@ int CUDA_ztrmm( ...@@ -34,17 +34,29 @@ int CUDA_ztrmm(
cuDoubleComplex *B, int ldb, cuDoubleComplex *B, int ldb,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
cublasZtrmm(CUBLAS_HANDLE #if defined(CHAMELEON_USE_CUBLAS_V2)
morse_lapack_const(side), morse_lapack_const(uplo),
morse_lapack_const(transa), morse_lapack_const(diag), cublasZtrmm(
CUBLAS_HANDLE
morse_cublas_const(side), morse_cublas_const(uplo),
morse_cublas_const(transa), morse_cublas_const(diag),
m, n, m, n,
CUBLAS_VALUE(alpha), A, lda, CUBLAS_VALUE(alpha), A, lda,
B, ldb,
B, ldb); B, ldb);
#else
cublasZtrmm(
CUBLAS_HANDLE
morse_cublas_const(side), morse_cublas_const(uplo),
morse_cublas_const(transa), morse_cublas_const(diag),
m, n,
CUBLAS_VALUE(alpha), A, lda,
B, ldb);
#endif
assert( CUBLAS_STATUS_SUCCESS == cublasGetError() ); assert( CUBLAS_STATUS_SUCCESS == cublasGetError() );
return MORSE_SUCCESS; return MORSE_SUCCESS;
......
...@@ -33,13 +33,9 @@ int CUDA_ztrsm(MORSE_enum side, MORSE_enum uplo, ...@@ -33,13 +33,9 @@ int CUDA_ztrsm(MORSE_enum side, MORSE_enum uplo,
cuDoubleComplex *B, int ldb, cuDoubleComplex *B, int ldb,
CUBLAS_STREAM_PARAM) CUBLAS_STREAM_PARAM)
{ {
#if !defined(CHAMELEON_USE_CUBLAS_V2)
cublasSetKernelStream( stream );
#endif
cublasZtrsm(CUBLAS_HANDLE cublasZtrsm(CUBLAS_HANDLE
morse_lapack_const(side), morse_lapack_const(uplo), morse_cublas_const(side), morse_cublas_const(uplo),
morse_lapack_const(transa), morse_lapack_const(diag), morse_cublas_const(transa), morse_cublas_const(diag),
m, n, m, n,
CUBLAS_VALUE(alpha), A, lda, CUBLAS_VALUE(alpha), A, lda,
B, ldb); B, ldb);
......
/**
*
* @copyright (c) 2009-2014 The University of Tennessee and The University of
* Tennessee Research Foundation. All rights reserved.
* @copyright (c) 2012-2017 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
* Univ. Bordeaux. All rights reserved.
*
**/
/**
*
* @file cudaglobal.c
*
* MORSE auxiliary routines
* MORSE is a software package provided by Univ. of Tennessee,
* Univ. of California Berkeley and Univ. of Colorado Denver
*
* @version 0.9.0
* @author Mathieu Faverge
* @date 2017-04-06
*
**/
#include "cudablas/include/cudablas.h"
/*******************************************************************************
* LAPACK Constants
**/
int morse_cublas_constants[] =
{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, // 100
0, // 101: MorseRowMajor
0, // 102: MorseColMajor
0, 0, 0, 0, 0, 0, 0, 0,
CUBLAS_OP_N, // 111: MorseNoTrans
CUBLAS_OP_T, // 112: MorseTrans
CUBLAS_OP_C, // 113: MorseConjTrans
0, 0, 0, 0, 0, 0, 0,
CUBLAS_FILL_MODE_UPPER, // 121: MorseUpper
CUBLAS_FILL_MODE_LOWER, // 122: MorseLower
0, // 123: MorseUpperLower
0, 0, 0, 0, 0, 0, 0,
CUBLAS_DIAG_NON_UNIT, // 131: MorseNonUnit
CUBLAS_DIAG_UNIT, // 132: MorseUnit
0, 0, 0, 0, 0, 0, 0, 0,
CUBLAS_SIDE_LEFT, // 141: MorseLeft
CUBLAS_SIDE_RIGHT, // 142: MorseRight
0, 0, 0, 0, 0, 0, 0, 0,
0, // 151:
0, // 152:
0, // 153:
0, // 154:
0, // 155:
0, // 156:
0, // 157: MorseEps
0, // 158:
0, // 159:
0, // 160:
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, // 171: MorseOneNorm
0, // 172: MorseRealOneNorm
0, // 173: MorseTwoNorm
0, // 174: MorseFrobeniusNorm
0, // 175: MorseInfNorm
0, // 176: MorseRealInfNorm
0, // 177: MorseMaxNorm
0, // 178: MorseRealMaxNorm
0, // 179
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, // 200
0, // 201: MorseDistUniform
0, // 202: MorseDistSymmetric
0, // 203: MorseDistNormal
0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, // 240
0, // 241 MorseHermGeev
0, // 242 MorseHermPoev
0, // 243 MorseNonsymPosv
0, // 244 MorseSymPosv
0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, // 290
0, // 291 MorseNoPacking
0, // 292 MorsePackSubdiag
0, // 293 MorsePackSupdiag
0, // 294 MorsePackColumn
0, // 295 MorsePackRow
0, // 296 MorsePackLowerBand
0, // 297 MorsePackUpeprBand
0, // 298 MorsePackAll
0, // 299
0, // 300
0, // 301 MorseNoVec
0, // 302 MorseVec
0, // 303 MorseIvec
0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, // 390
0, // 391
0, // 392
0, 0, 0, 0, 0, 0, 0, 0,
0, // 401
0, // 402
0, 0, 0, 0, 0, 0, 0, 0 // Remember to add a coma!
};
...@@ -41,7 +41,9 @@ ...@@ -41,7 +41,9 @@
#if defined(CHAMELEON_USE_CUBLAS_V2) #if defined(CHAMELEON_USE_CUBLAS_V2)
#include <cublas.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#define CUBLAS_STREAM_PARAM cublasHandle_t handle #define CUBLAS_STREAM_PARAM cublasHandle_t handle
#define CUBLAS_STREAM_VALUE handle #define CUBLAS_STREAM_VALUE handle
#define CUBLAS_HANDLE handle, #define CUBLAS_HANDLE handle,
...@@ -96,4 +98,12 @@ ...@@ -96,4 +98,12 @@
extern char *morse_lapack_constants[]; extern char *morse_lapack_constants[];
#define morse_lapack_const(morse_const) morse_lapack_constants[morse_const][0] #define morse_lapack_const(morse_const) morse_lapack_constants[morse_const][0]
extern int morse_cublas_constants[];
#if defined(CHAMELEON_USE_CUBLAS_V2)
#define morse_cublas_const(morse_const) morse_cublas_constants[morse_const]
#else
#define morse_cublas_const(morse_const) morse_lapack_constants[morse_const][0]
#endif
#endif #endif
...@@ -166,7 +166,6 @@ static void cl_zgelqt_cuda_func(void *descr[], void *cl_arg) ...@@ -166,7 +166,6 @@ static void cl_zgelqt_cuda_func(void *descr[], void *cl_arg)
cuDoubleComplex *h_A, *h_T, *h_D, *h_W, *h_TAU; cuDoubleComplex *h_A, *h_T, *h_D, *h_W, *h_TAU;
cuDoubleComplex *d_A, *d_T, *d_D, *d_W; cuDoubleComplex *d_A, *d_T, *d_D, *d_W;
int lda, ldt; int lda, ldt;
CUstream stream;
starpu_codelet_unpack_args(cl_arg, &m, &n, &ib, &lda, &ldt, &h_work); starpu_codelet_unpack_args(cl_arg, &m, &n, &ib, &lda, &ldt, &h_work);
...@@ -186,15 +185,14 @@ static void cl_zgelqt_cuda_func(void *descr[], void *cl_arg) ...@@ -186,15 +185,14 @@ static void cl_zgelqt_cuda_func(void *descr[], void *cl_arg)
h_W = h_TAU + chameleon_max(m,n); h_W = h_TAU + chameleon_max(m,n);
h_D = h_W + ib*ib; h_D = h_W + ib*ib;
stream = starpu_cuda_get_local_stream(); RUNTIME_getStream(stream);
cublasSetKernelStream( stream );
CUDA_zgelqt( CUDA_zgelqt(
m, n, ib, m, n, ib,
d_A, lda, h_A, ib, d_A, lda, h_A, ib,
d_T, ldt, h_T, ib, d_T, ldt, h_T, ib,
d_D, h_D, ib, h_TAU, d_D, h_D, ib, h_TAU,
h_W, d_W, stream); h_W, d_W, stream );
cudaThreadSynchronize(); cudaThreadSynchronize();
} }
......
...@@ -148,14 +148,13 @@ static void cl_zgemm_cuda_func(void *descr[], void *cl_arg) ...@@ -148,14 +148,13 @@ static void cl_zgemm_cuda_func(void *descr[], void *cl_arg)
cuDoubleComplex beta; cuDoubleComplex beta;
cuDoubleComplex *C; cuDoubleComplex *C;
int ldc; int ldc;
CUstream stream;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]); A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]); B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]); C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
starpu_codelet_unpack_args(cl_arg, &transA, &transB, &m, &n, &k, &alpha, &lda, &ldb, &beta, &ldc); starpu_codelet_unpack_args(cl_arg, &transA, &transB, &m, &n, &k, &alpha, &lda, &ldb, &beta, &ldc);
stream = starpu_cuda_get_local_stream(); RUNTIME_getStream( stream );
CUDA_zgemm( CUDA_zgemm(
transA, transB, transA, transB,
......
...@@ -166,7 +166,6 @@ static void cl_zgeqrt_cuda_func(void *descr[], void *cl_arg) ...@@ -166,7 +166,6 @@ static void cl_zgeqrt_cuda_func(void *descr[], void *cl_arg)
cuDoubleComplex *h_A, *h_T, *h_D, *h_W, *h_TAU; cuDoubleComplex *h_A, *h_T, *h_D, *h_W, *h_TAU;
cuDoubleComplex *d_A, *d_T, *d_D, *d_W; cuDoubleComplex *d_A, *d_T, *d_D, *d_W;
int lda, ldt; int lda, ldt;
CUstream stream;
starpu_codelet_unpack_args(cl_arg, &m, &n, &ib, &lda, &ldt, &h_work); starpu_codelet_unpack_args(cl_arg, &m, &n, &ib, &lda, &ldt, &h_work);
...@@ -186,8 +185,7 @@ static void cl_zgeqrt_cuda_func(void *descr[], void *cl_arg) ...@@ -186,8 +185,7 @@ static void cl_zgeqrt_cuda_func(void *descr[], void *cl_arg)
h_W = h_TAU + chameleon_max(m,n); h_W = h_TAU + chameleon_max(m,n);
h_D = h_W + ib*ib; h_D = h_W + ib*ib;
stream = starpu_cuda_get_local_stream(); RUNTIME_getStream(stream);
cublasSetKernelStream( stream );
CUDA_zgeqrt( CUDA_zgeqrt(
m, n, ib, m, n, ib,
......
...@@ -119,14 +119,13 @@ static void cl_zhemm_cuda_func(void *descr[], void *cl_arg) ...@@ -119,14 +119,13 @@ static void cl_zhemm_cuda_func(void *descr[], void *cl_arg)
cuDoubleComplex beta; cuDoubleComplex beta;
cuDoubleComplex *C; cuDoubleComplex *C;
int LDC; int LDC;
CUstream stream;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]); A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]); B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]); C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
starpu_codelet_unpack_args(cl_arg, &side, &uplo, &M, &N, &alpha, &LDA, &LDB, &beta, &LDC); starpu_codelet_unpack_args(cl_arg, &side, &uplo, &M, &N, &alpha, &LDA, &LDB, &beta, &LDC);
stream = starpu_cuda_get_local_stream(); RUNTIME_getStream(stream);
CUDA_zhemm( CUDA_zhemm(
side, uplo, side, uplo,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment