BlockWP.hpp 4.17 KB
Newer Older
1 2
#ifndef FABULOUS_BLOCK_WP_HPP
#define FABULOUS_BLOCK_WP_HPP
3

4
namespace fabulous {
5
template<class> class BlockWP;
6 7
};

8
#include "fabulous/kernel/blas.hpp"
9
#include "fabulous/data/Block.hpp"
10

11
namespace fabulous {
12

13
/**
14 15 16
 * \brief Represent a block of datas containing P and W concatenated
 *
 * This use a cursor to know where P ends and W starts is needed.
17
 */
18 19
template< class S >
class BlockWP : public Block<S>
20
{
21
private:
22
    int _cursor; /**< number of vector in P part */
23

24
public:
25
    FABULOUS_INHERITS_BLOCK(S);
26

MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
27 28 29 30 31 32
    BlockWP():
        super{},
        _cursor(0)
    {
    }

33 34 35
    BlockWP(int dim, int nbRHS):
        super{dim, nbRHS},
        _cursor{0}
36
    {
37 38 39
    }

    /**
40
     * \brief W part (candidate for base expansion)
41
     */
42
    Block<S> get_W()
43
    {
44 45
        int nb_col = get_nb_col();
        return sub_block(0, _cursor, get_nb_row(), nb_col-_cursor);
46 47 48
    }

    /**
49
     * \brief P part (candidate part that take into account discarded directions)
50
     */
51
    Block<S> get_P()
52
    {
53
        return sub_block(0, 0, get_nb_row(), _cursor);
54 55
    }

56 57 58 59 60 61 62 63 64 65 66 67 68 69
    S *get_P_ptr()
    {
        return get_ptr();
    }

    S *get_W_ptr()
    {
        return get_ptr(0, _cursor);
    }

    int get_size_P() const { return _cursor; }
    int get_size_W() const { return get_nb_col()-_cursor; }
    void increase_size_P(int size) { _cursor += size; }

70
    /**
71 72
     * \brief Compute \f$C = P^{H} * W\f$ and then \f$W = W-C*P\f$
     * (no user reduction used in this function; cannot be used in distributed problem)
73
     *
74
     * \param[out] C the C block of the projected matrix to be filled
75
     */
76
    void compute_C(Block<S> &C)
77
    {
78 79 80
        if (!_cursor)
            return;

81
        lapacke::Tgemm( //  C := P^{H} * \widetilde{W}
82 83 84 85 86 87
            get_size_P(), get_size_W(), get_nb_row(),
            get_P_ptr(), get_leading_dim(),
            get_W_ptr(), get_leading_dim(),
            C.get_ptr(), C.get_leading_dim(),
            S{1.0}, S{0.0}
        );
88
        lapacke::gemm( // ~W  = ~W - P_j*C
89 90 91 92 93 94
            get_nb_row(), get_size_W(), get_size_P(),
            get_P_ptr(), get_leading_dim(),
            C.get_ptr(), C.get_leading_dim(),
            get_W_ptr(), get_leading_dim(),
            S{-1.0}, S{1.0}
        );
95 96 97
    }

    /**
98 99 100
     * \brief perform \f$ V_{j+1} = [P_{j-1},\widetilde{W}] * \mathbb{W}_1 \f$
     * \param[in] W1 \f$ \mathbb{W}_1 \f$
     * \param[out] V \f$ V_{j+1} \f$
101
     */
102
    void compute_V(const Block<S> &W1, Block<S> &V) const
103
    {
104
        int nb_kept_direction = W1.get_nb_col();
105
        lapacke::gemm( // V_{j+1} :=  [P_{j-1},~W] * \W_1
106 107 108 109 110 111
            get_nb_row(), nb_kept_direction, get_nb_col(),
            get_ptr(), get_leading_dim(),
            W1.get_ptr(), W1.get_leading_dim(),
            V.get_ptr(), V.get_leading_dim(),
            S{1.0}, S{0.0}
        );
112
    }
113

114 115 116 117 118
    /**
     * \brief perform \f$ P_{j} = [P_{j-1},\widetilde{W}] * \mathbb{W}_2 \f$
     * \param[in] W2 \f$ \mathbb{W}_2 \f$
     */
    void update_P(const Block<S> &W2)
119
    {
120 121
        int nb_discarded_direction = W2.get_nb_col();
        Block<S> tmp{get_nb_row(), nb_discarded_direction};
122
        lapacke::gemm( // P_{j} :=  [P_{j-1},~W] * \W_1
123 124 125 126 127 128 129 130 131 132
            get_nb_row(), nb_discarded_direction, get_nb_col(),
            get_ptr(), get_leading_dim(),
            W2.get_ptr(), W2.get_leading_dim(),
            tmp.get_ptr(), tmp.get_leading_dim(),
            S{1.0}, S{0.0}
        );

        _cursor = nb_discarded_direction;
        Block<S> P = get_P();
        P.copy(tmp);
133 134
    }

135 136 137 138 139 140 141 142 143
    /**
     * \brief Check othogonality between P and W
     */
    void check_ortho_WP(std::string name="WP")
    {
        if (!_cursor)
            return;

        super tmp{get_size_P(), get_size_W()};
144
        lapacke::Tgemm(
145 146 147 148 149 150 151 152 153 154
            get_size_P(), get_size_W(), get_nb_row(),
            get_P_ptr(), get_leading_dim(),
            get_W_ptr(), get_leading_dim(),
            tmp.get_ptr(), tmp.get_leading_dim(),
            S{1.0}, S{0.0}
        );
        auto MM = tmp.get_min_max_norm();
        FABULOUS_DEBUG("CHECK_ORTHO WP["<<Color::green<<name<<Color::reset<<"]: (B^H*B - I) "
                       <<"Min="<<MM.first<<" Max="<<MM.second);
    }
155

156
}; // end class BlockWP
157

158
}; // end namespace fabulous
159

160
#endif // FABULOUS_BLOCK_WP_HPP