diff --git a/compute/pzgemm.c b/compute/pzgemm.c index a67977fa4d91d4657e1fe68032441faab9035261..65a15037e17b26e88f1baba17809cc6e0f82d401 100644 --- a/compute/pzgemm.c +++ b/compute/pzgemm.c @@ -26,8 +26,8 @@ #define A(m, n) A, m, n #define B(m, n) B, m, n #define C(m, n) C, m, n -#define WA(m, n) &WA, m, n -#define WB(m, n) &WB, m, n +#define WA(m, n) WA, m, n +#define WB(m, n) WB, m, n /** * Parallel tile matrix-matrix multiplication @@ -37,6 +37,7 @@ static inline void chameleon_pzgemm_summa( CHAM_context_t *chamctxt, cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, + CHAM_desc_t *WA, CHAM_desc_t *WB, RUNTIME_option_t *options ) { RUNTIME_sequence_t *sequence = options->sequence; @@ -46,19 +47,8 @@ chameleon_pzgemm_summa( CHAM_context_t *chamctxt, cham_trans_t transA, cham_tran CHAMELEON_Complex64_t zbeta; CHAMELEON_Complex64_t zone = (CHAMELEON_Complex64_t)1.0; - CHAM_desc_t WA, WB; lookahead = chamctxt->lookahead; - chameleon_desc_init( &WA, CHAMELEON_MAT_ALLOC_TILE, - ChamComplexDouble, C->mb, C->nb, (C->mb * C->nb), - C->mt * C->mb, C->nb * C->q * lookahead, 0, 0, - C->mt * C->mb, C->nb * C->q * lookahead, C->p, C->q, - NULL, NULL, NULL ); - chameleon_desc_init( &WB, CHAMELEON_MAT_ALLOC_TILE, - ChamComplexDouble, C->mb, C->nb, (C->mb * C->nb), - C->mb * C->p * lookahead, C->nt * C->nb, 0, 0, - C->mb * C->p * lookahead, C->nt * C->nb, C->p, C->q, - NULL, NULL, NULL ); KT = transA == ChamNoTrans ? A->nt : A->mt; K = transA == ChamNoTrans ? A->n : A->m; @@ -171,12 +161,8 @@ chameleon_pzgemm_summa( CHAM_context_t *chamctxt, cham_trans_t transA, cham_tran } } - RUNTIME_desc_flush( &WA, sequence ); - RUNTIME_desc_flush( &WB, sequence ); - RUNTIME_desc_flush( C, sequence ); - chameleon_sequence_wait( chamctxt, sequence ); - chameleon_desc_destroy( &WA ); - chameleon_desc_destroy( &WB ); + CHAMELEON_Desc_Flush( WA, sequence ); + CHAMELEON_Desc_Flush( WB, sequence ); } /** @@ -286,10 +272,11 @@ chameleon_pzgemm_generic( CHAM_context_t *chamctxt, cham_trans_t transA, cham_tr } /** - * Parallel tile matrix-matrix multiplication. wrapper. + * Parallel tile matrix-matrix multiplication wrapper. */ void -chameleon_pzgemm( cham_trans_t transA, cham_trans_t transB, +chameleon_pzgemm( struct chameleon_pzgemm_s *ws, + cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ) @@ -303,11 +290,10 @@ chameleon_pzgemm( cham_trans_t transA, cham_trans_t transB, } RUNTIME_options_init( &options, chamctxt, sequence, request ); - if ( ((C->p > 1) || (C->q > 1)) && - (C->get_rankof == chameleon_getrankof_2d) && - (chamctxt->generic_enabled != CHAMELEON_TRUE) ) + if ( ws->summa ) { - chameleon_pzgemm_summa( chamctxt, transA, transB, alpha, A, B, beta, C, &options ); + chameleon_pzgemm_summa( chamctxt, transA, transB, alpha, A, B, beta, C, + &(ws->WA), &(ws->WB), &options ); } else { chameleon_pzgemm_generic( chamctxt, transA, transB, alpha, A, B, beta, C, &options ); diff --git a/compute/zgemm.c b/compute/zgemm.c index fed78d6f466c139250a290aed208674028e4c2b7..4c964a0eb609bbd1bc4c3e7d8638f41ea9c79921 100644 --- a/compute/zgemm.c +++ b/compute/zgemm.c @@ -236,9 +236,9 @@ void CHAMELEON_zgemm_WS_Free( void *user_ws ) * */ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int K, - CHAMELEON_Complex64_t alpha, CHAMELEON_Complex64_t *A, int LDA, - CHAMELEON_Complex64_t *B, int LDB, - CHAMELEON_Complex64_t beta, CHAMELEON_Complex64_t *C, int LDC ) + CHAMELEON_Complex64_t alpha, CHAMELEON_Complex64_t *A, int LDA, + CHAMELEON_Complex64_t *B, int LDB, + CHAMELEON_Complex64_t beta, CHAMELEON_Complex64_t *C, int LDC ) { int NB; int Am, An, Bm, Bn; @@ -249,6 +249,7 @@ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int CHAM_context_t *chamctxt; RUNTIME_sequence_t *sequence = NULL; RUNTIME_request_t request = RUNTIME_REQUEST_INITIALIZER; + void *ws; chamctxt = chameleon_context_self(); if (chamctxt == NULL) { @@ -319,26 +320,28 @@ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int /* Submit the matrix conversion */ chameleon_zlap2tile( chamctxt, &descAl, &descAt, ChamDescInput, ChamUpperLower, - A, NB, NB, LDA, An, Am, An, sequence, &request ); + A, NB, NB, LDA, An, Am, An, sequence, &request ); chameleon_zlap2tile( chamctxt, &descBl, &descBt, ChamDescInput, ChamUpperLower, - B, NB, NB, LDB, Bn, Bm, Bn, sequence, &request ); + B, NB, NB, LDB, Bn, Bm, Bn, sequence, &request ); chameleon_zlap2tile( chamctxt, &descCl, &descCt, ChamDescInout, ChamUpperLower, - C, NB, NB, LDC, N, M, N, sequence, &request ); + C, NB, NB, LDC, N, M, N, sequence, &request ); /* Call the tile interface */ - CHAMELEON_zgemm_Tile_Async( transA, transB, alpha, &descAt, &descBt, beta, &descCt, sequence, &request ); + ws = CHAMELEON_zgemm_WS_Alloc( transA, transB, &descAt, &descBt, &descCt ); + CHAMELEON_zgemm_Tile_Async( transA, transB, alpha, &descAt, &descBt, beta, &descCt, ws, sequence, &request ); /* Submit the matrix conversion back */ chameleon_ztile2lap( chamctxt, &descAl, &descAt, - ChamDescInput, ChamUpperLower, sequence, &request ); + ChamDescInput, ChamUpperLower, sequence, &request ); chameleon_ztile2lap( chamctxt, &descBl, &descBt, - ChamDescInput, ChamUpperLower, sequence, &request ); + ChamDescInput, ChamUpperLower, sequence, &request ); chameleon_ztile2lap( chamctxt, &descCl, &descCt, - ChamDescInout, ChamUpperLower, sequence, &request ); + ChamDescInout, ChamUpperLower, sequence, &request ); chameleon_sequence_wait( chamctxt, sequence ); /* Cleanup the temporary data */ + CHAMELEON_zgemm_WS_Free( ws ); chameleon_ztile2lap_cleanup( chamctxt, &descAl, &descAt ); chameleon_ztile2lap_cleanup( chamctxt, &descBl, &descBt ); chameleon_ztile2lap_cleanup( chamctxt, &descCl, &descCt ); @@ -405,13 +408,14 @@ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int * */ int CHAMELEON_zgemm_Tile( cham_trans_t transA, cham_trans_t transB, - CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, - CHAMELEON_Complex64_t beta, CHAM_desc_t *C ) + CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, + CHAMELEON_Complex64_t beta, CHAM_desc_t *C ) { CHAM_context_t *chamctxt; RUNTIME_sequence_t *sequence = NULL; RUNTIME_request_t request = RUNTIME_REQUEST_INITIALIZER; int status; + void *ws; chamctxt = chameleon_context_self(); if (chamctxt == NULL) { @@ -420,13 +424,17 @@ int CHAMELEON_zgemm_Tile( cham_trans_t transA, cham_trans_t transB, } chameleon_sequence_create( chamctxt, &sequence ); - CHAMELEON_zgemm_Tile_Async( transA, transB, alpha, A, B, beta, C, sequence, &request ); + ws = CHAMELEON_zgemm_WS_Alloc( transA, transB, A, B, C ); + CHAMELEON_zgemm_Tile_Async( transA, transB, alpha, A, B, beta, C, ws, sequence, &request ); CHAMELEON_Desc_Flush( A, sequence ); CHAMELEON_Desc_Flush( B, sequence ); CHAMELEON_Desc_Flush( C, sequence ); chameleon_sequence_wait( chamctxt, sequence ); + + CHAMELEON_zgemm_WS_Free( ws ); + status = sequence->status; chameleon_sequence_destroy( chamctxt, sequence ); return status; @@ -461,11 +469,13 @@ int CHAMELEON_zgemm_Tile( cham_trans_t transA, cham_trans_t transB, * */ int CHAMELEON_zgemm_Tile_Async( cham_trans_t transA, cham_trans_t transB, - CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, - CHAMELEON_Complex64_t beta, CHAM_desc_t *C, - RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ) + CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, + CHAMELEON_Complex64_t beta, CHAM_desc_t *C, + void *user_ws, + RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ) { CHAM_context_t *chamctxt; + struct chameleon_pzgemm_s *ws; int M, N, K; int Am, An, Ai, Aj, Amb, Anb; int Bm, Bn, Bi, Bj, Bmb, Bnb; @@ -570,7 +580,21 @@ int CHAMELEON_zgemm_Tile_Async( cham_trans_t transA, cham_trans_t transB, return CHAMELEON_SUCCESS; } - chameleon_pzgemm( transA, transB, alpha, A, B, beta, C, sequence, request ); + if ( user_ws == NULL ) { + ws = CHAMELEON_zgemm_WS_Alloc( transA, transB, A, B, C ); + } + else { + ws = user_ws; + } + + chameleon_pzgemm( ws, transA, transB, alpha, A, B, beta, C, sequence, request ); + if ( user_ws == NULL ) { + CHAMELEON_Desc_Flush( A, sequence ); + CHAMELEON_Desc_Flush( B, sequence ); + CHAMELEON_Desc_Flush( C, sequence ); + chameleon_sequence_wait( chamctxt, sequence ); + CHAMELEON_zgemm_WS_Free( ws ); + } return CHAMELEON_SUCCESS; } diff --git a/control/compute_z.h b/control/compute_z.h index cf3a1d0ae5ae99a5959800b46da3f035e26efc86..cb688350ba89261715d0249214ad8eb98ac5b72d 100644 --- a/control/compute_z.h +++ b/control/compute_z.h @@ -25,6 +25,15 @@ #ifndef _compute_z_h_ #define _compute_z_h_ +/** + * @brief Data structure to handle the GEMM workspaces + */ +struct chameleon_pzgemm_s { + int summa; + CHAM_desc_t WA; + CHAM_desc_t WB; +}; + /** * Declarations of internal sequential functions */ @@ -40,7 +49,7 @@ void chameleon_pzgebrd_ge2gb( int genD, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_des void chameleon_pzgelqf( int genD, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *D, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); void chameleon_pzgelqfrh( int genD, int BS, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *D, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); void chameleon_pzgenm2( double tol, const CHAM_desc_t *A, double *result, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ); -void chameleon_pzgemm(cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); +void chameleon_pzgemm( struct chameleon_pzgemm_s *options, cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); void chameleon_pzgepdf_qdwh( cham_mtxtype_t trans, CHAM_desc_t *descU, CHAM_desc_t *descH, gepdf_info_t *info, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ); void chameleon_pzgepdf_qr( int genD, int doqr, int optid, const libhqr_tree_t *qrtreeT, const libhqr_tree_t *qrtreeB, CHAM_desc_t *A1, CHAM_desc_t *TS1, CHAM_desc_t *TT1, CHAM_desc_t *D1, CHAM_desc_t *Q1, CHAM_desc_t *A2, CHAM_desc_t *TS2, CHAM_desc_t *TT2, CHAM_desc_t *D2, CHAM_desc_t *Q2, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ); void chameleon_pzgeqrf( int genD, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *D, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); diff --git a/include/chameleon/chameleon_z.h b/include/chameleon/chameleon_z.h index 682217abc9a2625199f424ffff45dcddd76e183b..4ae7e3ef66111777f6176227f63a8cdd9504ac9e 100644 --- a/include/chameleon/chameleon_z.h +++ b/include/chameleon/chameleon_z.h @@ -192,7 +192,7 @@ int CHAMELEON_zgelqf_Tile_Async(CHAM_desc_t *A, CHAM_desc_t *T, RUNTIME_sequence int CHAMELEON_zgelqs_Tile_Async(CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *B, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); int CHAMELEON_zgels_Tile_Async(cham_trans_t trans, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *B, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); int CHAMELEON_zgenm2_Tile_Async( double tol, CHAM_desc_t *A, double *value, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ); -int CHAMELEON_zgemm_Tile_Async(cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); +int CHAMELEON_zgemm_Tile_Async(cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, void *ws, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); int CHAMELEON_zgepdf_qdwh_Tile_Async( CHAM_desc_t *A, CHAM_desc_t *H, gepdf_info_t *info, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request ); int CHAMELEON_zgeqrf_Tile_Async(CHAM_desc_t *A, CHAM_desc_t *T, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request); int CHAMELEON_zgeqrs_Tile_Async(CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *B, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);