Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 90d85dad authored by PRUVOST Florent's avatar PRUVOST Florent Committed by Mathieu Faverge
Browse files

Make gemm_async interface really asynchronous

parent ce934b36
No related branches found
No related tags found
No related merge requests found
......@@ -26,8 +26,8 @@
#define A(m, n) A, m, n
#define B(m, n) B, m, n
#define C(m, n) C, m, n
#define WA(m, n) &WA, m, n
#define WB(m, n) &WB, m, n
#define WA(m, n) WA, m, n
#define WB(m, n) WB, m, n
/**
* Parallel tile matrix-matrix multiplication
......@@ -37,6 +37,7 @@ static inline void
chameleon_pzgemm_summa( CHAM_context_t *chamctxt, cham_trans_t transA, cham_trans_t transB,
CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
CHAMELEON_Complex64_t beta, CHAM_desc_t *C,
CHAM_desc_t *WA, CHAM_desc_t *WB,
RUNTIME_option_t *options )
{
RUNTIME_sequence_t *sequence = options->sequence;
......@@ -46,19 +47,8 @@ chameleon_pzgemm_summa( CHAM_context_t *chamctxt, cham_trans_t transA, cham_tran
CHAMELEON_Complex64_t zbeta;
CHAMELEON_Complex64_t zone = (CHAMELEON_Complex64_t)1.0;
CHAM_desc_t WA, WB;
lookahead = chamctxt->lookahead;
chameleon_desc_init( &WA, CHAMELEON_MAT_ALLOC_TILE,
ChamComplexDouble, C->mb, C->nb, (C->mb * C->nb),
C->mt * C->mb, C->nb * C->q * lookahead, 0, 0,
C->mt * C->mb, C->nb * C->q * lookahead, C->p, C->q,
NULL, NULL, NULL );
chameleon_desc_init( &WB, CHAMELEON_MAT_ALLOC_TILE,
ChamComplexDouble, C->mb, C->nb, (C->mb * C->nb),
C->mb * C->p * lookahead, C->nt * C->nb, 0, 0,
C->mb * C->p * lookahead, C->nt * C->nb, C->p, C->q,
NULL, NULL, NULL );
KT = transA == ChamNoTrans ? A->nt : A->mt;
K = transA == ChamNoTrans ? A->n : A->m;
......@@ -171,12 +161,8 @@ chameleon_pzgemm_summa( CHAM_context_t *chamctxt, cham_trans_t transA, cham_tran
}
}
RUNTIME_desc_flush( &WA, sequence );
RUNTIME_desc_flush( &WB, sequence );
RUNTIME_desc_flush( C, sequence );
chameleon_sequence_wait( chamctxt, sequence );
chameleon_desc_destroy( &WA );
chameleon_desc_destroy( &WB );
CHAMELEON_Desc_Flush( WA, sequence );
CHAMELEON_Desc_Flush( WB, sequence );
}
/**
......@@ -286,10 +272,11 @@ chameleon_pzgemm_generic( CHAM_context_t *chamctxt, cham_trans_t transA, cham_tr
}
/**
* Parallel tile matrix-matrix multiplication. wrapper.
* Parallel tile matrix-matrix multiplication wrapper.
*/
void
chameleon_pzgemm( cham_trans_t transA, cham_trans_t transB,
chameleon_pzgemm( struct chameleon_pzgemm_s *ws,
cham_trans_t transA, cham_trans_t transB,
CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
CHAMELEON_Complex64_t beta, CHAM_desc_t *C,
RUNTIME_sequence_t *sequence, RUNTIME_request_t *request )
......@@ -303,11 +290,10 @@ chameleon_pzgemm( cham_trans_t transA, cham_trans_t transB,
}
RUNTIME_options_init( &options, chamctxt, sequence, request );
if ( ((C->p > 1) || (C->q > 1)) &&
(C->get_rankof == chameleon_getrankof_2d) &&
(chamctxt->generic_enabled != CHAMELEON_TRUE) )
if ( ws->summa )
{
chameleon_pzgemm_summa( chamctxt, transA, transB, alpha, A, B, beta, C, &options );
chameleon_pzgemm_summa( chamctxt, transA, transB, alpha, A, B, beta, C,
&(ws->WA), &(ws->WB), &options );
}
else {
chameleon_pzgemm_generic( chamctxt, transA, transB, alpha, A, B, beta, C, &options );
......
......@@ -236,9 +236,9 @@ void CHAMELEON_zgemm_WS_Free( void *user_ws )
*
*/
int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int K,
CHAMELEON_Complex64_t alpha, CHAMELEON_Complex64_t *A, int LDA,
CHAMELEON_Complex64_t *B, int LDB,
CHAMELEON_Complex64_t beta, CHAMELEON_Complex64_t *C, int LDC )
CHAMELEON_Complex64_t alpha, CHAMELEON_Complex64_t *A, int LDA,
CHAMELEON_Complex64_t *B, int LDB,
CHAMELEON_Complex64_t beta, CHAMELEON_Complex64_t *C, int LDC )
{
int NB;
int Am, An, Bm, Bn;
......@@ -249,6 +249,7 @@ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int
CHAM_context_t *chamctxt;
RUNTIME_sequence_t *sequence = NULL;
RUNTIME_request_t request = RUNTIME_REQUEST_INITIALIZER;
void *ws;
chamctxt = chameleon_context_self();
if (chamctxt == NULL) {
......@@ -319,26 +320,28 @@ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int
/* Submit the matrix conversion */
chameleon_zlap2tile( chamctxt, &descAl, &descAt, ChamDescInput, ChamUpperLower,
A, NB, NB, LDA, An, Am, An, sequence, &request );
A, NB, NB, LDA, An, Am, An, sequence, &request );
chameleon_zlap2tile( chamctxt, &descBl, &descBt, ChamDescInput, ChamUpperLower,
B, NB, NB, LDB, Bn, Bm, Bn, sequence, &request );
B, NB, NB, LDB, Bn, Bm, Bn, sequence, &request );
chameleon_zlap2tile( chamctxt, &descCl, &descCt, ChamDescInout, ChamUpperLower,
C, NB, NB, LDC, N, M, N, sequence, &request );
C, NB, NB, LDC, N, M, N, sequence, &request );
/* Call the tile interface */
CHAMELEON_zgemm_Tile_Async( transA, transB, alpha, &descAt, &descBt, beta, &descCt, sequence, &request );
ws = CHAMELEON_zgemm_WS_Alloc( transA, transB, &descAt, &descBt, &descCt );
CHAMELEON_zgemm_Tile_Async( transA, transB, alpha, &descAt, &descBt, beta, &descCt, ws, sequence, &request );
/* Submit the matrix conversion back */
chameleon_ztile2lap( chamctxt, &descAl, &descAt,
ChamDescInput, ChamUpperLower, sequence, &request );
ChamDescInput, ChamUpperLower, sequence, &request );
chameleon_ztile2lap( chamctxt, &descBl, &descBt,
ChamDescInput, ChamUpperLower, sequence, &request );
ChamDescInput, ChamUpperLower, sequence, &request );
chameleon_ztile2lap( chamctxt, &descCl, &descCt,
ChamDescInout, ChamUpperLower, sequence, &request );
ChamDescInout, ChamUpperLower, sequence, &request );
chameleon_sequence_wait( chamctxt, sequence );
/* Cleanup the temporary data */
CHAMELEON_zgemm_WS_Free( ws );
chameleon_ztile2lap_cleanup( chamctxt, &descAl, &descAt );
chameleon_ztile2lap_cleanup( chamctxt, &descBl, &descBt );
chameleon_ztile2lap_cleanup( chamctxt, &descCl, &descCt );
......@@ -405,13 +408,14 @@ int CHAMELEON_zgemm( cham_trans_t transA, cham_trans_t transB, int M, int N, int
*
*/
int CHAMELEON_zgemm_Tile( cham_trans_t transA, cham_trans_t transB,
CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
CHAMELEON_Complex64_t beta, CHAM_desc_t *C )
CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
CHAMELEON_Complex64_t beta, CHAM_desc_t *C )
{
CHAM_context_t *chamctxt;
RUNTIME_sequence_t *sequence = NULL;
RUNTIME_request_t request = RUNTIME_REQUEST_INITIALIZER;
int status;
void *ws;
chamctxt = chameleon_context_self();
if (chamctxt == NULL) {
......@@ -420,13 +424,17 @@ int CHAMELEON_zgemm_Tile( cham_trans_t transA, cham_trans_t transB,
}
chameleon_sequence_create( chamctxt, &sequence );
CHAMELEON_zgemm_Tile_Async( transA, transB, alpha, A, B, beta, C, sequence, &request );
ws = CHAMELEON_zgemm_WS_Alloc( transA, transB, A, B, C );
CHAMELEON_zgemm_Tile_Async( transA, transB, alpha, A, B, beta, C, ws, sequence, &request );
CHAMELEON_Desc_Flush( A, sequence );
CHAMELEON_Desc_Flush( B, sequence );
CHAMELEON_Desc_Flush( C, sequence );
chameleon_sequence_wait( chamctxt, sequence );
CHAMELEON_zgemm_WS_Free( ws );
status = sequence->status;
chameleon_sequence_destroy( chamctxt, sequence );
return status;
......@@ -461,11 +469,13 @@ int CHAMELEON_zgemm_Tile( cham_trans_t transA, cham_trans_t transB,
*
*/
int CHAMELEON_zgemm_Tile_Async( cham_trans_t transA, cham_trans_t transB,
CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
CHAMELEON_Complex64_t beta, CHAM_desc_t *C,
RUNTIME_sequence_t *sequence, RUNTIME_request_t *request )
CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B,
CHAMELEON_Complex64_t beta, CHAM_desc_t *C,
void *user_ws,
RUNTIME_sequence_t *sequence, RUNTIME_request_t *request )
{
CHAM_context_t *chamctxt;
struct chameleon_pzgemm_s *ws;
int M, N, K;
int Am, An, Ai, Aj, Amb, Anb;
int Bm, Bn, Bi, Bj, Bmb, Bnb;
......@@ -570,7 +580,21 @@ int CHAMELEON_zgemm_Tile_Async( cham_trans_t transA, cham_trans_t transB,
return CHAMELEON_SUCCESS;
}
chameleon_pzgemm( transA, transB, alpha, A, B, beta, C, sequence, request );
if ( user_ws == NULL ) {
ws = CHAMELEON_zgemm_WS_Alloc( transA, transB, A, B, C );
}
else {
ws = user_ws;
}
chameleon_pzgemm( ws, transA, transB, alpha, A, B, beta, C, sequence, request );
if ( user_ws == NULL ) {
CHAMELEON_Desc_Flush( A, sequence );
CHAMELEON_Desc_Flush( B, sequence );
CHAMELEON_Desc_Flush( C, sequence );
chameleon_sequence_wait( chamctxt, sequence );
CHAMELEON_zgemm_WS_Free( ws );
}
return CHAMELEON_SUCCESS;
}
......@@ -25,6 +25,15 @@
#ifndef _compute_z_h_
#define _compute_z_h_
/**
* @brief Data structure to handle the GEMM workspaces
*/
struct chameleon_pzgemm_s {
int summa;
CHAM_desc_t WA;
CHAM_desc_t WB;
};
/**
* Declarations of internal sequential functions
*/
......@@ -40,7 +49,7 @@ void chameleon_pzgebrd_ge2gb( int genD, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_des
void chameleon_pzgelqf( int genD, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *D, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
void chameleon_pzgelqfrh( int genD, int BS, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *D, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
void chameleon_pzgenm2( double tol, const CHAM_desc_t *A, double *result, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
void chameleon_pzgemm(cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
void chameleon_pzgemm( struct chameleon_pzgemm_s *options, cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
void chameleon_pzgepdf_qdwh( cham_mtxtype_t trans, CHAM_desc_t *descU, CHAM_desc_t *descH, gepdf_info_t *info, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
void chameleon_pzgepdf_qr( int genD, int doqr, int optid, const libhqr_tree_t *qrtreeT, const libhqr_tree_t *qrtreeB, CHAM_desc_t *A1, CHAM_desc_t *TS1, CHAM_desc_t *TT1, CHAM_desc_t *D1, CHAM_desc_t *Q1, CHAM_desc_t *A2, CHAM_desc_t *TS2, CHAM_desc_t *TT2, CHAM_desc_t *D2, CHAM_desc_t *Q2, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
void chameleon_pzgeqrf( int genD, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *D, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
......
......@@ -192,7 +192,7 @@ int CHAMELEON_zgelqf_Tile_Async(CHAM_desc_t *A, CHAM_desc_t *T, RUNTIME_sequence
int CHAMELEON_zgelqs_Tile_Async(CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *B, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
int CHAMELEON_zgels_Tile_Async(cham_trans_t trans, CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *B, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
int CHAMELEON_zgenm2_Tile_Async( double tol, CHAM_desc_t *A, double *value, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
int CHAMELEON_zgemm_Tile_Async(cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
int CHAMELEON_zgemm_Tile_Async(cham_trans_t transA, cham_trans_t transB, CHAMELEON_Complex64_t alpha, CHAM_desc_t *A, CHAM_desc_t *B, CHAMELEON_Complex64_t beta, CHAM_desc_t *C, void *ws, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
int CHAMELEON_zgepdf_qdwh_Tile_Async( CHAM_desc_t *A, CHAM_desc_t *H, gepdf_info_t *info, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request );
int CHAMELEON_zgeqrf_Tile_Async(CHAM_desc_t *A, CHAM_desc_t *T, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
int CHAMELEON_zgeqrs_Tile_Async(CHAM_desc_t *A, CHAM_desc_t *T, CHAM_desc_t *B, RUNTIME_sequence_t *sequence, RUNTIME_request_t *request);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment