codelet_zunmqr.c 8.22 KB
Newer Older
1
2
/**
 *
3
4
 * @copyright (c) 2009-2014 The University of Tennessee and The University
 *                          of Tennessee Research Foundation.
5
 *                          All rights reserved.
6
 * @copyright (c) 2012-2016 Inria. All rights reserved.
7
 * @copyright (c) 2012-2014, 2016 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria, Univ. Bordeaux. All rights reserved.
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
 *
 **/

/**
 *
 * @file codelet_zunmqr.c
 *
 *  MORSE codelets kernel
 *  MORSE is a software package provided by Univ. of Tennessee,
 *  Univ. of California Berkeley and Univ. of Colorado Denver
 *
 * @version 2.5.0
 * @comment This file has been automatically generated
 *          from Plasma 2.5.0 for MORSE 1.0.0
 * @author Hatem Ltaief
 * @author Jakub Kurzak
 * @author Mathieu Faverge
 * @author Emmanuel Agullo
 * @author Cedric Castagnede
 * @date 2010-11-15
 * @precisions normal z -> c d s
 *
 **/
PRUVOST Florent's avatar
PRUVOST Florent committed
31

32
#include "runtime/starpu/include/morse_starpu.h"
33
#include "runtime/starpu/include/runtime_codelet_z.h"
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

/**
 *
 * @ingroup CORE_MORSE_Complex64_t
 *
 *  CORE_zunmqr overwrites the general complex M-by-N tile C with
 *
 *                    SIDE = 'L'     SIDE = 'R'
 *    TRANS = 'N':      Q * C          C * Q
 *    TRANS = 'C':      Q**H * C       C * Q**H
 *
 *  where Q is a complex unitary matrix defined as the product of k
 *  elementary reflectors
 *
 *    Q = H(1) H(2) . . . H(k)
 *
 *  as returned by CORE_zgeqrt. Q is of order M if SIDE = 'L' and of order N
 *  if SIDE = 'R'.
 *
 *******************************************************************************
 *
 * @param[in] side
 *         @arg MorseLeft  : apply Q or Q**H from the Left;
 *         @arg MorseRight : apply Q or Q**H from the Right.
 *
 * @param[in] trans
 *         @arg MorseNoTrans   :  No transpose, apply Q;
 *         @arg MorseConjTrans :  Transpose, apply Q**H.
 *
 * @param[in] M
 *         The number of rows of the tile C.  M >= 0.
 *
 * @param[in] N
 *         The number of columns of the tile C.  N >= 0.
 *
 * @param[in] K
 *         The number of elementary reflectors whose product defines
 *         the matrix Q.
 *         If SIDE = MorseLeft,  M >= K >= 0;
 *         if SIDE = MorseRight, N >= K >= 0.
 *
 * @param[in] IB
 *         The inner-blocking size.  IB >= 0.
 *
 * @param[in] A
 *         Dimension:  (LDA,K)
 *         The i-th column must contain the vector which defines the
 *         elementary reflector H(i), for i = 1,2,...,k, as returned by
 *         CORE_zgeqrt in the first k columns of its array argument A.
 *
 * @param[in] LDA
 *         The leading dimension of the array A.
 *         If SIDE = MorseLeft,  LDA >= max(1,M);
 *         if SIDE = MorseRight, LDA >= max(1,N).
 *
 * @param[in] T
 *         The IB-by-K triangular factor T of the block reflector.
 *         T is upper triangular by block (economic storage);
 *         The rest of the array is not referenced.
 *
 * @param[in] LDT
 *         The leading dimension of the array T. LDT >= IB.
 *
 * @param[in,out] C
 *         On entry, the M-by-N tile C.
 *         On exit, C is overwritten by Q*C or Q**T*C or C*Q**T or C*Q.
 *
 * @param[in] LDC
 *         The leading dimension of the array C. LDC >= max(1,M).
 *
 * @param[in,out] WORK
 *         On exit, if INFO = 0, WORK(1) returns the optimal LDWORK.
 *
 * @param[in] LDWORK
 *         The dimension of the array WORK.
 *         If SIDE = MorseLeft,  LDWORK >= max(1,N);
 *         if SIDE = MorseRight, LDWORK >= max(1,M).
 *
 *******************************************************************************
 *
 * @return
 *          \retval MORSE_SUCCESS successful exit
 *          \retval <0 if -i, the i-th argument had an illegal value
 *
 ******************************************************************************/

120
void MORSE_TASK_zunmqr(const MORSE_option_t *options,
121
122
                       MORSE_enum side, MORSE_enum trans,
                       int m, int n, int k, int ib, int nb,
123
124
125
                       const MORSE_desc_t *A, int Am, int An, int lda,
                       const MORSE_desc_t *T, int Tm, int Tn, int ldt,
                       const MORSE_desc_t *C, int Cm, int Cn, int ldc)
