tpqrt.hpp 7.27 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#ifndef FABULOUS_TPQRT_HPP
#define FABULOUS_TPQRT_HPP

#include <vector>
#include "fabulous/data/Block.hpp"
#include "fabulous/utils/Arithmetic.hpp"
#include "fabulous/ext/lapacke.h"

namespace fabulous {
namespace lapacke {

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"

/****** TILED TRIANGULAR over PENTAGONAL QR FACTORIZATION (for QR+IB(+DR)) *********/

/**
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
18
 * \brief TRIANGULAR-over-PENTAGONAL QR-Factorisation (full)
19 20
 * \param m number of lines
 * \param n number of col
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
21 22 23 24 25 26 27 28 29
 * \param l the size of the trapezoidal part
 * \param nb tile size for the tiled factorization algorithm
 * \param[in,out] A upper triangular matrix N-by-N A
 * \param lda leading dim of A
 * \param[in,out] B pentagonal M-by-N matrix, overwritten with V
 * \param ldb leading dim of B
 * \param[out] T on output: extra factorization data (needed for ormqr and orgqr)
 *
 * Ref: https://software.intel.com/en-us/node/521034
30 31 32 33
 */
template< class S >
int tpqrt(int m, int n, int l, int nb, S *A, int lda, S *B, int ldb, Block<S> &T)
{
34
    FABULOUS_DISABLE_INVALID_ARITHMETIC(S);
35 36 37 38
    return -1;
}

/**
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
39 40 41 42
 * \brief TRIANGULAR-over-PENTAGONAL QR UPDATE kernel
 *
 * Apply Q to a matrix B where Q is the
 * orthogonal part obtained from Lapack GEQRF call
43 44 45 46 47
 *
 * \param trans char Can be 'N' or 'T', to know if we need Q^{H}*C or Q*C.
 * \param m Integer number of rows in C
 * \param n Integer number of columns in C
 * \param k Integer Number of elementary reflectors in Q
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
48 49 50 51 52 53 54 55 56
 * \param l the size of the trapezoidal part
 * \param nb same block size that was passed to the tpqrt kernel
 * \param V array returned by tpqrt
 * \param ldv leading dim of V
 * \param A array returned by tpqrt
 * \param lda leading dim of A
 * \param T Block as returned by tpqrt
 * \param B the MxN matrix to be multiplied
 * \param ldb leading dimension of B
57
 *
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
58
 * Ref: https://software.intel.com/en-us/node/521035
59 60 61 62 63 64 65
 */
template<class S>
int tpmqrt(char trans, int m, int n, int k, int l, int nb,
           const S *V, int ldv,
           const Block<S> &T, S *A, int lda, S *B, int ldb)

{
66
    FABULOUS_DISABLE_INVALID_ARITHMETIC(S);
67 68 69 70 71 72 73 74 75 76 77 78 79
    return -1;
}

#pragma GCC diagnostic pop

/****************************************/
/********  IMPLEMENTATIONS **************/
/****************************************/

#ifdef FABULOUS_LAPACKE_NANCHECK

/********  TPQRT **************/

80
#define FABULOUS_SPECIALIZE_TPQRT(_1, S_, P_, _4, _5, _6, prefix_, ...) \
81
    template<>                                                          \
82 83
    int tpqrt(int m, int n, int l, int nb,                              \
              S_ *A, int lda, S_ *B, int ldb, Block<S_> &T)             \
84
    {                                                                   \
85
        T = Block<S_>{nb, n};                                           \
86
        return LAPACKE_##prefix_##tpqrt(                                \
MIJIEUX Thomas's avatar
MIJIEUX Thomas committed
87
            LAPACK_COL_MAJOR, m, n, l, nb, A, lda, B, ldb,              \
88 89 90 91 92 93 94 95 96 97 98 99 100 101
            T.get_ptr(), T.get_leading_dim()                            \
        );                                                              \
    }                                                                   \

FABULOUS_ARITHMETIC_LIST(FABULOUS_SPECIALIZE_TPQRT);

/************* TPMQRT ********************/

#define FABULOUS_SPECIALIZE_TPMQRT(_1, S_, P_, _4, _5, _6, prefix_, ...) \
    template<>                                                          \
    int tpmqrt(char trans, int m, int n, int k, int l, int nb,          \
               const S_ *V, int ldv, const Block<S_> &T,                \
               S_ *A, int lda, S_ *B, int ldb)                          \
    {                                                                   \
102 103 104 105 106
        /* This function had buggy workspace allocation */              \
        /* in netlib lapack (fixed May 4 2017)  */                      \
        /* you may want do use the -DFABULOUS_LAPACKE_NANCHECK=OFF */   \
        /* cmake switch to use the other function if problems */        \
        /* do appear with your lapack distribution */                   \
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
        return LAPACKE_##prefix_##tpmqrt(                               \
            LAPACK_COL_MAJOR, 'L', trans, m, n, k, l, nb,               \
            V, ldv, T.get_ptr(), T.get_leading_dim(),                   \
            A, lda, B, ldb                                              \
        );                                                              \
    }

