RUHE.hpp 11.1 KB
Newer Older
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
1 2
#ifndef FABULOUS_ORTHO_RUHE_HPP
#define FABULOUS_ORTHO_RUHE_HPP
3

4 5 6 7
namespace fabulous {
template<class HESS> class OrthogonalizerRuheSTD;
};

MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
8
#include "fabulous/utils/Traits.hpp"
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
9
#include "fabulous/data/Base.hpp"
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
10 11
#include "fabulous/data/Block.hpp"
#include "fabulous/orthogonalization/OrthoParam.hpp"
12 13
#include "fabulous/kernel/blas.hpp"
#include "fabulous/kernel/flops.hpp"
14 15 16

namespace fabulous {

17 18 19
/* **************************** RUHE ******************************** */

/**
20
 * \brief Orthogonalization methods for RUHE variant WITHOUT Inexact Breakdown
21
 */
22
template<class HESS>
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
23
class OrthogonalizerRuheSTD : public OrthoParam
24
{
25
private:
26
    int64_t _nb_flops;
27
    friend class Orthogonalizer;
28

MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
29
    OrthogonalizerRuheSTD(const OrthoParam &param):
30 31
        OrthoParam{param},
        _nb_flops{0}
32
    {
33 34
    }

35 36 37 38 39 40 41 42 43 44 45
    /**
     * \brief Arnoldi Ruhe version with CGS ortho
     */
    template< class Matrix, class S >
    void CGS(Matrix &A, Base<S> &base,
             Block<S> &H, Block<S> &R, // hessenberg
             Block<S> &W) // candidate
    {
        int W_size = W.get_nb_col();
        int dim = base.get_nb_row();
        int nb_vect_in_base = base.get_nb_vect();
46 47
        int block_size = base.get_block_size(base.get_nb_block()-1);
        FABULOUS_ASSERT(W_size == block_size);
48 49 50 51 52

        int ldh = H.get_leading_dim();
        int ldr = R.get_leading_dim();
        int ldw = W.get_leading_dim();
        int ldv = base.get_leading_dim();
53
        S *V = base.get_ptr();
54 55 56 57 58 59 60 61

        // Loop over vector in W block
        for (int k = 0; k < W_size; ++k) {
            S *W_k = W.get_vect(k);
            S *H_k = H.get_vect(k);
            S *R_k = R.get_vect(k);

            { // Ortho against Base
62 63
                _nb_flops += A.DotProduct(nb_vect_in_base, 1, V, ldv,
                                          W_k, ldw, H_k, ldh);
64
                lapacke::gemm( // W_k = W_k - V * H_k
65 66 67 68 69 70
                    dim, 1, nb_vect_in_base,
                    V, ldv,
                    H_k, ldh,
                    W_k, ldw,
                    S{-1.0}, S{1.0}
                );
71
                _nb_flops += lapacke::gemm_flops<S>(dim, 1, nb_vect_in_base);
72 73 74
            }

            { // Ortho against already processed W vectors
75 76
                _nb_flops += A.DotProduct(k, 1, W.get_ptr(), ldw,
                                          W_k, ldw, R_k, ldr);
77
                lapacke::gemm( // W_k = W_k - W_{0->k-1} * R_k
78 79
                    dim, 1, k,
                    W.get_ptr(), ldw,
80 81 82 83
                    R_k, ldr,
                    W_k, ldw,
                    S{-1.0}, S{1.0}
                );
84
                _nb_flops += lapacke::gemm_flops<S>(dim, 1, k);
85 86
            }
            auto n = W.get_norm(k, A);
87
            _nb_flops += W.get_last_flops();
88 89 90 91 92
            if (n == 0.0) {
                FABULOUS_FATAL_ERROR(
                    "Rank loss in block candidate for extending the Krylov Space" );
            }
            R_k[k] = n; // Hess[k+1,k] = norm(wj_k)
93
            lapacke::scal(dim, S{1.0} / n, W_k, 1); // W[k] /= norm(W[k])
94
            _nb_flops += lapacke::scal_flops<S>(dim);
95 96 97 98 99 100
        }
    }

    /**
     * \brief Arnoldi Ruhe version with CGS ortho
     */
101
    template<class Matrix, class S>
102 103 104 105 106 107 108 109
    void ICGS( Matrix &A, Base<S> &base,
               Block<S> &H, Block<S> &R, // hessenberg
               Block<S> &W) // candidate
    {
        int W_size = W.get_nb_col();
        int dim = base.get_nb_row();
        int nb_vect_in_base = base.get_nb_vect();
        int block_size = base.get_block_size(base.get_nb_block()-1);
110
        FABULOUS_ASSERT(W_size == block_size);
111 112 113

        int ldw = W.get_leading_dim();
        int ldv = base.get_leading_dim();
114
        S *V = base.get_ptr();
115 116 117 118 119 120 121 122 123 124

        // Loop over vector in W block
        for (int k = 0; k < W_size; ++k) {
            S *H_k = H.get_vect(k);
            S *W_k = W.get_vect(k);
            S *R_k = R.get_vect(k);

            for (int t = 0; t < _nb_iteration; ++t) { // Double ortho against base
                Block<S> tmp{nb_vect_in_base, 1};
                int ld = tmp.get_leading_dim();
125 126 127
                // tmp = V^{H} * Wj[k]
                _nb_flops += A.DotProduct(nb_vect_in_base, 1, V, ldv,
                                          W_k, ldw, tmp.get_ptr(), ld);
128
                lapacke::gemm( // W = W - Vm*H_k
129 130
                    dim, 1, nb_vect_in_base,
                    V, ldv,
131 132 133 134
                    tmp.get_ptr(), tmp.get_leading_dim(),
                    W_k, ldw,
                    S{-1.0}, S{1.0}
                );
135
                _nb_flops += lapacke::gemm_flops<S>(dim, 1, nb_vect_in_base);
136
                // H_k += tmp
137 138
                lapacke::axpy(nb_vect_in_base, S{1.0}, tmp.get_ptr(), 1, H_k, 1);
                _nb_flops += lapacke::axpy_flops<S>(nb_vect_in_base);
139 140
            }
            for (int t = 0; t < _nb_iteration; ++t) {
141
                // Double ortho against W_j{0->k-1} (already orthogonalized part of W)
142 143 144
                if (k != 0) {
                    Block<S> tmp{k, 1};
                    int ld = tmp.get_leading_dim();
145 146
                    _nb_flops += A.DotProduct(k, 1, W.get_ptr(), ldw,
                                              W_k, ldw, tmp.get_ptr(), ld);
147
                    lapacke::gemm(
148 149 150 151 152 153
                        dim, 1, k,
                        W.get_ptr(), ldw,
                        tmp.get_ptr(), tmp.get_leading_dim(),
                        W_k, ldw,
                        S{-1.0}, S{1.0}
                    );
154 155 156 157
                    _nb_flops += lapacke::gemm_flops<S>(dim, 1, k);
                    // R_k += tmp
                    lapacke::axpy(k, S{1.0}, tmp.get_ptr(), 1, R_k, 1);
                    _nb_flops += lapacke::axpy_flops<S>(k);
158
                }
159
            }
160

161
            auto n = W.get_norm(k, A);
162
            _nb_flops += W.get_last_flops();
163 164 165 166 167
            if (n == 0.0) {
                FABULOUS_FATAL_ERROR(
                    "Rank loss in block candidate for extending the Krylov Space" );
            }
            R_k[k] = n; // Hess[...] = norm(wj_k)
168
            lapacke::scal(dim, S{1.0} / n, W_k, 1); // W[k] /= norm(W[k])
169
            _nb_flops += lapacke::scal_flops<S>(dim);
170
        }
171 172
    }

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    /**
     * \brief Arnoldi Ruhe version with MGS ortho
     */
    template< class Matrix, class S >
    void MGS(Matrix &A, Base<S> &base,
             Block<S> &H, Block<S> &R, // hessenberg
             Block<S> &W) // candidate
    {
        int W_size = W.get_nb_col();
        int dim = base.get_nb_row();
        int nb_vect_in_base = base.get_nb_vect();
        int ldw = W.get_leading_dim();
        int ldv = base.get_leading_dim();

        // Loop over vector in W block
        for (int k = 0; k < W_size; ++k) {
            S *H_k = H.get_vect(k);
            S *W_k = W.get_vect(k);
            S *R_k = R.get_vect(k);

            for (int n = 0; n < nb_vect_in_base; ++n) { //Ortho against base
                // H_k[n] = dot( V_m[n], W[k] ) (compute orthogonalization coefficient)
195
                _nb_flops += A.DotProduct(1, 1, base.get_vect(n), ldv, W_k, ldw, H_k+n, 1);
196
                // W[k] = W[k] - H_k[n] * V_m[n]  ( apply orthogonalization to W[k] )
197
                lapacke::axpy(dim, -H_k[n], base.get_vect(n), 1, W_k, 1);
198
                _nb_flops += lapacke::axpy_flops<S>(dim);
199 200 201 202
            }
            for (int n = 0; n < k; ++n) { // Ortho against already computed vectors of W
                // H_k[nb_vect_in_base + n] = dot( W[n], W[k] )
                S *W_n = W.get_vect(n);
203
                _nb_flops += A.DotProduct(1, 1, W_n, ldw, W_k, ldw, R_k+n, 1);
204
                lapacke::axpy(dim, -R_k[n], W_n, 1, W_k, 1);
205
                _nb_flops += lapacke::axpy_flops<S>(dim);
206 207
            }
            auto n = W.get_norm(k, A);
208
            _nb_flops += W.get_last_flops();
209 210 211 212 213
            if (n == 0.0) {
                FABULOUS_FATAL_ERROR(
                    "Rank loss in block candidate for extending the Krylov Space" );
            }
            R_k[k] = n; // Hess[k+1,k] = norm(wj_k)
214
            lapacke::scal(dim, S{1.0} / n, W_k, 1); // W[k] /= norm(W[k])
215
            _nb_flops += lapacke::scal_flops<S>(dim);
216
        }
217 218
    }

219 220 221
    /**
     * \brief Arnoldi Ruhe version with IMGS ortho
     */
222
    template<class Matrix, class S>
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
    void IMGS(Matrix &A, Base<S> &base,
              Block<S> &H, Block<S> &R, // hessenberg
              Block<S> &W) // candidate
    {
        int W_size = W.get_nb_col();
        int nb_vect_in_base = base.get_nb_vect();
        int ldw = W.get_leading_dim();
        int ldv = base.get_leading_dim();
        int dim = base.get_nb_row();

        // Loop over vector in W block
        for (int k = 0; k < W_size; ++k) {
            S *H_k = H.get_vect(k);
            S *W_k = W.get_vect(k);
            S *R_k = R.get_vect(k);

            for (int n = 0; n < nb_vect_in_base; ++n) { //Ortho against Base
                H_k[n] = S{0.0};
                for (int t = 0; t < _nb_iteration; ++t) {
                    S tmp = S{0.0};
                    // H_k[n] = V[n]^{H} * Wj[k]
244 245
                    _nb_flops += A.DotProduct(1, 1, base.get_vect(n), ldv,
                                              W_k, ldw, &tmp, 1);
246
                    // Wj[k] = Wj[k] - H_k[n] * V[n]
247
                    lapacke::axpy(dim, -tmp, base.get_vect(n), 1, W_k, 1);
248
                    _nb_flops += lapacke::axpy_flops<S>(dim);
249 250
                    H_k[n] += tmp;
                }
251
            }
252 253 254 255
            for (int n = 0; n < k; ++n) {
                H_k[nb_vect_in_base + n] = S{0.0};
                for (int t = 0; t < _nb_iteration; ++t) {
                    S tmp{0.0};
256 257
                    _nb_flops += A.DotProduct(1, 1, W.get_vect(n), ldw,
                                              W_k, ldw, &tmp, 1);
258
                    lapacke::axpy(dim, -tmp, W.get_vect(n), 1, W_k, 1);
259
                    _nb_flops += lapacke::axpy_flops<S>(dim);
260 261
                    R_k[n] += tmp;
                }
262
            }
263
            auto n = W.get_norm(k, A);
264
            _nb_flops += W.get_last_flops();
265 266 267 268 269
            if (n == 0.0) {
                FABULOUS_FATAL_ERROR(
                    "Rank loss in block candidate for extending the Krylov Space" );
            }
            R_k[k] = n; // Hess[k+1,k] = norm(wj_k)
270
            lapacke::scal(dim, S{1.0} / n, W_k, 1); // W[k] /= norm(W[k])
271
            _nb_flops += lapacke::scal_flops<S>(dim);
272
        }
273 274
    }

275 276
    /* *************************** WRAPPER ************************* */
    template<class Matrix, class S>
277
    int64_t run(HESS &hess, Base<S> &base, Block<S> &W, Matrix &A)
278
    {
279
        _nb_flops = 0;
280 281 282 283 284 285 286 287 288 289 290 291
        hess.increase(W.get_nb_col());

        Block<S> H = hess.get_H();
        Block<S> R = hess.get_R();

        switch (_scheme) {
        case OrthoScheme::CGS:  CGS(  A, base, H, R, W);                   break;
        case OrthoScheme::ICGS: ICGS( A, base, H, R, W);                   break;
        case OrthoScheme::MGS:  MGS(  A, base, H, R, W);                   break;
        case OrthoScheme::IMGS: IMGS( A, base, H, R, W);                   break;
        default: FABULOUS_FATAL_ERROR("Invalid orthogonalization scheme"); break;
        }
292
        return _nb_flops;
293
    }
294

295
}; // end class OrthogonalizerRuheSTD
296

297
}; // end namespace fabulous;
298

MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
299
#endif // FABULOUS_ORTHO_RUHE_HPP