126
127
128
{
    struct starpu_codelet *codelet = &cl_zunmqr;
    void (*callback)(void*) = options->profiling ? cl_zunmqr_callback : NULL;
129
130
131
132
133
    int sizeA = lda*k;
    int sizeT = ldt*n;
    int sizeC = ldc*n;
    int execution_rank = C->get_rankof( C, Cm, Cn );
    int rank_changed=0;
134
    (void)execution_rank;
135

136
    /*  force execution on the rank owning the largest data (tile) */
137
138
139
140
141
142
143
144
145
146
147
148
149
    int threshold;
    char* env = getenv("MORSE_COMM_FACTOR_THRESHOLD");
    if (env != NULL)
        threshold = (unsigned)atoi(env);
    else
        threshold = 10;
    if ( sizeA > threshold*sizeC ){
        execution_rank = A->get_rankof( A, Am, An );
        rank_changed = 1;
    }else if( sizeT > threshold*sizeC ){
        execution_rank = T->get_rankof( T, Tm, Tn );
        rank_changed = 1;
    }
150
151
152

    if ( morse_desc_islocal( A, Am, An ) ||
         morse_desc_islocal( T, Tm, Tn ) ||
153
154
         morse_desc_islocal( C, Cm, Cn ) ||
         rank_changed )
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    {
        starpu_insert_task(
            codelet,
            STARPU_VALUE,    &side,              sizeof(MORSE_enum),
            STARPU_VALUE,    &trans,             sizeof(MORSE_enum),
            STARPU_VALUE,    &m,                 sizeof(int),
            STARPU_VALUE,    &n,                 sizeof(int),
            STARPU_VALUE,    &k,                 sizeof(int),
            STARPU_VALUE,    &ib,                sizeof(int),
            STARPU_R,         RTBLKADDR(A, MORSE_Complex64_t, Am, An),
            STARPU_VALUE,    &lda,               sizeof(int),
            STARPU_R,         RTBLKADDR(T, MORSE_Complex64_t, Tm, Tn),
            STARPU_VALUE,    &ldt,               sizeof(int),
            STARPU_RW,        RTBLKADDR(C, MORSE_Complex64_t, Cm, Cn),
            STARPU_VALUE,    &ldc,               sizeof(int),
            /* ib * nb */
            STARPU_SCRATCH,   options->ws_worker,
            STARPU_VALUE,    &nb,                sizeof(int),
            STARPU_PRIORITY,  options->priority,
            STARPU_CALLBACK,  callback,
175
            STARPU_NAME,      "zunmqr",
176
177
178
#if defined(CHAMELEON_USE_MPI)
            STARPU_EXECUTE_ON_NODE, execution_rank,
#endif
179
180
181
182
183
            0);
    }
}


184
#if !defined(CHAMELEON_SIMULATION)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
static void cl_zunmqr_cpu_func(void *descr[], void *cl_arg)
{
    MORSE_enum side;
    MORSE_enum trans;
    int m;
    int n;
    int k;
    int ib;
    MORSE_Complex64_t *A;
    int lda;
    MORSE_Complex64_t *T;
    int ldt;
    MORSE_Complex64_t *C;
    int ldc;
    MORSE_Complex64_t *WORK;
    int ldwork;

    A    = (MORSE_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[0]);
    T    = (MORSE_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[1]);
    C    = (MORSE_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[2]);
    WORK = (MORSE_Complex64_t *)STARPU_MATRIX_GET_PTR(descr[3]); /* ib * nb */

    starpu_codelet_unpack_args(cl_arg, &side, &trans, &m, &n, &k, &ib,
                               &lda, &ldt, &ldc, &ldwork);

    CORE_zunmqr(side, trans, m, n, k, ib,
                A, lda, T, ldt, C, ldc, WORK, ldwork);
}

214
#if defined(CHAMELEON_USE_CUDA)
215
216
217
218
219
220
221
222
223
224
225
226
static void cl_zunmqr_cuda_func(void *descr[], void *cl_arg)
{
    MORSE_starpu_ws_t *d_work;
    MORSE_enum side;
    MORSE_enum trans;
    int m;
    int n;
    int k;
    int ib;
    cuDoubleComplex *A, *T, *C, *WORK;
    int lda, ldt, ldc, ldwork;
    int info = 0;
Mathieu Faverge's avatar
Mathieu Faverge committed
227
    CUstream stream;
228
229
230
231
232
233
234
235
236

    starpu_codelet_unpack_args(cl_arg, &side, &trans, &m, &n, &k, &ib,
                               &lda, &ldt, &ldc, &ldwork);

    A    = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[0]);
    T    = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[1]);
    C    = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[2]);
    WORK = (cuDoubleComplex *)STARPU_MATRIX_GET_PTR(descr[3]); /* ib * nb */

237
238
239
    stream = starpu_cuda_get_local_stream();
    cublasSetKernelStream( stream );

240
241
    CUDA_zunmqrt(
            side, trans, m, n, k, ib,
242
            A, lda, T, ldt, C, ldc, WORK, ldwork, stream );
243

244
245
#ifndef STARPU_CUDA_ASYNC
    cudaStreamSynchronize( stream );
246
#endif
247
248
}
#endif /* defined(CHAMELEON_USE_CUDA) */
249
#endif /* !defined(CHAMELEON_SIMULATION) */
250
251
252
253

/*
 * Codelet definition
 */
254
#if defined(CHAMELEON_USE_CUDA)
255
256
257
258
CODELETS(zunmqr, 4, cl_zunmqr_cpu_func, cl_zunmqr_cuda_func, 0)
#else
CODELETS_CPU(zunmqr, 4, cl_zunmqr_cpu_func)
#endif