diff --git a/runtime/starpu/codelets/codelet_gemmex.c b/runtime/starpu/codelets/codelet_gemmex.c index c59af2cb0d334cd6c036e631732d5dff7a7fb033..cf99b426e0e3bb018ccd879848f7584f49e4c8d0 100644 --- a/runtime/starpu/codelets/codelet_gemmex.c +++ b/runtime/starpu/codelets/codelet_gemmex.c @@ -43,43 +43,52 @@ cl_gemmex_cuda_func( void *descr[], void *cl_arg ) CHAM_tile_t *tileC; void *ptrAlpha, *ptrBeta; + CHAMELEON_Real16_t halpha = clargs->alpha; + CHAMELEON_Real16_t hbeta = clargs->beta; + float salpha = clargs->alpha; + float sbeta = clargs->beta; + double dalpha = clargs->alpha; + double dbeta = clargs->beta; + CHAMELEON_Complex32_t calpha = clargs->alpha; + CHAMELEON_Complex32_t cbeta = clargs->beta; + CHAMELEON_Complex64_t zalpha = clargs->alpha; + CHAMELEON_Complex64_t zbeta = clargs->beta; + + tileA = cti_interface_get(descr[0]); + tileB = cti_interface_get(descr[1]); + tileC = cti_interface_get(descr[2]); + + assert( tileA->format & CHAMELEON_TILE_FULLRANK ); + assert( tileB->format & CHAMELEON_TILE_FULLRANK ); + assert( tileC->format & CHAMELEON_TILE_FULLRANK ); + switch( tileC->flttype ) { case ChamRealHalf: { - CHAMELEON_Real16_t halpha = clargs->alpha; - CHAMELEON_Real16_t hbeta = clargs->beta; ptrAlpha = &halpha; ptrBeta = &hbeta; } break; case ChamRealFloat: { - float salpha = clargs->alpha; - float sbeta = clargs->beta; ptrAlpha = &salpha; ptrBeta = &sbeta; } break; case ChamRealDouble: { - double dalpha = clargs->alpha; - double dbeta = clargs->beta; ptrAlpha = &dalpha; ptrBeta = &dbeta; } break; case ChamComplexFloat: { - CHAMELEON_Complex32_t calpha = clargs->alpha; - CHAMELEON_Complex32_t cbeta = clargs->beta; ptrAlpha = &calpha; ptrBeta = &cbeta; } break; case ChamComplexDouble: { - CHAMELEON_Complex64_t zalpha = clargs->alpha; - CHAMELEON_Complex64_t zbeta = clargs->beta; ptrAlpha = &zalpha; ptrBeta = &zbeta; } @@ -89,14 +98,6 @@ cl_gemmex_cuda_func( void *descr[], void *cl_arg ) return; } - tileA = cti_interface_get(descr[0]); - tileB = cti_interface_get(descr[1]); - tileC = cti_interface_get(descr[2]); - - assert( tileA->format & CHAMELEON_TILE_FULLRANK ); - assert( tileB->format & CHAMELEON_TILE_FULLRANK ); - assert( tileC->format & CHAMELEON_TILE_FULLRANK ); - CUDA_gemmex( clargs->transA, clargs->transB, clargs->m, clargs->n, clargs->k,