diff --git a/compute/pzgetrf_incpiv.c b/compute/pzgetrf_incpiv.c index 092d258faf2ec47c7c630c9d1767bcbec1bb3774..b5c23f6da8b8e085a5d1c09856b029fc4d0beb5e 100644 --- a/compute/pzgetrf_incpiv.c +++ b/compute/pzgetrf_incpiv.c @@ -35,9 +35,9 @@ #define A(_m_,_n_) A, _m_, _n_ #if defined(CHAMELEON_COPY_DIAG) -#define DIAG(_k_) DIAG, _k_, 0 +#define D(k) D, k, 0 #else -#define DIAG(_k_) A, _k_, _k_ +#define D(k) A, k, k #endif #define L(_m_,_n_) L, _m_, _n_ #define IPIV(_m_,_n_) &(IPIV[(int64_t)A->mb*((int64_t)(_m_)+(int64_t)A->mt*(int64_t)(_n_))]) @@ -45,13 +45,13 @@ /******************************************************************************* * Parallel tile LU factorization - dynamic scheduling **/ -void morse_pzgetrf_incpiv(MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, - MORSE_sequence_t *sequence, MORSE_request_t *request) +void morse_pzgetrf_incpiv( MORSE_desc_t *A, MORSE_desc_t *L, MORSE_desc_t *D, int *IPIV, + MORSE_sequence_t *sequence, MORSE_request_t *request ) { - MORSE_desc_t *DIAG = NULL; MORSE_context_t *morse; MORSE_option_t options; - size_t h_work_size, d_work_size; + size_t ws_worker = 0; + size_t ws_host = 0; int k, m, n; int ldak, ldam; @@ -65,14 +65,19 @@ void morse_pzgetrf_incpiv(MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, RUNTIME_options_init(&options, morse, sequence, request); ib = MORSE_IB; - h_work_size = sizeof(MORSE_Complex64_t)*( ib*L->nb ); - d_work_size = 0; - RUNTIME_options_ws_alloc( &options, h_work_size, d_work_size ); + /* + * zgetrf_incpiv = 0 + * zgessm = 0 + * ztstrf = A->mb * ib + * zssssm = 0 + */ + ws_worker = A->mb * ib; - /* necessary to avoid dependencies between tasks regarding the diag tile */ - DIAG = (MORSE_desc_t*)malloc(sizeof(MORSE_desc_t)); - morse_zdesc_alloc_diag(*DIAG, A->mb, A->nb, chameleon_min(A->m, A->n), A->nb, 0, 0, chameleon_min(A->m, A->n), A->nb, A->p, A->q); + ws_worker *= sizeof(MORSE_Complex64_t); + ws_host *= sizeof(MORSE_Complex64_t); + + RUNTIME_options_ws_alloc( &options, ws_worker, ws_host ); for (k = 0; k < minMNT; k++) { RUNTIME_iteration_push(morse, k); @@ -94,7 +99,7 @@ void morse_pzgetrf_incpiv(MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, &options, MorseUpperLower, tempkm, tempkn, A->nb, A(k, k), ldak, - DIAG(k), ldak); + D(k), ldak); #endif } @@ -105,7 +110,7 @@ void morse_pzgetrf_incpiv(MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, tempkm, tempnn, tempkm, ib, L->nb, IPIV(k, k), L(k, k), L->mb, - DIAG(k), ldak, + D(k), ldak, A(k, n), ldak); } for (m = k+1; m < A->mt; m++) { @@ -138,7 +143,5 @@ void morse_pzgetrf_incpiv(MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, RUNTIME_options_ws_free(&options); RUNTIME_options_finalize(&options, morse); - - morse_desc_mat_free(DIAG); - free(DIAG); + (void)D; } diff --git a/compute/pztrsmpl.c b/compute/pztrsmpl.c index 7314ad60fe731751f16e02b72e1987b5a2f78191..a96db6292f8a7bcb71224a3b88e6e55d2f409b54 100644 --- a/compute/pztrsmpl.c +++ b/compute/pztrsmpl.c @@ -37,8 +37,8 @@ /******************************************************************************* * Parallel forward substitution for tile LU - dynamic scheduling **/ -void morse_pztrsmpl(MORSE_desc_t *A, MORSE_desc_t *B, MORSE_desc_t *L, int *IPIV, - MORSE_sequence_t *sequence, MORSE_request_t *request) +void morse_pztrsmpl( MORSE_desc_t *A, MORSE_desc_t *B, MORSE_desc_t *L, int *IPIV, + MORSE_sequence_t *sequence, MORSE_request_t *request ) { MORSE_context_t *morse; MORSE_option_t options; diff --git a/compute/zgesv_incpiv.c b/compute/zgesv_incpiv.c index 700f6b02e960c8723090f567f33107b467f6dddc..9808488ccec1af2196c74221bbc77c1fced063a3 100644 --- a/compute/zgesv_incpiv.c +++ b/compute/zgesv_incpiv.c @@ -267,6 +267,7 @@ int MORSE_zgesv_incpiv_Tile_Async( MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, MORSE_sequence_t *sequence, MORSE_request_t *request ) { MORSE_context_t *morse; + MORSE_desc_t D, *Dptr = NULL; morse = morse_context_self(); if (morse == NULL) { @@ -313,11 +314,28 @@ int MORSE_zgesv_incpiv_Tile_Async( MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, return MORSE_SUCCESS; */ - morse_pzgetrf_incpiv( A, L, IPIV, sequence, request ); +#if defined(CHAMELEON_COPY_DIAG) + { + int n = chameleon_min(A->mt, A->nt) * A->nb; + morse_zdesc_alloc(D, A->mb, A->nb, A->m, n, 0, 0, A->m, n, ); + Dptr = &D; + } +#endif + + morse_pzgetrf_incpiv( A, L, Dptr, IPIV, sequence, request ); morse_pztrsmpl( A, B, L, IPIV, sequence, request ); morse_pztrsm( MorseLeft, MorseUpper, MorseNoTrans, MorseNonUnit, (MORSE_Complex64_t)1.0, A, B, sequence, request ); + if (Dptr != NULL) { + MORSE_Desc_Flush( A, sequence ); + MORSE_Desc_Flush( L, sequence ); + MORSE_Desc_Flush( Dptr, sequence ); + MORSE_Desc_Flush( B, sequence ); + morse_sequence_wait( morse, sequence ); + morse_desc_mat_free( Dptr ); + } + (void)D; return MORSE_SUCCESS; } diff --git a/compute/zgetrf_incpiv.c b/compute/zgetrf_incpiv.c index a86915d081d5b92b07269f534348dc985c023f7d..63ead5685f62b5e6c955245aa9ab4cde94bb570d 100644 --- a/compute/zgetrf_incpiv.c +++ b/compute/zgetrf_incpiv.c @@ -244,6 +244,7 @@ int MORSE_zgetrf_incpiv_Tile_Async( MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, MORSE_sequence_t *sequence, MORSE_request_t *request ) { MORSE_context_t *morse; + MORSE_desc_t D, *Dptr = NULL; morse = morse_context_self(); if (morse == NULL) { @@ -286,7 +287,23 @@ int MORSE_zgetrf_incpiv_Tile_Async( MORSE_desc_t *A, MORSE_desc_t *L, int *IPIV, return MORSE_SUCCESS; */ - morse_pzgetrf_incpiv( A, L, IPIV, sequence, request ); +#if defined(CHAMELEON_COPY_DIAG) + { + int n = chameleon_min(A->mt, A->nt) * A->nb; + morse_zdesc_alloc(D, A->mb, A->nb, A->m, n, 0, 0, A->m, n, ); + Dptr = &D; + } +#endif + + morse_pzgetrf_incpiv( A, L, Dptr, IPIV, sequence, request ); + if (Dptr != NULL) { + MORSE_Desc_Flush( A, sequence ); + MORSE_Desc_Flush( L, sequence ); + MORSE_Desc_Flush( Dptr, sequence ); + morse_sequence_wait( morse, sequence ); + morse_desc_mat_free( Dptr ); + } + (void)D; return MORSE_SUCCESS; }