BlockWP.hpp 4.49 KB
Newer Older
1 2
#ifndef FABULOUS_BLOCK_WP_HPP
#define FABULOUS_BLOCK_WP_HPP
3

4
namespace fabulous {
5
template<class> class BlockWP;
6
}
7

8
#include "fabulous/kernel/blas.hpp"
9
#include "fabulous/kernel/flops.hpp"
10
#include "fabulous/data/Block.hpp"
11

12
namespace fabulous {
13

14
/**
15 16 17
 * \brief Represent a block of datas containing P and W concatenated
 *
 * This use a cursor to know where P ends and W starts is needed.
18
 */
19
template<class S>
20
class BlockWP : public Block<S>
21
{
22
private:
23
    int _cursor; /**< number of vector in P part */
24

25
public:
26
    FABULOUS_INHERITS_BLOCK(S);
27

MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
28 29 30 31 32 33
    BlockWP():
        super{},
        _cursor(0)
    {
    }

34 35 36
    BlockWP(int dim, int nbRHS):
        super{dim, nbRHS},
        _cursor{0}
37
    {
38 39 40
    }

    /**
41
     * \brief W part (candidate for base expansion)
42
     */
43
    Block<S> get_W()
44
    {
45 46
        int nb_col = get_nb_col();
        return sub_block(0, _cursor, get_nb_row(), nb_col-_cursor);
47 48 49
    }

    /**
50
     * \brief P part (candidate part that take into account discarded directions)
51
     */
52
    Block<S> get_P()
53
    {
54
        return sub_block(0, 0, get_nb_row(), _cursor);
55 56
    }

57 58 59 60 61 62 63 64 65 66 67 68 69 70
    S *get_P_ptr()
    {
        return get_ptr();
    }

    S *get_W_ptr()
    {
        return get_ptr(0, _cursor);
    }

    int get_size_P() const { return _cursor; }
    int get_size_W() const { return get_nb_col()-_cursor; }
    void increase_size_P(int size) { _cursor += size; }

71
    /**
72 73
     * \brief Compute \f$C = P^{H} * W\f$ and then \f$W = W-C*P\f$
     * (no user reduction used in this function; cannot be used in distributed problem)
74
     *
75
     * \param[out] C the C block of the projected matrix to be filled
76
     * \param[in] A the user callback matrix
77
     */
78 79
    template<class Matrix>
    int64_t compute_C(Block<S> &C, Matrix &A)
80
    {
81
        int64_t nb_flops = 0;
82
        if (!_cursor)
83
            return nb_flops;
84

85 86
        nb_flops += A.DotProduct( //  C := P^{H} * \widetilde{W}
            get_size_P(), get_size_W(),
87 88
            get_P_ptr(), get_leading_dim(),
            get_W_ptr(), get_leading_dim(),
89
            C.get_ptr(), C.get_leading_dim()
90
        );
91

92
        lapacke::gemm( // ~W  = ~W - P_j*C
93 94 95 96 97 98
            get_nb_row(), get_size_W(), get_size_P(),
            get_P_ptr(), get_leading_dim(),
            C.get_ptr(), C.get_leading_dim(),
            get_W_ptr(), get_leading_dim(),
            S{-1.0}, S{1.0}
        );
99
        namespace fps = lapacke::flops;
100
        nb_flops += fps::gemm<S>(get_nb_row(), get_size_W(), get_size_P());
101
        return nb_flops;
102 103 104
    }

    /**
105 106 107
     * \brief perform \f$ V_{j+1} = [P_{j-1},\widetilde{W}] * \mathbb{W}_1 \f$
     * \param[in] W1 \f$ \mathbb{W}_1 \f$
     * \param[out] V \f$ V_{j+1} \f$
108
     */
109
    void compute_V(const Block<S> &W1, Block<S> &V) const
110
    {
111
        int nb_kept_direction = W1.get_nb_col();
112
        lapacke::gemm( // V_{j+1} :=  [P_{j-1},~W] * \W_1
113 114 115 116 117 118
            get_nb_row(), nb_kept_direction, get_nb_col(),
            get_ptr(), get_leading_dim(),
            W1.get_ptr(), W1.get_leading_dim(),
            V.get_ptr(), V.get_leading_dim(),
            S{1.0}, S{0.0}
        );
119
    }
120

121 122 123 124
    /**
     * \brief perform \f$ P_{j} = [P_{j-1},\widetilde{W}] * \mathbb{W}_2 \f$
     * \param[in] W2 \f$ \mathbb{W}_2 \f$
     */
125
    int64_t update_P(const Block<S> &W2)
126
    {
127 128 129 130 131
        int M = this->get_nb_row();
        int N = W2.get_nb_col(); /* nb_discarded_direction */
        int K = this->get_nb_col();
        Block<S> tmp{M, N};

132
        lapacke::gemm( // P_{j} :=  [P_{j-1},~W] * \W_1
133 134
            M, N, K,
            this->get_ptr(), this->get_leading_dim(),
135 136 137 138 139
            W2.get_ptr(), W2.get_leading_dim(),
            tmp.get_ptr(), tmp.get_leading_dim(),
            S{1.0}, S{0.0}
        );

140 141
        _cursor = N;
        Block<S> P = this->get_P();
142
        P.copy(tmp);
143 144

        return lapacke::flops::gemm<S>(M, N, K);
145 146
    }

147 148 149 150 151 152 153 154 155
    /**
     * \brief Check othogonality between P and W
     */
    void check_ortho_WP(std::string name="WP")
    {
        if (!_cursor)
            return;

        super tmp{get_size_P(), get_size_W()};
156
        lapacke::Tgemm(
157 158 159 160 161 162 163 164 165 166
            get_size_P(), get_size_W(), get_nb_row(),
            get_P_ptr(), get_leading_dim(),
            get_W_ptr(), get_leading_dim(),
            tmp.get_ptr(), tmp.get_leading_dim(),
            S{1.0}, S{0.0}
        );
        auto MM = tmp.get_min_max_norm();
        FABULOUS_DEBUG("CHECK_ORTHO WP["<<Color::green<<name<<Color::reset<<"]: (B^H*B - I) "
                       <<"Min="<<MM.first<<" Max="<<MM.second);
    }
167

168
}; // end class BlockWP
169

170
} // end namespace fabulous
171

172
#endif // FABULOUS_BLOCK_WP_HPP