Commit 1f014a72 authored by MIJIEUX Thomas's avatar MIJIEUX Thomas

add flops namespace for flops evalutation functions

parent 3f08f506
......@@ -53,3 +53,4 @@
CLOSED: [2017-05-30 Tue 15:53]
* TODO distributed hessenberg (with chameleon)
* TODO rework C++ api for multiple algorithm
* TODO Add parameter for setting maximum base extension in IB
......@@ -176,7 +176,7 @@ public:
const S *ptr = get_vect(k);
S res = S{0.0};
lapacke::dot(_m, ptr, 1, ptr, 1, &res);
_last_flops = lapacke::dot_flops<S>(_m);
_last_flops = lapacke::flops::dot<S>(_m);
return std::sqrt(fabulous::real(res));
}
......
......@@ -87,7 +87,8 @@ public:
C.get_ptr(), C.get_leading_dim(),
S{1.0}, S{0.0}
);
nb_flops += lapacke::Tgemm_flops<S>(get_size_P(), get_size_W(), get_nb_row());
namespace fps = lapacke::flops;
nb_flops += fps::gemm<S>(get_size_P(), get_size_W(), get_nb_row());
lapacke::gemm( // ~W = ~W - P_j*C
get_nb_row(), get_size_W(), get_size_P(),
......@@ -96,7 +97,7 @@ public:
get_W_ptr(), get_leading_dim(),
S{-1.0}, S{1.0}
);
nb_flops += lapacke::gemm_flops<S>(get_nb_row(), get_size_W(), get_size_P());
nb_flops += fps::gemm<S>(get_nb_row(), get_size_W(), get_size_P());
return nb_flops;
}
......
......@@ -184,7 +184,7 @@ private:
MORSE_desc_t *T = _tau[i].get();
int err = chameleon::ormqr<S>(trans, A, T, C, _seq.get());
flops += lapacke::ormqr_flops<S>(_nbRHS, _nbRHS, _nbRHS);
flops += lapacke::flops::left_ormqr<S>(_nbRHS, _nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "ormqr 'step' err="<<err);
}
......@@ -195,7 +195,7 @@ private:
MORSE_desc_t *A = get_sub_hess(k, k, 2, 1);
int err = chameleon::geqrf<S>(A, tau, _seq.get());
flops += lapacke::geqrf_flops<S>(_nbRHS, _nbRHS);
flops += lapacke::flops::geqrf<S>(_nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "geqrf 'last block' err="<<err);
}
......@@ -206,7 +206,7 @@ private:
/* STEP 3: Apply Q^H generated at step 2 to last block of RHS */
MORSE_desc_t *C = get_sub_rhs(k, 0, 2, 1);
err = chameleon::ormqr<S>(trans, A, tau.get(), C, _seq.get());
flops += lapacke::ormqr_flops<S>(_nbRHS, _nbRHS, _nbRHS);
flops += lapacke::flops::left_ormqr<S>(_nbRHS, _nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "ormqr 'RHS' err="<<err);
}
......@@ -302,7 +302,7 @@ public:
MORSE_Sequence_Wait(_seq.get());
MORSE_Tile_to_Lapack(YYd, _Y.get_ptr(), _Y.get_leading_dim());
_solution_computed = true;
int64_t flops = lapacke::trsm_flops<S>(_nb_vect, _nbRHS);
int64_t flops = lapacke::flops::left_trsm<S>(_nb_vect, _nbRHS);
_logger.notify_least_square_end(flops);
}
......
......@@ -171,7 +171,7 @@ private:
MORSE_desc_t *T = _tau[i].get();
int err = chameleon::ormqr<S>(trans, A.get(), T, C.get(), _seq.get());
flops += lapacke::ormqr_flops<S>(_nbRHS, _nbRHS, _nbRHS);
flops += lapacke::flops::left_ormqr<S>(_nbRHS, _nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "ormqr 'step' err="<<err);
}
......@@ -182,7 +182,7 @@ private:
MorseDesc2<S> A = get_sub_hess(k, k, 2, 1);
int err = chameleon::geqrf<S>(A.get(), tau, _seq.get());
flops += lapacke::geqrf_flops<S>(_nbRHS, _nbRHS);
flops += lapacke::flops::geqrf<S>(_nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "geqrf 'last block' err="<<err);
}
......@@ -193,7 +193,7 @@ private:
/* STEP 3: Apply Q^H generated at step 2 to last block of RHS */
MorseDesc2<S> C = get_sub_rhs(k, 0, 2, 1);
err = chameleon::ormqr<S>(trans, A.get(), tau.get(), C.get(), _seq.get());
flops += lapacke::ormqr_flops<S>(_nbRHS, _nbRHS, _nbRHS);
flops += lapacke::flops::left_ormqr<S>(_nbRHS, _nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "ormqr 'RHS' err="<<err);
......@@ -284,7 +284,7 @@ public:
// compute the solution
chameleon::trsm<S>(R.get(), Lambda_tmp.get(), _seq.get());
int flops = lapacke::trsm_flops<S>(_nb_vect, _nbRHS);
int flops = lapacke::flops::left_trsm<S>(_nb_vect, _nbRHS);
MORSE_Tile_to_Lapack(Lambda_tmp.get(), _Y.get_ptr(), _Y.get_leading_dim());
......
......@@ -136,7 +136,7 @@ public:
}
_solution_computed = true;
int64_t flops = lapacke::gels_flops<S>(M, N, _nbRHS);
int64_t flops = lapacke::flops::gels<S>(M, N, _nbRHS);
_logger.notify_least_square_end(flops);
}
......
......@@ -133,7 +133,7 @@ public:
}
_solution_computed = true;
int64_t flops = lapacke::gels_flops<S>(M, N, _nbRHS);
int64_t flops = lapacke::flops::gels<S>(M, N, _nbRHS);
_logger.notify_least_square_end(flops);
}
......
......@@ -275,7 +275,7 @@ public:
FABULOUS_THROW(Kernel, "gels (least square) err="<<err);
}
_solution_computed = true;
int64_t flops = lapacke::gels_flops<S>(M, N, _nbRHS);
int64_t flops = lapacke::flops::gels<S>(M, N, _nbRHS);
_logger.notify_least_square_end(flops);
}
......
......@@ -340,7 +340,7 @@ public:
F.get_ptr(), F.get_leading_dim(),
_YY.get_ptr(), _YY.get_leading_dim()
);
int64_t flops = lapacke::gels_flops<S>(M, N, _nbRHS);
int64_t flops = lapacke::flops::gels<S>(M, N, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "gels (least square) err="<<err);
}
......@@ -378,10 +378,7 @@ public:
_logger.notify_least_square_begin();
Block<S> residual = _YY.sub_block(_nb_vect, 0, _nbRHS, _nbRHS);
auto MinMaxConv = residual.check_precision(epsilon);
int64_t flops = 2L*_nbRHS;
if ( is_real_t<S>::value ) {
flops = flops * 4L;
}
int64_t flops = _nbRHS*lapacke::flops::dot<S>(_nbRHS);
_logger.notify_least_square_end(flops);
return MinMaxConv;
}
......
......@@ -80,7 +80,7 @@ private:
A.get_ptr(), A.get_leading_dim(), tau,
C.get_ptr(), C.get_leading_dim()
);
flops += lapacke::ormqr_flops<S>(_nbRHS, _nbRHS, _nbRHS);
flops += lapacke::flops::left_ormqr<S>(_nbRHS, _nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "ormqr 'step' err="<<err);
}
......@@ -94,7 +94,7 @@ private:
int err = lapacke::geqrf( A.get_nb_row(), A.get_nb_col(),
A.get_ptr(), A.get_leading_dim(), tau);
flops += lapacke::geqrf_flops<S>(_nbRHS, _nbRHS);
flops += lapacke::flops::geqrf<S>(_nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "geqrf 'last block' err="<<err);
}
......@@ -106,7 +106,7 @@ private:
A.get_ptr(), A.get_leading_dim(), tau.data(),
C.get_ptr(), C.get_leading_dim()
);
flops += lapacke::ormqr_flops<S>(_nbRHS, _nbRHS, _nbRHS);
flops += lapacke::flops::left_ormqr<S>(_nbRHS, _nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "ormqr 'RHS' err="<<err);
}
......@@ -188,7 +188,7 @@ public:
_Y.get_ptr(), _Y.get_leading_dim()
);
_solution_computed = true;
int64_t flops = lapacke::trsm_flops<S>(M, NRHS);
int64_t flops = lapacke::flops::left_trsm<S>(M, NRHS);
_logger.notify_least_square_end(flops);
}
......
......@@ -99,7 +99,7 @@ private:
A.get_ptr(), A.get_leading_dim(), tau,
C.get_ptr(), C.get_leading_dim()
);
flops += lapacke::ormqr_flops<S>(_nbRHS, _nbRHS, _nbRHS);
flops += lapacke::flops::left_ormqr<S>(_nbRHS, _nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "ormqr 'step' err="<<err);
}
......@@ -111,7 +111,7 @@ private:
Block<S> A = this->qr_sub_block(k, k);
int err = lapacke::geqrf( A.get_nb_row(), A.get_nb_col(),
A.get_ptr(), A.get_leading_dim(), tau);
flops += lapacke::geqrf_flops<S>(_nbRHS, _nbRHS);
flops += lapacke::flops::geqrf<S>(_nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "geqrf 'last block' err="<<err);
}
......@@ -123,7 +123,7 @@ private:
A.get_ptr(), A.get_leading_dim(), tau.data(),
C.get_ptr(), C.get_leading_dim()
);
flops += lapacke::ormqr_flops<S>(_nbRHS, _nbRHS, _nbRHS);
flops += lapacke::flops::left_ormqr<S>(_nbRHS, _nbRHS, _nbRHS);
if (err != 0) {
FABULOUS_THROW(Kernel, "ormqr 'RHS' err="<<err);
}
......@@ -232,7 +232,7 @@ public:
_Y.get_ptr(), _Y.get_leading_dim()
);
_solution_computed = true;
int64_t flops = lapacke::trsm_flops<S>(M, NRHS);
int64_t flops = lapacke::flops::left_trsm<S>(M, NRHS);
_logger.notify_least_square_end(flops);
}
......
......@@ -29,6 +29,8 @@ namespace qr {
template<class S, class Matrix>
int64_t InPlaceQRFactoMGS_User(Block<S> &Q, Block<S> &R, const Matrix &A)
{
namespace fps = lapacke::flops;
int M = Q.get_nb_row();
int N = Q.get_nb_col();
int LD = Q.get_leading_dim();
......@@ -48,7 +50,7 @@ int64_t InPlaceQRFactoMGS_User(Block<S> &Q, Block<S> &R, const Matrix &A)
// Qj = Qj - dot(Qj,Qi)*Qi
lapacke::axpy(M, -dot, Q.get_vect(i), 1, Q.get_vect(j), 1);
nb_flops += lapacke::axpy_flops<S>(M);
nb_flops += fps::axpy<S>(M);
if (R.get_nb_col() != 0)
R.at(i, j) = dot;
}
......@@ -70,7 +72,7 @@ int64_t InPlaceQRFactoMGS_User(Block<S> &Q, Block<S> &R, const Matrix &A)
#endif
}
lapacke::scal(M, S{1.0} / norm, Q.get_vect(j), 1);
nb_flops += lapacke::scal_flops<S>(M);
nb_flops += fps::scal<S>(M);
// Qj = (1/n)Qj
}
return nb_flops;
......@@ -102,15 +104,16 @@ int64_t InPlaceQRFacto(Block<S> &Q, Block<S> &R)
assert( R.get_nb_row() == Q.get_nb_col() );
}
namespace fps = lapacke::flops;
std::vector<S> tau;
lapacke::geqrf(M, N, Q.get_ptr(), LD, tau);
int64_t nb_flops = lapacke::geqrf_flops<S>(M, N);
int64_t nb_flops = fps::geqrf<S>(M, N);
if (R.get_nb_col() != 0) {
lapacke::lacpy( 'U', M, N, Q.get_ptr(), LD,
R.get_ptr(), R.get_leading_dim());
}
lapacke::orgqr(M, N, N, Q.get_ptr(), LD, tau.data());
nb_flops += lapacke::orgqr_flops<S>(M, N, N);
nb_flops += fps::orgqr<S>(M, N, N);
return nb_flops;
}
......
......@@ -7,132 +7,114 @@
namespace fabulous {
namespace lapacke {
namespace flops {
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t Tgemm_flops(int64_t m, int64_t n, int64_t k)
{
return m*n*(2L*k + 2L);
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t Tgemm_flops(int64_t m, int64_t n, int64_t k)
{
return m*n*(8L*k + 12L);
}
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t gemm_flops(int64_t m, int64_t n, int64_t k)
{
return m*n*(2L*k + 2L);
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t gemm_flops(int64_t m, int64_t n, int64_t k)
{
return m*n*(8L*k+12L);
}
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t axpy_flops(int64_t m)
inline constexpr double flop_per_mul()
{
return 2L*m; /* m additions and m multiplications */
return 1.0;
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t axpy_flops(int64_t m)
inline constexpr double flop_per_mul()
{
return 8L*m; /* 2*m (complex additions) + 6*m (complex multiplication) */
return 6.0;
}
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t scal_flops(int64_t m)
inline constexpr double flop_per_add()
{
return m; /* m multiplications */
return 1.0;
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t scal_flops(int64_t m)
{
return 6L*m; /* m complex multiplications */
}
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t dot_flops(int64_t m)
{
return 2L*m; /* m additions and m multiplications */
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t dot_flops(int64_t m)
{
return 8L*m; /* 2*m (complex additions) + 6*m (complex multiplication) */
}
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t geqrf_flops(int64_t m, int64_t n)
{
FABULOUS_ASSERT( m >= n );
return (2L*m*n*n) - (2.0/3.0) * (n*n*n);
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t geqrf_flops(int64_t m, int64_t n)
{
FABULOUS_ASSERT( m >= n );
return (8L*m*n*n) - (8.0/3.0) * (n*n*n);
}
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t orgqr_flops(int64_t m, int64_t n, int64_t k)
{
FABULOUS_ASSERT( m >= n );
FABULOUS_ASSERT( n == k );
return (2L*m*n*k) - (2.0/3.0) * (k*k*k);
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t orgqr_flops(int64_t m, int64_t n, int64_t k)
{
FABULOUS_ASSERT( m >= n );
FABULOUS_ASSERT( n == k );
return (8L*m*n*k) - (8.0/3.0) * (k*k*k);
}
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t ormqr_flops(int64_t m, int64_t n, int64_t k)
{
FABULOUS_ASSERT( m >= n );
FABULOUS_ASSERT( n == k );
return (2L*m*n*k) - (2.0/3.0) * (k*k*k);
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t ormqr_flops(int64_t m, int64_t n, int64_t k)
{
FABULOUS_ASSERT( m >= n );
FABULOUS_ASSERT( n == k );
return (8L*m*n*k) - (8.0/3.0) * (k*k*k);
}
template<class S, class = enable_if_t<is_real_t<S>::value>>
int64_t trsm_flops(int64_t m, int64_t nrhs)
{
return 2L*m*m*nrhs;
}
template<class S, class = enable_if_t<is_complex_t<S>::value>, class=void>
int64_t trsm_flops(int64_t m, int64_t nrhs)
{
return 8L*m*m*nrhs;
}
inline constexpr double flop_per_add()
{
return 2.0;
}
#define FABULOUS_KERNEL_FLOPS_DEFINE_MNK(kernel_, additions_, multiplications_) \
inline constexpr double kernel_##_add(double M, double N, double K) { return (additions_); } \
inline constexpr double kernel_##_mul(double M, double N, double K) { return (multiplications_); } \
template<class S> \
inline constexpr double kernel_(double M, double N, double K) \
{ \
return flop_per_mul<S>() * kernel_##_mul(M, N, K) \
+ flop_per_add<S>() * kernel_##_add(M, N, K); \
} \
#define FABULOUS_KERNEL_FLOPS_DEFINE_MN(kernel_, additions_, multiplications_) \
inline constexpr double kernel_##_add(double M, double N) { return (additions_); } \
inline constexpr double kernel_##_mul(double M, double N) { return (multiplications_); } \
template<class S> \
inline constexpr double kernel_(double M, double N) \
{ \
return flop_per_mul<S>() * kernel_##_mul(M, N) \
+ flop_per_add<S>() * kernel_##_add(M, N); \
} \
#define FABULOUS_KERNEL_FLOPS_DEFINE_M(kernel_, additions_, multiplications_) \
inline constexpr double kernel_##_add(double M) { return (additions_); } \
inline constexpr double kernel_##_mul(double M) { return (multiplications_); } \
template<class S> \
inline constexpr double kernel_(double M) \
{ \
return flop_per_mul<S>() * kernel_##_mul(M) \
+ flop_per_add<S>() * kernel_##_add(M); \
} \
FABULOUS_KERNEL_FLOPS_DEFINE_MNK(gemm, M*N*K, M*N*K)
FABULOUS_KERNEL_FLOPS_DEFINE_M(axpy, M, M)
FABULOUS_KERNEL_FLOPS_DEFINE_M(dot, M, M)
FABULOUS_KERNEL_FLOPS_DEFINE_M(scal, ((void) M, 0.0), M)
FABULOUS_KERNEL_FLOPS_DEFINE_MN(
geqrf,
((M>N)
? (N * (N * ( 0.5-(1./3.)*N + M) + 5./6.))
: (M * (M * (-0.5-(1./3.)*M + N) + N + 5. / 6.))),
((M>N)
? (N * (N * ( 0.5-(1./3.) * N + M) + M + 23. / 6.))
: (M * (M * ( -0.5-(1./3.) * M + N) + 2.*N + 23. / 6.)))
)
FABULOUS_KERNEL_FLOPS_DEFINE_MNK(
orgqr,
(K * (2.* M*N + N - M + 1./3. + K * ( 2./3. * K - (M + N) ))),
(K * (2.* M*N + 2. * N - 5./3. + K * ( 2./3. * K - (M + N) - 1.)))
)
FABULOUS_KERNEL_FLOPS_DEFINE_MNK(
left_ormqr,
(K * N * (2.*M - K + 1.)),
(K * N * (2.*M - K + 2.))
)
FABULOUS_KERNEL_FLOPS_DEFINE_MNK(
right_ormqr,
(K * (2.*M*N - M*K + M)),
(K * (2.*M*N - M*K + M + N - 0.5*K + 0.5))
)
FABULOUS_KERNEL_FLOPS_DEFINE_MN(
left_trsm,
(N*M*(M - 1.0) * 0.5 ),
(N*M*(M + 1.0) * 0.5 )
)
FABULOUS_KERNEL_FLOPS_DEFINE_MN(
right_trsm,
(M*N*(N - 1.0) * 0.5 ),
(M*N*(N + 1.0) * 0.5 )
)
template<class S>
int64_t gels_flops(int64_t m, int64_t n, int64_t nrhs)
inline constexpr double gels(double M, double N, double NRHS)
{
FABULOUS_ASSERT( m >= n );
return geqrf_flops<S>(m, n) + ormqr_flops<S>(m, nrhs, n) + trsm_flops<S>(n, nrhs);
return geqrf<S>(M, N) + left_ormqr<S>(M, NRHS, N) + left_trsm<S>(N, NRHS);
}
} // end namespace flops
} // end namespace lapacke
} // end namespace fabulous
......
......@@ -52,6 +52,7 @@ private:
if (nb_vect == 0) {
return;
}
namespace fps = lapacke::flops;
_nb_flops += A.DotProduct( // Theta = V^{T} * W
nb_vect, W_size,
......@@ -68,7 +69,7 @@ private:
W.get_ptr(), W.get_leading_dim(),
S{-1.0}, S{1.0}
);
_nb_flops += lapacke::gemm_flops<S>(dim, W_size, nb_vect);
_nb_flops += fps::gemm<S>(dim, W_size, nb_vect);
Block<S> Zj = Z.get_Vj();
FABULOUS_ASSERT( Z.get_ptr() != Zj.get_ptr() );
......@@ -79,7 +80,7 @@ private:
Zj.get_ptr(), Z.get_leading_dim(),
S{-1.0}, S{1.0}
);
_nb_flops += lapacke::gemm_flops<S>(dim, W_size, nb_vect);
_nb_flops += fps::gemm<S>(dim, W_size, nb_vect);
}
/**
......@@ -101,6 +102,8 @@ private:
int W_size = W.get_nb_col();
Block<S> buf{_nbrhs_alloc, _nbrhs_alloc};
namespace fps = lapacke::flops;
// Loop over different blocks
for (int k = 0; k < nb_block; ++k) {
int size_block = V.get_block_size(k);
......@@ -121,7 +124,7 @@ private:
W.get_ptr(), W.get_leading_dim(),
S{-1.0}, S{1.0}
);
_nb_flops += lapacke::gemm_flops<S>(dim, W_size, size_block);
_nb_flops += fps::gemm<S>(dim, W_size, size_block);
Block<S> Zj = Z.get_Vj();
FABULOUS_ASSERT( Z.get_block_ptr(k) != Zj.get_ptr() );
......@@ -132,7 +135,7 @@ private:
Zj.get_ptr(), Z.get_leading_dim(),
S{-1.0}, S{1.0}
);
_nb_flops += lapacke::gemm_flops<S>(dim, W_size, size_block);
_nb_flops += fps::gemm<S>(dim, W_size, size_block);
}
}
......
......@@ -41,9 +41,11 @@ private:
int nb_vect = base.get_nb_vect();
int W_size = W.get_nb_col();
assert( H.get_nb_col() == W.get_nb_col() );
assert( H.get_nb_row() == base.get_nb_vect() );
assert( W.get_nb_row() == base.get_nb_row() );
namespace fps = lapacke::flops;
FABULOUS_ASSERT( H.get_nb_col() == W.get_nb_col() );
FABULOUS_ASSERT( H.get_nb_row() == base.get_nb_vect() );
FABULOUS_ASSERT( W.get_nb_row() == base.get_nb_row() );
lapacke::Tgemm( // H = Vm^{t}*W
nb_vect, W_size, dim,
......@@ -52,7 +54,7 @@ private:
H.get_ptr(), H.get_leading_dim(),
S{1.0}, S{0.0}
);
_nb_flops += lapacke::Tgemm_flops<S>(nb_vect, W_size, dim);
_nb_flops += fps::gemm<S>(nb_vect, W_size, dim);
lapacke::gemm( // W = W - Vm*H
dim, W_size, nb_vect,
......@@ -61,7 +63,7 @@ private:
W.get_ptr(), W.get_leading_dim(),
S{-1.0}, S{1.0}
);
_nb_flops += lapacke::gemm_flops<S>(dim, W_size, nb_vect);
_nb_flops += fps::gemm<S>(dim, W_size, nb_vect);
}
/**
......@@ -78,8 +80,9 @@ private:
int nb_vect = base.get_nb_vect();
int W_size = W.get_nb_col();
assert( H.get_nb_col() == W.get_nb_col() );
assert( H.get_nb_row() == base.get_nb_vect() );
namespace fps = lapacke::flops;
FABULOUS_ASSERT( H.get_nb_col() == W.get_nb_col() );
FABULOUS_ASSERT( H.get_nb_row() == base.get_nb_vect() );
// Block tmp to be added to Hess
Block<S> tmp{nb_vect, W_size};
......@@ -91,7 +94,7 @@ private:
tmp.get_ptr(), tmp.get_leading_dim(),
S{1.0}, S{0.0}
);
_nb_flops += lapacke::Tgemm_flops<S>(nb_vect, W_size, dim);
_nb_flops += fps::gemm<S>(nb_vect, W_size, dim);
lapacke::gemm( // W = W - Vm*H
dim, W_size, nb_vect,
......@@ -100,12 +103,12 @@ private:
W.get_ptr(), W.get_leading_dim(),
S{-1.0}, S{1.0}
);
_nb_flops += lapacke::gemm_flops<S>(dim, W_size, nb_vect);
_nb_flops += fps::gemm<S>(dim, W_size, nb_vect);
// Add tmp to hess
for (int j = 0; j < tmp.get_nb_col(); ++j) { // H = H + tmp;
lapacke::axpy(nb_vect, S{1.0}, tmp.get_vect(j), 1, H.get_vect(j), 1);
_nb_flops += lapacke::axpy_flops<S>(nb_vect);
_nb_flops += fps::axpy<S>(nb_vect);
}
}
......@@ -123,8 +126,9 @@ private:
int W_size = W.get_nb_col();
int nb_block = base.get_nb_block();
assert( H.get_nb_col() == W.get_nb_col() );
assert( H.get_nb_row() == base.get_nb_vect() );
namespace fps = lapacke::flops;
FABULOUS_ASSERT( H.get_nb_col() == W.get_nb_col() );
FABULOUS_ASSERT( H.get_nb_row() == base.get_nb_vect() );
// Loop over different blocks
int size_block_sum = 0;
......@@ -140,7 +144,7 @@ private:
H_k.get_ptr(), H_k.get_leading_dim(),
S{1.0}, S{0.0}
);
_nb_flops += lapacke::Tgemm_flops<S>(size_block, W_size, dim);
_nb_flops += fps::gemm<S>(size_block, W_size, dim);
lapacke::gemm( // W = W - V_i * H_{ij}
dim, W_size, size_block,
......@@ -149,7 +153,7 @@ private:
W.get_ptr(), W.get_leading_dim(),
S{-1.0}, S{1.0}
);
_nb_flops += lapacke::gemm_flops<S>(dim, W_size, size_block);
_nb_flops += fps::gemm<S>(dim, W_size, size_block);
}
}
......@@ -170,8 +174,9 @@ private:
int W_size = W.get_nb_col();
int nb_block = base.get_nb_block();
assert( H.get_nb_col() == W.get_nb_col() );
assert( H.get_nb_row() == base.get_nb_vect() );
namespace fps = lapacke::flops;
FABULOUS_ASSERT( H.get_nb_col() == W.get_nb_col() );
FABULOUS_ASSERT( H.get_nb_row() == base.get_nb_vect() );
//FABULOUS_DEBUG("base.nb_vect="<<base.get_nb_vect());
// Loop over different blocks
......@@ -188,7 +193,7 @@ private:
tmp.get_ptr(), tmp.get_leading_dim(),
S{1.0}, S{0.0}
);
_nb_flops += lapacke::Tgemm_flops<S>(size_block, W_size, dim);
_nb_flops += fps::gemm<S>(size_block, W_size, dim);
lapacke::gemm( // W = W - V_i * tmp
dim, W_size, size_block,
......@@ -197,7 +202,7 @@ private:
W.get_ptr(), W.get_leading_dim(),
S{-1.0}, S{1.0}
);
_nb_flops += lapacke::gemm_flops<S>(dim, W_size, size_block);
_nb_flops += fps::gemm<S>(dim, W_size, size_block);
Block<S> H_k = H.sub_block(size_block_sum, 0, size_block, W_size);
size_block_sum += size_block;
......@@ -206,7 +211,7 @@ private:
for (int j = 0; j < tmp.get_nb_col(); ++j) { // H_k = H_k + tmp
lapacke::axpy(
size_block, S{1.0}, tmp.get_vect(j), 1, H_k.get_vect(j), 1);
_nb_flops += lapacke::axpy_flops<S>(size_block);
_nb_flops += fps::axpy<S>(size_block);
}
}
}
......
......@@ -56,6 +56,8 @@ private:
int ldv = base.get_leading_dim();
S *V = base.get_ptr();
namespace fps = lapacke::flops;
// Loop over vector in W block
for (int k = 0; k < W_size; ++k) {
S *W_k = W.get_vect(k);
......@@ -72,7 +74,7 @@ private:
W_k, ldw,