FABULOUS_ARITHMETIC_LIST(FABULOUS_SPECIALIZE_TPMQRT);

#else // FABULOUS_LAPACKE_NANCHECK

/********  TPQRT **************/

#define FABULOUS_SPECIALIZE_TPQRT(_1, S_, P_, _4, _5, _6, prefix_, ...) \
    template<>                                                          \
    int tpqrt(int m, int n, int l, int nb, S_ *A, int lda, S_ *B, int ldb, Block<S_> &T) \
    {                                                                   \
124
        T = Block<S_>{nb, n};                                           \
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        lapack_int info = 0;                                            \
        S_ *work = (S_*) LAPACKE_malloc( sizeof(S_) * FABULOUS_MAX(1,nb) * FABULOUS_MAX(1,n) ); \
        if (work == nullptr)                                            \
            return LAPACK_WORK_MEMORY_ERROR;                            \
        info = LAPACKE_##prefix_##tpqrt_work(                           \
            LAPACK_COL_MAJOR, m, n, l, nb, A, lda, B, ldb,              \
            T.get_ptr(), T.get_leading_dim(), work                      \
        );                                                              \
        LAPACKE_free( work );                                           \
        return info;                                                    \
    }                                                                   \

FABULOUS_ARITHMETIC_LIST(FABULOUS_SPECIALIZE_TPQRT);

/********  TPMQRT **************/

#define FABULOUS_SPECIALIZE_TPMQRT(_1, S_, P_, _4, _5, _6, prefix_, ...) \
    template<>                                                          \
    int tpmqrt(char trans, int m, int n, int k, int l, int nb,          \
               const S_ *V, int ldv, const Block<S_> &T,                \
               S_ *A, int lda, S_ *B, int ldb)                          \
    {                                                                   \
        lapack_int info = 0;                                            \
148
        S_ *work = (S_*) LAPACKE_malloc( sizeof(S_) * FABULOUS_MAX(1,n) * FABULOUS_MAX(1,nb) ); \
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        if (work == nullptr)                                            \
            return LAPACK_WORK_MEMORY_ERROR;                            \
        info = LAPACKE_##prefix_##tpmqrt_work(                          \
            LAPACK_COL_MAJOR, 'L', trans, m, n, k, l, nb,               \
            V, ldv, T.get_ptr(), T.get_leading_dim(),                   \
            A, lda, B, ldb, work                                        \
        );                                                              \
        LAPACKE_free( work );                                           \
        return info;                                                    \
    }                                                                   \

FABULOUS_ARITHMETIC_LIST(FABULOUS_SPECIALIZE_TPMQRT);

#endif // FABULOUS_LAPACKE_NANCHECK

}; // end namespace lapacke
}; // end namespace fabulous

#endif // FABULOUS_TPQRT_HPP