Commit acd2e9ee authored by MIJIEUX Thomas's avatar MIJIEUX Thomas

Fix 2 minor little DR bugs

parent 6745ec31
......@@ -90,6 +90,14 @@ private:
o << restart_param;
}
template <class Matrix >
void check_params(Matrix &A, const Block<S> &X, const Block<S> B)
{
FABULOUS_ASSERT( A.size() == X.get_nb_row() );
FABULOUS_ASSERT( A.size() == B.get_nb_row() );
FABULOUS_ASSERT( X.get_nb_col() == B.get_nb_col() );
}
public:
BGMRes(bool log_real_residual = false):
_logger{log_real_residual},
......@@ -155,28 +163,32 @@ public:
*/
template< class Algo, class RestartParam, class Matrix >
int solve( Matrix &A, Block<S> &X, Block<S> &B,
Algo &, const int maxMVP, const int max_krylov_space_size,
Algo&, const int max_mvp, const int max_krylov_space_size,
const std::vector<P> &epsilon,
Orthogonalizer ortho, RestartParam restart_param )
{
check_params(A, X, B);
const int nbRHS = B.get_nb_col();
const int dim = B.get_nb_row();
reset();
print_start_info( dim, nbRHS, maxMVP, max_krylov_space_size,
print_start_info( dim, nbRHS, max_mvp, max_krylov_space_size,
ortho, restart_param, epsilon, A );
std::vector<P> normB, inv_normB;
normB = B.compute_norms(A);
inv_normB = array_inverse(normB);
X.scale(inv_normB); B.scale(inv_normB);
const int nb_eigen_pair = restart_param.get_k();
Base<S> base{dim, max_krylov_space_size+nbRHS};
Restarter<RestartParam, S> restarter{restart_param, base};
bool convergence = false;
while (!convergence && _mvp < maxMVP) {
while (!convergence && _mvp < max_mvp) {
//Compute nb of mat vect product to give to Arnoldi procedure
int size_to_span = std::min(max_krylov_space_size, maxMVP-_mvp);
int size_to_span = std::min(max_krylov_space_size, max_mvp-_mvp);
size_to_span = std::max(size_to_span, nbRHS);
print_iteration_start_info(size_to_span);
if (nbRHS + nb_eigen_pair > size_to_span)
break;
using Arnoldi = typename Algo::template t3mpl4te<S>;
Arnoldi arnoldi{_logger, restarter, dim, nbRHS, size_to_span};
convergence = arnoldi.run( A, X, B, size_to_span, epsilon,
......
......@@ -23,7 +23,8 @@ namespace fabulous {
*
* This class support DeflatedRestarting
*/
template<template<class> class HESSENBERG, class S > class ArnoldiDR
template<template<class> class HESSENBERG, class S >
class ArnoldiDR
{
static_assert(
arnoldiXhessenberg<fabulous::ArnoldiDR, HESSENBERG>::value,
......
......@@ -20,7 +20,8 @@ namespace fabulous {
*
* \warning This class does NOT support DeflatedRestarting (not implemented)
*/
template < template<class> class HESSENBERG, class S > class ArnoldiIB
template < template<class> class HESSENBERG, class S >
class ArnoldiIB
{
static_assert(
arnoldiXhessenberg<fabulous::ArnoldiIB, HESSENBERG>::value,
......
......@@ -131,15 +131,14 @@ public:
{
int M = _nb_vect + _nbRHS;
int N = _nb_vect;
Block<S> Y = alloc_least_square_sol();
assert( _YY.get_nb_row() == N );
assert( Y.get_nb_row() == N );
assert( LS.get_nb_row() == M );
LS.copy(_R1);
assert(LS.get_nb_col() == _R1.get_nb_col());
assert(LS.get_nb_row() >= _R1.get_nb_row());
//FABULOUS_DEBUG("R1.nb_row()="<<_R1.get_nb_row());
Block<S> Y = alloc_least_square_sol();
solve_least_square(Y);
LapackKernI::gemm(
......
......@@ -337,7 +337,6 @@ private:
case OrthoScheme::ICGS: ICGS(A, base, H, C, D, P, W); break;
default: FABULOUS_FATAL_ERROR("Invalid orthogonalization scheme\n"); break;
}
std::cout<<"Arnoldi RUHE done\n";
}
};
......
......@@ -31,6 +31,7 @@ namespace fabulous {
struct ClassicRestart
{
ClassicRestart() = default;
int get_k() { return 0; }
};
/**
......
......@@ -82,6 +82,14 @@ inline const char *basename(const char *str)
#define FABULOUS_NOTE(errstr_) \
FABULOUS_CONCAT(errstr_, ::fabulous::note)
#define FABULOUS_ASSERT( cond_ ) \
do{ \
if (!(cond_)) { \
FABULOUS_FATAL_ERROR("ASSERT( "#cond_" ) FAILED!"); \
} \
}while(0) \
};
#endif // FABULOUS_ERROR_HPP
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment