Commit 7dbb1f1f authored by MIJIEUX Thomas's avatar MIJIEUX Thomas

Factorize, rename, and comment

parent b90b54ac
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "../src/Arnoldi_Ortho.hpp" #include "../src/Arnoldi_Ortho.hpp"
#include "../src/Arnoldi_IB.hpp" #include "../src/Arnoldi_IB.hpp"
#include "../src/BGMRes.hpp" #include "../src/BGMRes.hpp"
#include "../src/BGMResDR.hpp"
#include "fabulous.h" #include "fabulous.h"
#include "MatrixApi.hpp" #include "MatrixApi.hpp"
...@@ -15,169 +16,165 @@ ...@@ -15,169 +16,165 @@
namespace fabulous { namespace fabulous {
/** /**
* @brief This class is used to dereference the handle given by the * \brief This interface is used to operate on the handle given by the user
* user without knowing in which arithmetic we are. * without knowing which arithmetic is used.
*/ */
struct ApiEngineI struct ApiEngineI
{ {
virtual ~ApiEngineI(){} virtual ~ApiEngineI(){}
virtual void set_rightprecond( virtual void set_mvp(fabulous_mvp_t user_mvp) = 0;
fabulous_rightprecond_t user_rpc, void *user_rpc_data) = 0; virtual void set_rightprecond(fabulous_rightprecond_t user_rpc) = 0;
virtual void set_usergemm(fabulous_mvp_t user_mvp) = 0; virtual void set_dot_product(fabulous_dot_t user_dot) = 0;
virtual void set_parameters( virtual void set_parameters(
int max_mvp, int max_krylov_space_size, void *tolerance, int nb_tol) = 0; int max_mvp, int max_krylov_space_size, void *tolerance, int nb_tol) = 0;
virtual void set_dotProduct(fabulous_dot_t user_dot) = 0;
virtual void set_ortho_process(fabulous_orthoproc orthoproc) = 0; virtual void set_ortho_process(fabulous_orthoproc orthoproc) = 0;
virtual int solve(int dim, void *RHS, void *X0) = 0;
virtual int solve_IB(int dim, void *RHS, void *X0) = 0; virtual int solve(int nrhs, void *B, int ldb, void *X, int ldx) = 0;
virtual int solve_QR(int dim, void *RHS, void *X0) = 0; virtual int solve_QR(int nrhs, void *B, int ldb, void *X, int ldx) = 0;
virtual void *get_results() = 0; virtual int solve_IB(int nrhs, void *B, int ldb, void *X, int ldx) = 0;
virtual int solve_DR(int nrhs, void *B, int ldb, void *X, int ldx,
int nb_eigen_pair, void *target) = 0;
virtual const double *get_logs(int *size) = 0; virtual const double *get_logs(int *size) = 0;
}; };
/** /**
* @brief Main Engine of the Library. Each method of the API except for * \brief Main Engine of the Library. Each method of the API except for
* the init function is a member function of both this class and its * the init function is a member function of both this class and its
* parent's one. * parent's one.
*/ */
template< class U > class ApiEngine : public ApiEngineI template< class S > class ApiEngine : public ApiEngineI
{ {
public: public:
typedef typename Arithmetik<U>::value_type value_type; using value_type = typename Arithmetik<S>::value_type;
typedef typename Arithmetik<U>::primary_type primary_type; using primary_type = typename Arithmetik<S>::primary_type;
using P = primary_type;
private: private:
MatrixApi<U> matrix; int _dim;
MatrixApi<S> _matrix;
// Solve parameters int _maxMVP;
int max_mvp; std::vector<P> _tolerance;
std::vector<primary_type> tolerance; int _max_krylov_space_size;
int max_krylov_space_size; OrthoScheme _ortho_scheme;
OrthoScheme ortho;
Block<U> sol; // Solution storage
Logger<primary_type> log;
public: public:
ApiEngine(void *user_mvp_data, int dim, void *user_env): ApiEngine(int dim, void *user_env):
matrix{user_env, user_mvp_data, dim}, _dim{dim},
max_mvp(0), _matrix{dim, user_env},
tolerance(0), _maxMVP{0},
max_krylov_space_size(0), _tolerance{},
ortho(OrthoScheme::MGS) _max_krylov_space_size{0},
_ortho_scheme(OrthoScheme::MGS)
{ {
} }
/** /**
* @brief This function set the user Dot product * \brief This function set the user Dot product
*/ */
void set_dotProduct(fabulous_dot_t user_dot) override void set_dot_product(fabulous_dot_t user_dot) override
{ {
matrix.setDotProduct(user_dot); _matrix.set_dot_product(user_dot);
} }
/** /**
* @brief This function allow to set a right pre conditionner * \brief This function allow to set a right pre conditionner
* *
*/ */
void set_rightprecond(fabulous_rightprecond_t user_rpc, void *user_rpc_data) override void set_rightprecond(fabulous_rightprecond_t user_rpc) override
{ {
matrix.setRightPreCond(user_rpc, user_rpc_data); _matrix.set_rpc(user_rpc);
} }
void set_usergemm(fabulous_mvp_t user_mvp) override void set_mvp(fabulous_mvp_t user_mvp) override
{ {
matrix.setMatBlockVect(user_mvp); _matrix.set_mvp(user_mvp);
} }
void set_parameters(int inMaxIte, int inMax_Krylov_Space_Size, void *in_tolerance, int nb_tol) override void set_parameters(int maxMVP, int max_krylov_space_size,
void *tolerance, int nb_tolerance) override
{ {
primary_type *p_tol = reinterpret_cast<primary_type*>(in_tolerance); P *pTolerance = reinterpret_cast<P*>(tolerance);
tolerance.clear(); _tolerance.clear();
tolerance.reserve(nb_tol); _tolerance.reserve(nb_tolerance);
tolerance.assign(p_tol, p_tol+nb_tol); _tolerance.assign(pTolerance, pTolerance+nb_tolerance);
max_mvp = inMaxIte; _maxMVP = maxMVP;
max_krylov_space_size = inMax_Krylov_Space_Size; _max_krylov_space_size = max_krylov_space_size;
} }
void set_ortho_process(fabulous_orthoproc orthoproc) override void set_ortho_process(fabulous_orthoproc orthoproc) override
{ {
switch (orthoproc) { switch (orthoproc) {
case FABULOUS_MGS: ortho = OrthoScheme::MGS; break; case FABULOUS_MGS: _ortho_scheme = OrthoScheme::MGS; break;
case FABULOUS_CGS: ortho = OrthoScheme::CGS; break; case FABULOUS_CGS: _ortho_scheme = OrthoScheme::CGS; break;
case FABULOUS_IMGS: ortho = OrthoScheme::IMGS; break; case FABULOUS_IMGS: _ortho_scheme = OrthoScheme::IMGS; break;
case FABULOUS_ICGS: ortho = OrthoScheme::ICGS; break; case FABULOUS_ICGS: _ortho_scheme = OrthoScheme::ICGS; break;
default: default:
::fabulous::warning("Value for Ortho is not part of {0,1,2,3}\n" ::fabulous::warning("Value for Ortho is not part of {0,1,2,3}\n"
"Ortho process used will be MGS\n"); "Ortho process used will be MGS\n");
ortho = OrthoScheme::MGS; _ortho_scheme = OrthoScheme::MGS;
break; break;
} }
} }
template< class BLOCK > template< template<class> class ARNOLDI >
void ConvertStructures(void *RHS, void *X0, BLOCK &B, BLOCK &X_init) int call_solve(int nrhs, S *B, int ldb, S *X, int ldx)
{ {
X_init.InitBlock(reinterpret_cast<U*>(X0)); BGMRes<S> bgmres;
B.InitBlock(reinterpret_cast<U*>(RHS)); return bgmres.template solve<ARNOLDI>(
_matrix, _dim, nrhs, B, ldb, X, ldx,
_maxMVP, _max_krylov_space_size, _tolerance,
_ortho_scheme, OrthoType::RUHE
);
} }
template<class ARNOLDI, class BLOCK, int solve(int nrhs, void *B, int ldb, void *X, int ldx) override
class P = typename BLOCK::primary_type >
int call_solve(BLOCK &B, BLOCK &X0, BLOCK &sol, Logger<P> &log)
{ {
return BGMRes<ARNOLDI>(matrix, B, X0, S *B_ = reinterpret_cast<S*>(B);
max_mvp, max_krylov_space_size, S *X_ = reinterpret_cast<S*>(X);
sol, log, tolerance, return call_solve<Arnoldi>(nrhs, B_, ldb, X_, ldx);
ortho, OrthoType::RUHE );
} }
// Solve method : RHS and X0 have the same size int solve_QR(int nrhs, void *B, int ldb, void *X, int ldx) override
int solve(int nbRHS, void *RHS, void *X0) override
{ {
// Convert raw data to Block S *B_ = reinterpret_cast<S*>(B);
Block<U> X_init{nbRHS, matrix.size()}; S *X_ = reinterpret_cast<S*>(X);
Block<U> B{nbRHS, matrix.size()}; return call_solve<Arnoldi_QRInc>(nrhs, B_, ldb, X_, ldx);
ConvertStructures(RHS, X0, B, X_init);
// Init an empty block to store solution
sol.initData(nbRHS, matrix.size());
return call_solve<Arnoldi>(B, X_init, sol, log);
} }
// Solve method : RHS and X0 have the same size int solve_IB(int nrhs, void *B, int ldb, void *X, int ldx) override
int solve_QR(int nbRHS, void *RHS, void *X0) override
{ {
// Convert raw data to Block S *B_ = reinterpret_cast<S*>(B);
Block<U> X_init{nbRHS, matrix.size()}; S *X_ = reinterpret_cast<S*>(X);
Block<U> B{nbRHS, matrix.size()}; return call_solve<Arnoldi_IB>(nrhs, B_, ldb, X_, ldx);
ConvertStructures(RHS, X0, B, X_init);
// Init an empty block to store solution
sol.initData(nbRHS, matrix.size());
return call_solve<Arnoldi_QRInc>(B, X_init, sol, log);
}
int solve_IB(int nbRHS, void *RHS, void *X0) override
{
Block<U> X_init{nbRHS, matrix.size()}; // Convert raw data to Block
Block<U> B{nbRHS, matrix.size()};
ConvertStructures(RHS, X0, B, X_init);
// Init an empty block to store solution
sol.initData(nbRHS, matrix.size());
return call_solve<Arnoldi_IB>(B, X_init, sol, log);
} }
void *get_results() override int solve_DR(int nrhs, void *B, int ldb, void *X, int ldx,
int nb_eigen_pair, void *target) override
{ {
return sol.getPtr(); S *B_ = reinterpret_cast<S*>(B);
S *X_ = reinterpret_cast<S*>(X);
P Target;
if (target != nullptr)
Target = *reinterpret_cast<P*>(target);
else
Target = P{0.0};
BGMResDR<S> bgmres;
return bgmres.template solve<Arnoldi>(
_matrix, _dim, nrhs, B_, ldb, X_, ldx,
_maxMVP, _max_krylov_space_size, _tolerance,
nb_eigen_pair, Target,
_ortho_scheme, OrthoType::RUHE
);
} }
const double *get_logs(int *size) override const double *get_logs(int *size) override
{ {
*size = log.getNbIteLogged(); Logger<P> _logger;
return log.writeDownArray(); *size = _logger.get_nb_iterations();
return _logger.write_down_array();
} }
}; };
......
...@@ -4,125 +4,112 @@ ...@@ -4,125 +4,112 @@
#include <cassert> #include <cassert>
#include "fabulous.h" #include "fabulous.h"
#include "Block.hpp" #include "../../src/Block.hpp"
#include "../../src/Algorithm.hpp"
namespace fabulous { namespace fabulous {
/** /**
* @brief This class is a object wrapper over the callback from user. * \brief Object Oriented wrapper over the user's callbacks.
* The MatrixVectorProduct will be set here
* *
* Note: the matrix must be square * The MatBlockVect, DotProduct, and PrecondBlockVect will be set here
*
* \note the matrix must be square
*/ */
template< class U > class MatrixApi template< class S > class MatrixApi
{ {
public: public:
typedef typename Arithmetik<U>::value_type value_type; typedef typename Arithmetik<S>::value_type value_type;
typedef typename Arithmetik<U>::primary_type primary_type; typedef typename Arithmetik<S>::primary_type primary_type;
int _dim;
void *_user_env;
fabulous_mvp_t _user_mvp; fabulous_mvp_t _user_mvp;
fabulous_rightprecond_t _user_rpc; fabulous_rightprecond_t _user_rpc;
fabulous_dot_t _user_dot; fabulous_dot_t _user_dot;
int _dim;
void *_user_mvp_data;
void *_user_rpc_data;
void *_user_env;
bool _use_rightprecond;
public: public:
MatrixApi(): MatrixApi(int dim, void *user_env):
_dim(dim),
_user_env(user_env),
_user_mvp(nullptr), _user_mvp(nullptr),
_user_rpc(nullptr), _user_rpc(nullptr),
_user_dot(nullptr), _user_dot(nullptr)
_dim(0),
_user_mvp_data(nullptr),
_user_rpc_data(nullptr),
_user_env(nullptr),
_use_rightprecond(false)
{ {
} }
int size() const { return _dim; } // Return dimension of matrix (square matrix) int size() const { return _dim; } // Return dimension of matrix (square matrix)
MatrixApi(void *user_env, void *user_mvp_data, int dim): S at(int, int) { FABULOUS_FATAL_ERROR("should not be reached"); return S{0.0}; }
_user_mvp(nullptr),
_user_rpc(nullptr),
_user_dot(nullptr),
_dim(dim),
_user_mvp_data(user_mvp_data),
_user_rpc_data(nullptr),
_user_env(user_env),
_use_rightprecond(false)
{
}
U at(int, int) { FABULOUS_FATAL_ERROR("should not be reached"); return U{0.0}; }
void MatBlockVect(const Block<S> &input, Block<S> &output,
void MatBlockVect(Block<U> &input, Block<U> &output, int idxToWrite=0) S alpha = S{1.0}, S beta = S{0.0}) const
{ {
void* toWrite = output.getPtr(idxToWrite); if ( _user_mvp == nullptr ) {
std::cout<<"MatBlockVect : Input is "<<input.getSizeBlock()<<" x " FABULOUS_NOTE("User matrix vector product is not set!");
<<input.getLeadingDim()<<"\n"; FABULOUS_NOTE("Have you called fabulous_set_mvp() ?!");
std::cout<<" : Output is "<<output.getSizeBlock()<<" x " FABULOUS_FATAL_ERROR("missing user matrix product; cannot recover");
<<output.getLeadingDim()<<"\n";
assert( _user_mvp != nullptr );
_user_mvp(_user_mvp_data,
input.getSizeBlock(), input.getPtr(), &toWrite,
_user_env );
} }
_user_mvp(
void MatBaseProduct(U *ptrToRead, int nbRHS, void *ptrToWrite) _user_env, input.get_nb_col(),
{ input.get_ptr(), input.get_leading_dim(),
assert( _user_mvp != nullptr ); output.get_ptr(), output.get_leading_dim(),
_user_mvp(_user_mvp_data, nbRHS, ptrToRead, &ptrToWrite, _user_env); &alpha, &beta
);
} }
void setMatBlockVect(fabulous_mvp_t user_mvp) void set_mvp(fabulous_mvp_t user_mvp)
{ {
_user_mvp = user_mvp; _user_mvp = user_mvp;
} }
void setRightPreCond(fabulous_rightprecond_t user_rpc, void *user_rpc_data) void set_rpc(fabulous_rightprecond_t user_rpc)
{ {
_user_rpc = user_rpc; _user_rpc = user_rpc;
_user_rpc_data = user_rpc_data;
_use_rightprecond = true;
} }
void setDotProduct(fabulous_dot_t user_dot) void set_dot_product(fabulous_dot_t user_dot)
{ {
_user_dot = user_dot; _user_dot = user_dot;
} }
bool useRightPreCond() const { return _use_rightprecond; } bool useRightPrecond() const { return _user_rpc != nullptr; }
template< class BLOCK > template< class BLOCK >
void preCondBlockVect(BLOCK &input, BLOCK &output) void PrecondBlockVect(const BLOCK &input, BLOCK &output) const
{ {
void *toWrite = output.getPtr(); assert ( _user_rpc != nullptr );
assert( _user_rpc != nullptr ); _user_rpc(
_user_rpc(_user_rpc_data, _user_env, input.get_nb_col(),
input.getSizeBlock(), input.getPtr(), &toWrite, input.get_ptr(), input.get_leading_dim(),
_user_env ); output.get_ptr(), output.get_leading_dim()
);
} }
template< class BLOCK > void DotProduct(int M, int N,
void preCondBaseProduct(U *ptrToRead, BLOCK &output) const S *A, int lda,
const S *B, int ldb,
S *C, int ldc) const
{
if ( _user_dot == nullptr) {
FABULOUS_NOTE("User dot product is not set!");
FABULOUS_NOTE("Have you called fabulous_set_dot_product() ?!");
FABULOUS_FATAL_ERROR("missing user dot product; cannot recover");
}
_user_dot( _user_env, M, N, A, lda, B, ldb, C, ldc );
}
void DotProduct(const Block<S> &A, const Block<S> &B, Block<S> &C) const
{ {
void *toWrite = output.getPtr(); DotProduct( A.get_nb_col(), B.get_nb_col(),
assert( _user_rpc != nullptr ); A.get_ptr(), A.get_leading_dim(),
_user_rpc(_user_rpc_data, B.get_ptr(), B.get_leading_dim(),
output.getSizeBlock(), ptrToRead, &toWrite, C.get_tpr(), C.get_leading_dim() );
_user_env );
} }
void DotProduct(int size, int nbVect, U *vectA, U *vectB, U *res) void QRFacto(Block<S> &Q, Block<S> &R) const
{ {
assert( _user_dot != nullptr ); Algorithm::InPlaceQRFactoMGS_User(Q, R, *this);
_user_dot(size, nbVect, vectA, vectB, res, _user_env);
} }
}; };
......
...@@ -10,26 +10,22 @@ FABULOUS_BEGIN_C_DECL ...@@ -10,26 +10,22 @@ FABULOUS_BEGIN_C_DECL
#define FABULOUS_HANDLE(engine_) \ #define FABULOUS_HANDLE(engine_) \
reinterpret_cast<fabulous_handle>(engine_) reinterpret_cast<fabulous_handle>(engine_)
/** fabulous_handle fabulous_create(
* Implement Init function by creating an instance of the Engine fabulous_arithmetic ari, int dim, void *user_env)
* class. The ptr to the instance is the Handle.
*/
fabulous_handle fabulous_init(
fabulous_arithmetic ari, void *mvp_data, int dim, void *userEnv)
{ {
ApiEngineI *engine = nullptr; ApiEngineI *engine = nullptr;
switch (ari) { switch (ari) {
case FABULOUS_FLOAT: case FABULOUS_REAL_FLOAT:
engine = new ApiEngine<float>(mvp_data, dim, userEnv); engine = new ApiEngine<float>(dim, user_env);
break; break;
case FABULOUS_DOUBLE: case FABULOUS_REAL_DOUBLE:
engine = new ApiEngine<double>(mvp_data, dim, userEnv); engine = new ApiEngine<double>(dim, user_env);
break; break;
case FABULOUS_COMPLEX_FLOAT: case FABULOUS_COMPLEX_FLOAT:
engine = new ApiEngine<std::complex<float>>(mvp_data, dim, userEnv); engine = new ApiEngine<std::complex<float>>(dim, user_env);
break; break;
case FABULOUS_COMPLEX_DOUBLE: case FABULOUS_COMPLEX_DOUBLE:
engine = new ApiEngine<std::complex<double>>(mvp_data, dim, userEnv); engine = new ApiEngine<std::complex<double>>(dim, user_env);
break; break;
default: default:
FABULOUS_FATAL_ERROR( FABULOUS_FATAL_ERROR(
...@@ -41,24 +37,23 @@ fabulous_handle fabulous_init( ...@@ -41,24 +37,23 @@ fabulous_handle fabulous_init(
return FABULOUS_HANDLE(engine); return FABULOUS_HANDLE(engine);
} }
void fabulous_set_usergemm(fabulous_mvp_t user_mvp, fabulous_handle handle) void fabulous_set_mvp(fabulous_mvp_t user_mvp, fabulous_handle handle)
{ {
ApiEngineI *disp = FABULOUS_API_ENGINE(handle); ApiEngineI *disp = FABULOUS_API_ENGINE(handle);
disp->set_usergemm(user_mvp); disp->set_mvp(user_mvp);
} }
void fabulous_set_rightprecond(fabulous_rightprecond_t user_rpc, void fabulous_set_rightprecond(fabulous_rightprecond_t user_rpc,
void *user_rpc_data,
fabulous_handle handle) fabulous_handle handle)
{ {
ApiEngineI *disp = FABULOUS_API_ENGINE(handle); ApiEngineI *disp = FABULOUS_API_ENGINE(handle);
disp->set_rightprecond(user_rpc, user_rpc_data); disp->set_rightprecond(user_rpc);
} }
void fabulous_set_dot_product(fabulous_dot_t user_dot, fabulous_handle handle) void fabulous_set_dot_product(fabulous_dot_t user_dot, fabulous_handle handle)
{ {
ApiEngineI *disp = FABULOUS_API_ENGINE(handle); ApiEngineI *disp = FABULOUS_API_ENGINE(handle);
disp->set_dotProduct(user_dot); disp->set_dot_product(user_dot);
} }
void fabulous_set_parameters(int max_mvp, int max_space_size, void fabulous_set_parameters(int max_mvp, int max_space_size,
...@@ -76,40 +71,44 @@ void fabulous_set_ortho_process(fabulous_orthoproc orthoproc, ...@@ -76,40 +71,44 @@ void fabulous_set_ortho_process(fabulous_orthoproc orthoproc,
disp->set_ortho_process(orthoproc); disp->set_ortho_process(orthoproc);
} }
int fabulous_solve(int nbRHS, void *RHS, void *X0, fabulous_handle handle) int fabulous_solve(int nrhs, void *B, int ldb, void *X, int ldx, fabulous_handle handle)
{ {
ApiEngineI *disp = FABULOUS_API_ENGINE(handle); ApiEngineI *disp = FABULOUS_API_ENGINE(handle);
return disp->solve(nbRHS, RHS, X0); return disp->solve(nrhs, B, ldb, X, ldx);
} }
int fabulous_solve_QR(int nbRHS, void *RHS, void *X0, fabulous_handle handle) int fabulous_solve_QR(int nrhs, void *B, int ldb, void *X, int ldx,
fabulous_handle handle)
{ {
ApiEngineI *disp = FABULOUS_API_ENGINE(handle); ApiEngineI *disp = FABULOUS_API_ENGINE(handle);
return disp->solve_QR(nbRHS, RHS, X0); return disp->solve_QR(nrhs, B, ldb, X, ldx);
} }
int fabulous_solve_IB(int nbRHS, void *RHS, void *X0, fabulous_handle handle) int fabulous_solve_IB(int nrhs, void *B, int ldb, void *X, int ldx,
fabulous_handle handle)
{ {
ApiEngineI *disp = FABULOUS_API_ENGINE(handle); ApiEngineI *disp = FABULOUS_API_ENGINE(handle);
return disp->solve_IB(nbRHS, RHS, X0); return disp->solve_IB(nrhs, B, ldb, X, ldx);
} }
void fabulous_dealloc(fabulous_handle handle) int fabulous_solve_DR(int nrhs, void *B, int ldb, void *X, int ldx,
int nb_eigen_pair, void *target,
fabulous_handle handle)
{ {
ApiEngineI *disp = FABULOUS_API_ENGINE(handle); ApiEngineI *disp = FABULOUS_API_ENGINE(handle);
delete disp; return disp->solve_DR(nrhs, B, ldb, X, ldx, nb_eigen_pair, target);
} }
void *fabulous_get_results(fabulous_handle handle) const double *fabulous_get_logs(int *size, fabulous_handle handle)
{ {
ApiEngineI *disp = FABULOUS_API_ENGINE(handle); ApiEngineI *disp = FABULOUS_API_ENGINE(handle);
return disp->get_results(); return disp->get_logs(size);
} }
const double *fabulous_get_logs(int *size,