Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 73117bfb authored by PRUVOST Florent's avatar PRUVOST Florent
Browse files

update to cublas interface v2

parent 1b4504ef
No related branches found
No related tags found
No related merge requests found
...@@ -50,13 +50,12 @@ ...@@ -50,13 +50,12 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
//#if defined(CHAMELEON_USE_CUBLAS_V2) #if defined(CHAMELEON_USE_CUBLAS_V2)
//#include <cublas_v2.h> #include <cublas_v2.h>
//#else #else
//#include <cublas.h>
//#endif
#include <cublas.h> #include <cublas.h>
#endif #endif
#endif
#if defined(CHAMELEON_USE_OPENCL) #if defined(CHAMELEON_USE_OPENCL)
#include <OpenCL/cl.h> #include <OpenCL/cl.h>
......
...@@ -105,6 +105,84 @@ static void cl_zgemm_cpu_func(void *descr[], void *cl_arg) ...@@ -105,6 +105,84 @@ static void cl_zgemm_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_zgemm_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum transA;
MORSE_enum transB;
int m;
int n;
int k;
cuDoubleComplex alpha;
const cuDoubleComplex *A;
int lda;
const cuDoubleComplex *B;
int ldb;
cuDoubleComplex beta;
cuDoubleComplex *C;
int ldc;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
starpu_codelet_unpack_args(cl_arg, &transA, &transB, &m, &n, &k, &alpha, &lda, &ldb, &beta, &ldc);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasOperation_t cublasTransA;
if (transA == MorseNoTrans){
cublasTransA = CUBLAS_OP_N;
}else if(transA == MorseTrans){
cublasTransA = CUBLAS_OP_T;
}else if(transA == MorseConjTrans){
cublasTransA = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in cl_zgemm_cuda_func: bad transA parameter %d\n", transA);
}
cublasOperation_t cublasTransB;
if (transB == MorseNoTrans){
cublasTransB = CUBLAS_OP_N;
}else if(transB == MorseTrans){
cublasTransB = CUBLAS_OP_T;
}else if(transB == MorseConjTrans){
cublasTransB = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in cl_zgemm_cuda_func: bad transB parameter %d\n", transB);
}
stat = cublasZgemm(handle,
cublasTransA, cublasTransB,
m, n, k,
(const cuDoubleComplex *) &alpha, A, lda,
B, ldb,
(const cuDoubleComplex *) &beta, C, ldc);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZgemm failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_zgemm_cuda_func(void *descr[], void *cl_arg) static void cl_zgemm_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum transA; MORSE_enum transA;
...@@ -135,6 +213,7 @@ static void cl_zgemm_cuda_func(void *descr[], void *cl_arg) ...@@ -135,6 +213,7 @@ static void cl_zgemm_cuda_func(void *descr[], void *cl_arg)
alpha, A, lda, alpha, A, lda,
B, ldb, B, ldb,
beta, C, ldc); beta, C, ldc);
assert( CUBLAS_STATUS_SUCCESS == cublasGetError() ); assert( CUBLAS_STATUS_SUCCESS == cublasGetError() );
#ifndef STARPU_CUDA_ASYNC #ifndef STARPU_CUDA_ASYNC
...@@ -143,7 +222,8 @@ static void cl_zgemm_cuda_func(void *descr[], void *cl_arg) ...@@ -143,7 +222,8 @@ static void cl_zgemm_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
* Codelet definition * Codelet definition
......
...@@ -153,6 +153,61 @@ static void cl_zgeqrt_cpu_func(void *descr[], void *cl_arg) ...@@ -153,6 +153,61 @@ static void cl_zgeqrt_cpu_func(void *descr[], void *cl_arg)
#if defined(CHAMELEON_USE_MAGMA) #if defined(CHAMELEON_USE_MAGMA)
#if defined(CHAMELEON_USE_CUBLAS_V2)
magma_int_t
magma_zgemerge_gpu(magma_side_t side, magma_diag_t diag,
magma_int_t M, magma_int_t N,
magmaDoubleComplex *A, magma_int_t LDA,
magmaDoubleComplex *B, magma_int_t LDB)
{
int i, j;
magmaDoubleComplex *cola, *colb;
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
if (M < 0) {
return -1;
}
if (N < 0) {
return -2;
}
if ( (LDA < max(1,M)) && (M > 0) ) {
return -5;
}
if ( (LDB < max(1,M)) && (M > 0) ) {
return -7;
}
if (side == MagmaLeft){
for(i=0; i<N; i++){
cola = A + i*LDA;
colb = B + i*LDB;
cublasZcopy(handle, i+1, cola, 1, colb, 1);
}
}else{
for(i=0; i<N; i++){
cola = A + i*LDA;
colb = B + i*LDB;
cublasZcopy(handle, M-i, cola + i, 1, colb + i, 1);
}
}
cublasDestroy(handle);
return MAGMA_SUCCESS;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
magma_int_t magma_int_t
magma_zgemerge_gpu(magma_side_t side, magma_diag_t diag, magma_zgemerge_gpu(magma_side_t side, magma_diag_t diag,
magma_int_t M, magma_int_t N, magma_int_t M, magma_int_t N,
...@@ -191,7 +246,7 @@ magma_zgemerge_gpu(magma_side_t side, magma_diag_t diag, ...@@ -191,7 +246,7 @@ magma_zgemerge_gpu(magma_side_t side, magma_diag_t diag,
return MAGMA_SUCCESS; return MAGMA_SUCCESS;
} }
#endif /* CHAMELEON_USE_CUBLAS_V2 */
magma_int_t magma_int_t
magma_zgeqrt_gpu( magma_int_t m, magma_int_t n, magma_int_t nb, magma_zgeqrt_gpu( magma_int_t m, magma_int_t n, magma_int_t nb,
......
...@@ -102,6 +102,81 @@ static void cl_zhemm_cpu_func(void *descr[], void *cl_arg) ...@@ -102,6 +102,81 @@ static void cl_zhemm_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_zhemm_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum side;
MORSE_enum uplo;
int M;
int N;
cuDoubleComplex alpha;
const cuDoubleComplex *A;
int LDA;
const cuDoubleComplex *B;
int LDB;
cuDoubleComplex beta;
cuDoubleComplex *C;
int LDC;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
starpu_codelet_unpack_args(cl_arg, &side, &uplo, &M, &N, &alpha, &LDA, &LDB, &beta, &LDC);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasSideMode_t cublasSide;
if (side == MorseLeft){
cublasSide = CUBLAS_SIDE_LEFT;
}else if (side == MorseRight){
cublasSide = CUBLAS_SIDE_RIGHT;
}else{
fprintf(stderr, "Error in cl_zhemm_cuda_func: bad side parameter %d\n", side);
}
cublasFillMode_t cublasUplo;
if (uplo == MorseUpper){
cublasUplo = CUBLAS_FILL_MODE_UPPER;
}else if(uplo == MorseLower){
cublasUplo = CUBLAS_FILL_MODE_LOWER;
}else if(uplo == MorseUpperLower){
cublasUplo = 0;
}else{
fprintf(stderr, "Error in cl_zhemm_cuda_func: bad uplo parameter %d\n", uplo);
}
stat = cublasZhemm(handle,
cublasSide, cublasUplo,
M, N,
(const cuDoubleComplex *) &alpha, A, LDA,
B, LDB,
(const cuDoubleComplex *) &beta, C, LDC);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZhemm failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_zhemm_cuda_func(void *descr[], void *cl_arg) static void cl_zhemm_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum side; MORSE_enum side;
...@@ -138,7 +213,8 @@ static void cl_zhemm_cuda_func(void *descr[], void *cl_arg) ...@@ -138,7 +213,8 @@ static void cl_zhemm_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
* Codelet definition * Codelet definition
......
...@@ -97,6 +97,81 @@ static void cl_zher2k_cpu_func(void *descr[], void *cl_arg) ...@@ -97,6 +97,81 @@ static void cl_zher2k_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_zher2k_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum uplo;
MORSE_enum trans;
int n;
int k;
cuDoubleComplex alpha;
const cuDoubleComplex *A;
int lda;
const cuDoubleComplex *B;
int ldb;
double beta;
cuDoubleComplex *C;
int ldc;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
starpu_codelet_unpack_args(cl_arg, &uplo, &trans, &n, &k, &alpha, &lda, &ldb, &beta, &ldc);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasFillMode_t cublasUplo;
if (uplo == MorseUpper){
cublasUplo = CUBLAS_FILL_MODE_UPPER;
}else if(uplo == MorseLower){
cublasUplo = CUBLAS_FILL_MODE_LOWER;
}else if(uplo == MorseUpperLower){
cublasUplo = 0;
}else{
fprintf(stderr, "Error in cl_zher2k_cuda_func: bad uplo parameter %d\n", uplo);
}
cublasOperation_t cublasTrans;
if (trans == MorseNoTrans){
cublasTrans = CUBLAS_OP_N;
}else if(trans == MorseTrans){
cublasTrans = CUBLAS_OP_T;
}else if(trans == MorseConjTrans){
cublasTrans = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in cl_zher2k_cuda_func: bad trans parameter %d\n", trans);
}
stat = cublasZher2k( handle, cublasUplo, cublasTrans,
n, k, (const cuDoubleComplex *) &alpha, A, lda, B, ldb,
(const double *) &beta, C, ldc);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZher2k failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_zher2k_cuda_func(void *descr[], void *cl_arg) static void cl_zher2k_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum uplo; MORSE_enum uplo;
...@@ -129,7 +204,8 @@ static void cl_zher2k_cuda_func(void *descr[], void *cl_arg) ...@@ -129,7 +204,8 @@ static void cl_zher2k_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
* Codelet definition * Codelet definition
......
...@@ -93,6 +93,80 @@ static void cl_zherk_cpu_func(void *descr[], void *cl_arg) ...@@ -93,6 +93,80 @@ static void cl_zherk_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_zherk_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum uplo;
MORSE_enum trans;
int n;
int k;
double alpha;
const cuDoubleComplex *A;
int lda;
double beta;
cuDoubleComplex *C;
int ldc;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
starpu_codelet_unpack_args(cl_arg, &uplo, &trans, &n, &k, &alpha, &lda, &beta, &ldc);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasFillMode_t cublasUplo;
if (uplo == MorseUpper){
cublasUplo = CUBLAS_FILL_MODE_UPPER;
}else if(uplo == MorseLower){
cublasUplo = CUBLAS_FILL_MODE_LOWER;
}else if(uplo == MorseUpperLower){
cublasUplo = 0;
}else{
fprintf(stderr, "Error in cl_zherk_cuda_func: bad uplo parameter %d\n", uplo);
}
cublasOperation_t cublasTrans;
if (trans == MorseNoTrans){
cublasTrans = CUBLAS_OP_N;
}else if(trans == MorseTrans){
cublasTrans = CUBLAS_OP_T;
}else if(trans == MorseConjTrans){
cublasTrans = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in cl_zherk_cuda_func: bad trans parameter %d\n", trans);
}
stat = cublasZherk(handle,
cublasUplo, cublasTrans,
n, k,
(const double *) &alpha, A, lda,
(const double *) &beta, C, ldc);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZherk failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_zherk_cuda_func(void *descr[], void *cl_arg) static void cl_zherk_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum uplo; MORSE_enum uplo;
...@@ -125,7 +199,8 @@ static void cl_zherk_cuda_func(void *descr[], void *cl_arg) ...@@ -125,7 +199,8 @@ static void cl_zherk_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
* Codelet definition * Codelet definition
......
...@@ -102,6 +102,81 @@ static void cl_zsymm_cpu_func(void *descr[], void *cl_arg) ...@@ -102,6 +102,81 @@ static void cl_zsymm_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_zsymm_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum side;
MORSE_enum uplo;
int M;
int N;
cuDoubleComplex alpha;
const cuDoubleComplex *A;
int LDA;
const cuDoubleComplex *B;
int LDB;
cuDoubleComplex beta;
cuDoubleComplex *C;
int LDC;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
starpu_codelet_unpack_args(cl_arg, &side, &uplo, &M, &N, &alpha, &LDA, &LDB, &beta, &LDC);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasSideMode_t cublasSide;
if (side == MorseLeft){
cublasSide = CUBLAS_SIDE_LEFT;
}else if (side == MorseRight){
cublasSide = CUBLAS_SIDE_RIGHT;
}else{
fprintf(stderr, "Error in cl_zsymm_cuda_func: bad side parameter %d\n", side);
}
cublasFillMode_t cublasUplo;
if (uplo == MorseUpper){
cublasUplo = CUBLAS_FILL_MODE_UPPER;
}else if(uplo == MorseLower){
cublasUplo = CUBLAS_FILL_MODE_LOWER;
}else if(uplo == MorseUpperLower){
cublasUplo = 0;
}else{
fprintf(stderr, "Error in cl_zsymm_cuda_func: bad uplo parameter %d\n", uplo);
}
stat = cublasZsymm(handle,
cublasSide, cublasUplo,
M, N,
(const cuDoubleComplex *) &alpha, A, LDA,
B, LDB,
(const cuDoubleComplex *) &beta, C, LDC);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZsymm failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_zsymm_cuda_func(void *descr[], void *cl_arg) static void cl_zsymm_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum side; MORSE_enum side;
...@@ -138,7 +213,8 @@ static void cl_zsymm_cuda_func(void *descr[], void *cl_arg) ...@@ -138,7 +213,8 @@ static void cl_zsymm_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
* Codelet definition * Codelet definition
......
...@@ -97,6 +97,81 @@ static void cl_zsyr2k_cpu_func(void *descr[], void *cl_arg) ...@@ -97,6 +97,81 @@ static void cl_zsyr2k_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_zsyr2k_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum uplo;
MORSE_enum trans;
int n;
int k;
cuDoubleComplex alpha;
const cuDoubleComplex *A;
int lda;
const cuDoubleComplex *B;
int ldb;
cuDoubleComplex beta;
cuDoubleComplex *C;
int ldc;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
starpu_codelet_unpack_args(cl_arg, &uplo, &trans, &n, &k, &alpha, &lda, &ldb, &beta, &ldc);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasFillMode_t cublasUplo;
if (uplo == MorseUpper){
cublasUplo = CUBLAS_FILL_MODE_UPPER;
}else if(uplo == MorseLower){
cublasUplo = CUBLAS_FILL_MODE_LOWER;
}else if(uplo == MorseUpperLower){
cublasUplo = 0;
}else{
fprintf(stderr, "Error in cl_zsyr2k_cuda_func: bad uplo parameter %d\n", uplo);
}
cublasOperation_t cublasTrans;
if (trans == MorseNoTrans){
cublasTrans = CUBLAS_OP_N;
}else if(trans == MorseTrans){
cublasTrans = CUBLAS_OP_T;
}else if(trans == MorseConjTrans){
cublasTrans = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in cl_zsyr2k_cuda_func: bad trans parameter %d\n", trans);
}
stat = cublasZsyr2k( handle, cublasUplo, cublasTrans,
n, k, (const cuDoubleComplex *) &alpha, A, lda, B, ldb,
(const cuDoubleComplex *) &beta, C, ldc);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZsyr2k failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_zsyr2k_cuda_func(void *descr[], void *cl_arg) static void cl_zsyr2k_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum uplo; MORSE_enum uplo;
...@@ -129,7 +204,8 @@ static void cl_zsyr2k_cuda_func(void *descr[], void *cl_arg) ...@@ -129,7 +204,8 @@ static void cl_zsyr2k_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
* Codelet definition * Codelet definition
......
...@@ -94,6 +94,80 @@ static void cl_zsyrk_cpu_func(void *descr[], void *cl_arg) ...@@ -94,6 +94,80 @@ static void cl_zsyrk_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_zsyrk_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum uplo;
MORSE_enum trans;
int n;
int k;
cuDoubleComplex alpha;
const cuDoubleComplex *A;
int lda;
cuDoubleComplex beta;
cuDoubleComplex *C;
int ldc;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
C = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
starpu_codelet_unpack_args(cl_arg, &uplo, &trans, &n, &k, &alpha, &lda, &beta, &ldc);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasFillMode_t cublasUplo;
if (uplo == MorseUpper){
cublasUplo = CUBLAS_FILL_MODE_UPPER;
}else if(uplo == MorseLower){
cublasUplo = CUBLAS_FILL_MODE_LOWER;
}else if(uplo == MorseUpperLower){
cublasUplo = 0;
}else{
fprintf(stderr, "Error in cl_zsyrk_cuda_func: bad uplo parameter %d\n", uplo);
}
cublasOperation_t cublasTrans;
if (trans == MorseNoTrans){
cublasTrans = CUBLAS_OP_N;
}else if(trans == MorseTrans){
cublasTrans = CUBLAS_OP_T;
}else if(trans == MorseConjTrans){
cublasTrans = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in cl_zsyrk_cuda_func: bad trans parameter %d\n", trans);
}
stat = cublasZsyrk(handle,
cublasUplo, cublasTrans,
n, k,
(const cuDoubleComplex *) &alpha, A, lda,
(const cuDoubleComplex *) &beta, C, ldc);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZsyrk failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_zsyrk_cuda_func(void *descr[], void *cl_arg) static void cl_zsyrk_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum uplo; MORSE_enum uplo;
...@@ -126,7 +200,8 @@ static void cl_zsyrk_cuda_func(void *descr[], void *cl_arg) ...@@ -126,7 +200,8 @@ static void cl_zsyrk_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
* Codelet definition * Codelet definition
......
...@@ -97,6 +97,96 @@ static void cl_ztrmm_cpu_func(void *descr[], void *cl_arg) ...@@ -97,6 +97,96 @@ static void cl_ztrmm_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_ztrmm_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum side;
MORSE_enum uplo;
MORSE_enum transA;
MORSE_enum diag;
int M;
int N;
cuDoubleComplex alpha;
const cuDoubleComplex *A;
int LDA;
cuDoubleComplex *B;
int LDB;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
starpu_codelet_unpack_args(cl_arg, &side, &uplo, &transA, &diag, &M, &N, &alpha, &LDA, &LDB);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasSideMode_t cublasSide;
if (side == MorseLeft){
cublasSide = CUBLAS_SIDE_LEFT;
}else if (side == MorseRight){
cublasSide = CUBLAS_SIDE_RIGHT;
}else{
fprintf(stderr, "Error in cl_ztrmm_cuda_func: bad side parameter %d\n", side);
}
cublasFillMode_t cublasUplo;
if (uplo == MorseUpper){
cublasUplo = CUBLAS_FILL_MODE_UPPER;
}else if(uplo == MorseLower){
cublasUplo = CUBLAS_FILL_MODE_LOWER;
}else if(uplo == MorseUpperLower){
cublasUplo = 0;
}else{
fprintf(stderr, "Error in cl_ztrmm_cuda_func: bad uplo parameter %d\n", uplo);
}
cublasOperation_t cublasTransA;
if (transA == MorseNoTrans){
cublasTransA = CUBLAS_OP_N;
}else if(transA == MorseTrans){
cublasTransA = CUBLAS_OP_T;
}else if(transA == MorseConjTrans){
cublasTransA = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in cl_ztrmm_cuda_func: bad transA parameter %d\n", transA);
}
cublasDiagType_t cublasDiag;
if (diag == MorseNonUnit){
cublasDiag = CUBLAS_DIAG_NON_UNIT;
}else if(diag == MorseUnit){
cublasDiag = CUBLAS_DIAG_UNIT;
}else{
fprintf(stderr, "Error in cl_ztrmm_cuda_func: bad diag parameter %d\n", diag);
}
stat = cublasZtrmm( handle,
cublasSide, cublasUplo, cublasTransA, cublasDiag,
M, N,
(const cuDoubleComplex *) &alpha, A, LDA,
B, LDB, B, LDB);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZtrmm failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_ztrmm_cuda_func(void *descr[], void *cl_arg) static void cl_ztrmm_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum side; MORSE_enum side;
...@@ -131,8 +221,8 @@ static void cl_ztrmm_cuda_func(void *descr[], void *cl_arg) ...@@ -131,8 +221,8 @@ static void cl_ztrmm_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
......
...@@ -97,6 +97,96 @@ static void cl_ztrsm_cpu_func(void *descr[], void *cl_arg) ...@@ -97,6 +97,96 @@ static void cl_ztrsm_cpu_func(void *descr[], void *cl_arg)
} }
#ifdef CHAMELEON_USE_CUDA #ifdef CHAMELEON_USE_CUDA
#if defined(CHAMELEON_USE_CUBLAS_V2)
static void cl_ztrsm_cuda_func(void *descr[], void *cl_arg)
{
MORSE_enum side;
MORSE_enum uplo;
MORSE_enum transA;
MORSE_enum diag;
int m;
int n;
cuDoubleComplex alpha;
const cuDoubleComplex *A;
int lda;
cuDoubleComplex *B;
int ldb;
A = (const cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
B = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
starpu_codelet_unpack_args(cl_arg, &side, &uplo, &transA, &diag, &m, &n, &alpha, &lda, &ldb);
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
CUstream stream = starpu_cuda_get_local_stream();
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasSideMode_t cublasSide;
if (side == MorseLeft){
cublasSide = CUBLAS_SIDE_LEFT;
}else if (side == MorseRight){
cublasSide = CUBLAS_SIDE_RIGHT;
}else{
fprintf(stderr, "Error in cl_ztrsm_cuda_func: bad side parameter %d\n", side);
}
cublasFillMode_t cublasUplo;
if (uplo == MorseUpper){
cublasUplo = CUBLAS_FILL_MODE_UPPER;
}else if(uplo == MorseLower){
cublasUplo = CUBLAS_FILL_MODE_LOWER;
}else if(uplo == MorseUpperLower){
cublasUplo = 0;
}else{
fprintf(stderr, "Error in cl_ztrsm_cuda_func: bad uplo parameter %d\n", uplo);
}
cublasOperation_t cublasTransA;
if (transA == MorseNoTrans){
cublasTransA = CUBLAS_OP_N;
}else if(transA == MorseTrans){
cublasTransA = CUBLAS_OP_T;
}else if(transA == MorseConjTrans){
cublasTransA = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in cl_ztrsm_cuda_func: bad transA parameter %d\n", transA);
}
cublasDiagType_t cublasDiag;
if (diag == MorseNonUnit){
cublasDiag = CUBLAS_DIAG_NON_UNIT;
}else if(diag == MorseUnit){
cublasDiag = CUBLAS_DIAG_UNIT;
}else{
fprintf(stderr, "Error in cl_ztrsm_cuda_func: bad diag parameter %d\n", diag);
}
stat = cublasZtrsm( handle,
cublasSide, cublasUplo, cublasTransA, cublasDiag,
m, n,
(const cuDoubleComplex *) &alpha, A, lda,
B, ldb);
if (stat != CUBLAS_STATUS_SUCCESS){
printf ("cublasZtrsm failed");
cublasDestroy(handle);
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasDestroy(handle);
#ifndef STARPU_CUDA_ASYNC
cudaStreamSynchronize( stream );
#endif
return;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
static void cl_ztrsm_cuda_func(void *descr[], void *cl_arg) static void cl_ztrsm_cuda_func(void *descr[], void *cl_arg)
{ {
MORSE_enum side; MORSE_enum side;
...@@ -131,7 +221,8 @@ static void cl_ztrsm_cuda_func(void *descr[], void *cl_arg) ...@@ -131,7 +221,8 @@ static void cl_ztrsm_cuda_func(void *descr[], void *cl_arg)
return; return;
} }
#endif #endif /* CHAMELEON_USE_CUBLAS_V2 */
#endif /* CHAMELEON_USE_CUDA */
/* /*
* Codelet definition * Codelet definition
......
...@@ -215,6 +215,7 @@ static void cl_ztsmqr_cpu_func(void *descr[], void *cl_arg) ...@@ -215,6 +215,7 @@ static void cl_ztsmqr_cpu_func(void *descr[], void *cl_arg)
#if defined(CHAMELEON_USE_MAGMA) #if defined(CHAMELEON_USE_MAGMA)
#if defined(CHAMELEON_USE_CUBLAS_V2)
magma_int_t magma_int_t
magma_zparfb_gpu(magma_side_t side, magma_trans_t trans, magma_zparfb_gpu(magma_side_t side, magma_trans_t trans,
magma_direct_t direct, magma_storev_t storev, magma_direct_t direct, magma_storev_t storev,
...@@ -229,6 +230,325 @@ magma_zparfb_gpu(magma_side_t side, magma_trans_t trans, ...@@ -229,6 +230,325 @@ magma_zparfb_gpu(magma_side_t side, magma_trans_t trans,
magmaDoubleComplex *WORKC, magma_int_t LDWORKC, magmaDoubleComplex *WORKC, magma_int_t LDWORKC,
CUstream stream) CUstream stream)
{
#if defined(PRECISION_z) || defined(PRECISION_c)
cuDoubleComplex zzero = make_cuDoubleComplex(0.0, 0.0);
cuDoubleComplex zone = make_cuDoubleComplex(1.0, 0.0);
cuDoubleComplex mzone = make_cuDoubleComplex(-1.0, 0.0);
#else
double zzero = 0.0;
double zone = 1.0;
double mzone = -1.0;
#endif /* defined(PRECISION_z) || defined(PRECISION_c) */
int j;
magma_trans_t transW;
magma_trans_t transA2;
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("CUBLAS initialization failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
stat = cublasSetStream(handle, stream);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf ("cublasSetStream failed\n");
assert( stat == CUBLAS_STATUS_SUCCESS );
}
cublasOperation_t cublasTrans;
if (trans == MagmaNoTrans){
cublasTrans = CUBLAS_OP_N;
}else if(trans == MagmaTrans){
cublasTrans = CUBLAS_OP_T;
}else if(trans == MagmaConjTrans){
cublasTrans = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in magma_zparfb_gpu: bad trans parameter %d\n", trans);
}
/* Check input arguments */
if ((side != MagmaLeft) && (side != MagmaRight)) {
return -1;
}
if ((trans != MagmaNoTrans) && (trans != MagmaConjTrans)) {
return -2;
}
if ((direct != MagmaForward) && (direct != MagmaBackward)) {
return -3;
}
if ((storev != MagmaColumnwise) && (storev != MagmaRowwise)) {
return -4;
}
if (M1 < 0) {
return -5;
}
if (N1 < 0) {
return -6;
}
if ((M2 < 0) ||
( (side == MagmaRight) && (M1 != M2) ) ) {
return -7;
}
if ((N2 < 0) ||
( (side == MagmaLeft) && (N1 != N2) ) ) {
return -8;
}
if (K < 0) {
return -9;
}
/* Quick return */
if ((M1 == 0) || (N1 == 0) || (M2 == 0) || (N2 == 0) || (K == 0))
return MAGMA_SUCCESS;
if (direct == MagmaForward) {
if (side == MagmaLeft) {
/*
* Column or Rowwise / Forward / Left
* ----------------------------------
*
* Form H * A or H' * A where A = ( A1 )
* ( A2 )
*/
/*
* W = A1 + V' * A2:
* W = A1
* W = W + V' * A2
*
*/
cudaMemcpy2DAsync( WORK, LDWORK * sizeof(cuDoubleComplex),
A1, LDA1 * sizeof(cuDoubleComplex),
K * sizeof(cuDoubleComplex), N1,
cudaMemcpyDeviceToDevice, stream );
transW = storev == MorseColumnwise ? MagmaConjTrans : MagmaNoTrans;
transA2 = storev == MorseColumnwise ? MagmaNoTrans : MagmaConjTrans;
cublasOperation_t cublasTransW;
if (transW == MagmaNoTrans){
cublasTransW = CUBLAS_OP_N;
}else if(transW == MagmaTrans){
cublasTransW = CUBLAS_OP_T;
}else if(transW == MagmaConjTrans){
cublasTransW = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in magma_zparfb_gpu: bad transW parameter %d\n", transW);
}
cublasOperation_t cublasTransA2;
if (transA2 == MagmaNoTrans){
cublasTransA2 = CUBLAS_OP_N;
}else if(transA2 == MagmaTrans){
cublasTransA2 = CUBLAS_OP_T;
}else if(transA2 == MagmaConjTrans){
cublasTransA2 = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in magma_zparfb_gpu: bad transA2 parameter %d\n", transA2);
}
cublasZgemm(handle, cublasTransW, CUBLAS_OP_N,
K, N1, M2,
(const cuDoubleComplex *) &zone,
(const cuDoubleComplex*)V /* K*M2 */, LDV,
(const cuDoubleComplex*)A2 /* M2*N1 */, LDA2,
(const cuDoubleComplex *) &zone,
(cuDoubleComplex*)WORK /* K*N1 */, LDWORK);
WORKC = NULL;
if (WORKC == NULL) {
/* W = op(T) * W */
cublasZtrmm( handle,
CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_UPPER,
cublasTrans, CUBLAS_DIAG_NON_UNIT,
K, N2,
(const cuDoubleComplex *) &zone,
(const cuDoubleComplex*)T, LDT,
(cuDoubleComplex*)WORK, LDWORK,
(cuDoubleComplex*)WORK, LDWORK);
/* A1 = A1 - W = A1 - op(T) * W */
for(j = 0; j < N1; j++) {
cublasZaxpy(handle, K, (const cuDoubleComplex *) &mzone,
(const cuDoubleComplex*)(WORK + LDWORK*j), 1,
(cuDoubleComplex*)(A1 + LDA1*j), 1);
}
/* A2 = A2 - op(V) * W */
cublasZgemm(handle, cublasTransA2, CUBLAS_OP_N,
M2, N2, K,
(const cuDoubleComplex *) &mzone,
(const cuDoubleComplex*)V /* M2*K */, LDV,
(const cuDoubleComplex*)WORK /* K*N2 */, LDWORK,
(const cuDoubleComplex *) &zone,
(cuDoubleComplex*)A2 /* m2*N2 */, LDA2);
} else {
/* Wc = V * op(T) */
cublasZgemm( handle, cublasTransA2, cublasTrans,
M2, K, K,
(const cuDoubleComplex *) &zone, V, LDV,
T, LDT,
(const cuDoubleComplex *) &zzero, WORKC, LDWORKC );
/* A1 = A1 - opt(T) * W */
cublasZgemm( handle, cublasTrans, CUBLAS_OP_N,
K, N1, K,
(const cuDoubleComplex *) &mzone,
(const cuDoubleComplex *)T, LDT,
(const cuDoubleComplex *)WORK, LDWORK,
(const cuDoubleComplex *) &zone,
(cuDoubleComplex*)A1, LDA1 );
/* A2 = A2 - Wc * W */
cublasZgemm( handle, CUBLAS_OP_N, CUBLAS_OP_N,
M2, N2, K,
(const cuDoubleComplex *) &mzone,
(const cuDoubleComplex *)WORKC, LDWORKC,
(const cuDoubleComplex *)WORK, LDWORK,
(const cuDoubleComplex *) &zone,
(cuDoubleComplex *)A2, LDA2 );
}
}
else {
/*
* Column or Rowwise / Forward / Right
* -----------------------------------
*
* Form H * A or H' * A where A = ( A1 A2 )
*
*/
/*
* W = A1 + A2 * V':
* W = A1
* W = W + A2 * V'
*
*/
cudaMemcpy2DAsync( WORK, LDWORK * sizeof(cuDoubleComplex),
A1, LDA1 * sizeof(cuDoubleComplex),
M1 * sizeof(cuDoubleComplex), K,
cudaMemcpyDeviceToDevice, stream );
transW = storev == MorseColumnwise ? MagmaNoTrans : MagmaConjTrans;
transA2 = storev == MorseColumnwise ? MagmaConjTrans : MagmaNoTrans;
cublasOperation_t cublasTransW;
if (transW == MagmaNoTrans){
cublasTransW = CUBLAS_OP_N;
}else if(transW == MagmaTrans){
cublasTransW = CUBLAS_OP_T;
}else if(transW == MagmaConjTrans){
cublasTransW = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in magma_zparfb_gpu: bad transW parameter %d\n", transW);
}
cublasOperation_t cublasTransA2;
if (transA2 == MagmaNoTrans){
cublasTransA2 = CUBLAS_OP_N;
}else if(transA2 == MagmaTrans){
cublasTransA2 = CUBLAS_OP_T;
}else if(transA2 == MagmaConjTrans){
cublasTransA2 = CUBLAS_OP_C;
}else{
fprintf(stderr, "Error in magma_zparfb_gpu: bad transA2 parameter %d\n", transA2);
}
cublasZgemm(handle, CUBLAS_OP_N, cublasTransW,
M1, K, N2,
(const cuDoubleComplex *) &zone,
(const cuDoubleComplex*)A2 /* M1*N2 */, LDA2,
(const cuDoubleComplex*)V /* N2*K */, LDV,
(const cuDoubleComplex *) &zone,
(cuDoubleComplex*)WORK /* M1*K */, LDWORK);
WORKC = NULL;
if (WORKC == NULL) {
/* W = W * op(T) */
cublasZtrmm( handle,
CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER,
cublasTrans, CUBLAS_DIAG_NON_UNIT,
M2, K,
(const cuDoubleComplex *) &zone,
(const cuDoubleComplex*)T, LDT,
(cuDoubleComplex*)WORK, LDWORK,
(cuDoubleComplex*)WORK, LDWORK);
/* A1 = A1 - W = A1 - W * op(T) */
for(j = 0; j < K; j++) {
cublasZaxpy(handle, M1, (const cuDoubleComplex *) &mzone,
(const cuDoubleComplex*)(WORK + LDWORK*j), 1,
(cuDoubleComplex*)(A1 + LDA1*j), 1);
}
/* A2 = A2 - W * op(V) */
cublasZgemm(handle, CUBLAS_OP_N, cublasTransA2,
M2, N2, K,
(const cuDoubleComplex *) &mzone,
(const cuDoubleComplex*)WORK /* M2*K */, LDWORK,
(const cuDoubleComplex*)V /* K*N2 */, LDV,
(const cuDoubleComplex *) &zone,
(cuDoubleComplex*)A2 /* M2*N2 */, LDA2);
} else {
/* A1 = A1 - W * opt(T) */
cublasZgemm( handle, CUBLAS_OP_N, cublasTrans,
M1, K, K,
(const cuDoubleComplex *) &mzone,
(const cuDoubleComplex *)WORK, LDWORK,
(const cuDoubleComplex *)T, LDT,
(const cuDoubleComplex *) &zone,
(cuDoubleComplex *)A1, LDA1 );
/* Wc = op(T) * V */
cublasZgemm( handle, cublasTrans, cublasTransA2,
K, N2, K,
(const cuDoubleComplex *) &zone,
(const cuDoubleComplex *)T, LDT,
(const cuDoubleComplex *)V, LDV,
(const cuDoubleComplex *) &zzero,
(cuDoubleComplex *)WORKC, LDWORKC );
/* A2 = A2 - W * Wc */
cublasZgemm( handle, CUBLAS_OP_N, CUBLAS_OP_N,
M2, N2, K,
(const cuDoubleComplex *) &mzone,
(const cuDoubleComplex *)WORK, LDWORK,
(const cuDoubleComplex *)WORKC, LDWORKC,
(const cuDoubleComplex *) &zone,
(cuDoubleComplex *)A2, LDA2 );
}
}
}
else {
fprintf(stderr, "Not implemented (Backward / Left or Right)");
return MAGMA_ERR_NOT_SUPPORTED;
}
cublasDestroy(handle);
return MAGMA_SUCCESS;
}
#else /* CHAMELEON_USE_CUBLAS_V2 */
magma_int_t
magma_zparfb_gpu(magma_side_t side, magma_trans_t trans,
magma_direct_t direct, magma_storev_t storev,
magma_int_t M1, magma_int_t N1,
magma_int_t M2, magma_int_t N2,
magma_int_t K, magma_int_t L,
magmaDoubleComplex *A1, magma_int_t LDA1,
magmaDoubleComplex *A2, magma_int_t LDA2,
const magmaDoubleComplex *V, magma_int_t LDV,
const magmaDoubleComplex *T, magma_int_t LDT,
magmaDoubleComplex *WORK, magma_int_t LDWORK,
magmaDoubleComplex *WORKC, magma_int_t LDWORKC,
CUstream stream)
{ {
#if defined(PRECISION_z) || defined(PRECISION_c) #if defined(PRECISION_z) || defined(PRECISION_c)
cuDoubleComplex zzero = make_cuDoubleComplex(0.0, 0.0); cuDoubleComplex zzero = make_cuDoubleComplex(0.0, 0.0);
...@@ -452,6 +772,7 @@ magma_zparfb_gpu(magma_side_t side, magma_trans_t trans, ...@@ -452,6 +772,7 @@ magma_zparfb_gpu(magma_side_t side, magma_trans_t trans,
return MAGMA_SUCCESS; return MAGMA_SUCCESS;
} }
#endif /* CHAMELEON_USE_CUBLAS_V2 */
magma_int_t magma_int_t
magma_ztsmqr_gpu( magma_side_t side, magma_trans_t trans, magma_ztsmqr_gpu( magma_side_t side, magma_trans_t trans,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment