Commit fd4a7453 authored by MIJIEUX Thomas's avatar MIJIEUX Thomas

add namespaces and better syntactic sugar for templates

parent 0363f035
......@@ -4,34 +4,27 @@ message(STATUS "Building the API")
#Define what's needed to Build the API
set(IBGMRES_API_HEADERS
src/Ib-GMRes-Dr.h
set(FABULOUS_HEADERS
src/fabulous.h
)
set(IBGMRES_API_SRC
src/IbGMResDrAPI.cpp
src/IbGMResDrEngine.hpp
src/UserMatrix.hpp
src/Ib-GMRes-Dr.h
set(FABULOUS_SRC
src/fabulous.cpp
src/ApiEngine.hpp
src/MatrixApi.hpp
${FABULOUS_HEADERS}
)
include_directories(src)
#define the test of the API
set(IBGMRES_API_TEST
tests/testApiGMres.c
tests/testApiGMresUserCase.c
)
if(BUILD_SHARED_LIBS)
add_library(ibgmresdr SHARED ${IBGMRES_API_SRC})
add_library(fabulous SHARED ${FABULOUS_SRC})
else()
add_library(ibgmresdr STATIC ${IBGMRES_API_SRC})
add_library(fabulous STATIC ${FABULOUS_SRC})
endif()
set(LIBS_FOR_API ibgmresdr)
list(APPEND LIBS_FOR_API
target_link_libraries(
fabulous
${LAPACKE_LIBRARIES}
${CBLAS_LIBRARIES}
${LAPACK_LIBRARIES}
......@@ -40,11 +33,17 @@ list(APPEND LIBS_FOR_API
${CHAMELEON_LIBRARIES_DEP}
)
#Pour chacun des tests au dessus, on fait des trucs...
foreach(_test ${IBGMRES_API_TEST})
#define the test of the API
set(FABULOUS_API_TEST
tests/testApiGMres.c
tests/testApiGMresUserCase.c
)
# create a target for each test:
foreach(_test ${FABULOUS_API_TEST})
get_filename_component(_name_exe ${_test} NAME_WE)
add_executable(${_name_exe} ${_test} ${IBGMRES_API_HEADERS})
target_link_libraries(${_name_exe} ${LIBS_FOR_API})
add_executable(${_name_exe} ${_test})
target_link_libraries(${_name_exe} fabulous)
install(
TARGETS ${_name_exe}
DESTINATION ${CMAKE_INSTALL_PREFIX}/bin/Api
......@@ -53,10 +52,11 @@ endforeach()
# # Rules for installing the project
install(
TARGETS ibgmresdr
TARGETS fabulous
DESTINATION lib
)
install(
FILES ${IBGMRES_API_HEADERS}
FILES ${FABULOUS_HEADERS}
DESTINATION include
)
#ifndef FABULOUS_API_ENGINE_HPP
#define FABULOUS_API_ENGINE_HPP
#include "../src/Utils.hpp"
#include "../src/Logger.hpp"
#include "../src/Arnoldi.hpp"
#include "../src/Arnoldi_QRInc.hpp"
#include "../src/Arnoldi_Ortho.hpp"
#include "../src/Arnoldi_IB.hpp"
#include "../src/BGMRes.hpp"
#include "fabulous.h"
#include "MatrixApi.hpp"
namespace fabulous {
/**
* @brief This class is used to dereference the handle given by the
* user without knowing in which arithmetic we are.
*/
struct ApiEngineI
{
virtual ~ApiEngineI(){}
virtual void set_rightprecond(
fabulous_rightprecond_t user_rpc, void *user_rpc_data) = 0;
virtual void set_usergemm(fabulous_mvp_t user_mvp) = 0;
virtual void set_parameters(
int max_mvp, int max_krylov_space_size, void *tolerance, int nb_tol) = 0;
virtual void set_dotProduct(fabulous_dot_t user_dot) = 0;
virtual void set_ortho_process(fabulous_orthoproc orthoproc) = 0;
virtual int solve(int dim, void *RHS, void *X0) = 0;
virtual int solve_IB(int dim, void *RHS, void *X0) = 0;
virtual int solve_QR(int dim, void *RHS, void *X0) = 0;
virtual void *get_results() = 0;
virtual const double *get_logs(int *size) = 0;
};
/**
* @brief Main Engine of the Library. Each method of the API except for
* the init function is a member function of both this class and its
* parent's one.
*/
template< class U > class ApiEngine : public ApiEngineI
{
public:
typedef typename Arithmetik<U>::value_type value_type;
typedef typename Arithmetik<U>::primary_type primary_type;
private:
MatrixApi<U> matrix;
// Solve parameters
int max_mvp;
std::vector<primary_type> tolerance;
int max_krylov_space_size;
OrthoChoice ortho;
Block<U> sol; // Solution storage
Logger<primary_type> log;
public:
ApiEngine(void *user_mvp_data, int dim, void *user_env):
matrix{user_env, user_mvp_data, dim},
max_mvp(0),
tolerance(0),
max_krylov_space_size(0),
ortho(OrthoChoice::MGS)
{
}
/**
* @brief This function set the user Dot product
*/
void set_dotProduct(fabulous_dot_t user_dot) override
{
matrix.setDotProduct(user_dot);
}
/**
* @brief This function allow to set a right pre conditionner
*
*/
void set_rightprecond(fabulous_rightprecond_t user_rpc, void *user_rpc_data) override
{
matrix.setRightPreCond(user_rpc, user_rpc_data);
}
void set_usergemm(fabulous_mvp_t user_mvp) override
{
matrix.setMatBlockVect(user_mvp);
}
void set_parameters(int inMaxIte, int inMax_Krylov_Space_Size, void *in_tolerance, int nb_tol) override
{
primary_type *p_tol = reinterpret_cast<primary_type*>(in_tolerance);
tolerance.clear();
tolerance.reserve(nb_tol);
tolerance.assign(p_tol, p_tol+nb_tol);
max_mvp = inMaxIte;
max_krylov_space_size = inMax_Krylov_Space_Size;
}
void set_ortho_process(fabulous_orthoproc orthoproc) override
{
switch (orthoproc) {
case FABULOUS_MGS: ortho = OrthoChoice::MGS; break;
case FABULOUS_CGS: ortho = OrthoChoice::CGS; break;
case FABULOUS_IMGS: ortho = OrthoChoice::IMGS; break;
case FABULOUS_ICGS: ortho = OrthoChoice::ICGS; break;
default:
::fabulous::warning("Value for Ortho is not part of {0,1,2,3}\n"
"Ortho process used will be MGS\n");
ortho = OrthoChoice::MGS;
break;
}
}
template< class BLOCK >
void ConvertStructures(void *RHS, void *X0, BLOCK &B, BLOCK &X_init)
{
X_init.InitBlock(reinterpret_cast<U*>(X0));
B.InitBlock(reinterpret_cast<U*>(RHS));
}
template<class ARNOLDI, class BLOCK,
class P = typename BLOCK::primary_type >
int call_solve(BLOCK &B, BLOCK &X0, BLOCK &sol, Logger<P> &log)
{
return BGMRes<ARNOLDI>(matrix, B, X0,
max_mvp, max_krylov_space_size,
sol, log, tolerance,
ortho, ArnOrtho::RUHE );
}
// Solve method : RHS and X0 have the same size
int solve(int nbRHS, void *RHS, void *X0) override
{
// Convert raw data to Block
Block<U> X_init{nbRHS, matrix.size()};
Block<U> B{nbRHS, matrix.size()};
ConvertStructures(RHS, X0, B, X_init);
// Init an empty block to store solution
sol.initData(nbRHS, matrix.size());
return call_solve<Arnoldi>(B, X_init, sol, log);
}
// Solve method : RHS and X0 have the same size
int solve_QR(int nbRHS, void *RHS, void *X0) override
{
// Convert raw data to Block
Block<U> X_init{nbRHS, matrix.size()};
Block<U> B{nbRHS, matrix.size()};
ConvertStructures(RHS, X0, B, X_init);
// Init an empty block to store solution
sol.initData(nbRHS, matrix.size());
return call_solve<Arnoldi_QRInc>(B, X_init, sol, log);
}
int solve_IB(int nbRHS, void *RHS, void *X0) override
{
Block<U> X_init{nbRHS, matrix.size()}; // Convert raw data to Block
Block<U> B{nbRHS, matrix.size()};
ConvertStructures(RHS, X0, B, X_init);
// Init an empty block to store solution
sol.initData(nbRHS, matrix.size());
return call_solve<Arnoldi_IB>(B, X_init, sol, log);
}
void *get_results() override
{
return sol.getPtr();
}
const double *get_logs(int *size) override
{
*size = log.getNbIteLogged();
return log.writeDownArray();
}
};
}; // namespace fabulous
#endif // FABULOUS_API_ENGINE_HPP
#include "IbGMResDrEngine.hpp"
/**
* Implement Init function by creating an instance of the Engine
* class. The ptr to the instance is the Handle.
*/
FABULOUS_BEGIN_C_DECL
Fabulous_handle fabulous_init(fabulous_Arithmetic ari,
void *Matrix,
int dim,
void *userEnv)
{
switch (ari) {
case FABULOUS_FLOAT:{
IbGMResDrEngine<float>* Engine =
new IbGMResDrEngine<float>(Matrix, dim, userEnv);
return Engine;
break;
}
case FABULOUS_DOUBLE:{
IbGMResDrEngine<double>* Engine =
new IbGMResDrEngine<double>(Matrix, dim, userEnv);
return Engine;
break;
}
case FABULOUS_COMPLEX_FLOAT:{
IbGMResDrEngine<std::complex<float>,float>* Engine =
new IbGMResDrEngine<std::complex<float>,float>(Matrix, dim, userEnv);
return Engine;
break;
}
case FABULOUS_COMPLEX_DOUBLE:{
IbGMResDrEngine<std::complex<double>,double>* Engine =
new IbGMResDrEngine<std::complex<double>,double>(Matrix, dim, userEnv);
return Engine;
break;
}
default:
std::cout<<"Arithmetic not recognised, please see the documentation"
<< "in order to choose something we propose\n"
<< "Exiting\n";
exit(0);
break;
}
}
FABULOUS_END_C_DECL
#ifndef IBGMRESENGINE_HPP
#define IBGMRESENGINE_HPP
#include "../src/Utils.hpp"
#include "../src/Logger.hpp"
#include "../src/Arnoldi.hpp"
#include "../src/Arnoldi_QRInc.hpp"
#include "../src/Arnoldi_Ortho.hpp"
#include "../src/Arnoldi_IB.hpp"
#include "../src/BGMRes.hpp"
#include "Ib-GMRes-Dr.h"
#include "UserMatrix.hpp"
/**
* @brief This class is used to dereference the handle given by the
* user without knowing in which arithmetic we are.
*/
class EngineDispatcher {
public:
EngineDispatcher(){}
EngineDispatcher(void*, int, void*){}
virtual ~EngineDispatcher(){}
virtual void set_rightprecond(Callback_RightCond /*rightPreCond*/,void * /*ptr*/){}
virtual void set_usergemm(Callback_MatBlockVect /*userProduct*/){}
virtual void set_parameters(int /*inMaxIte*/,int /*restart*/,
void* /*inTolerance*/, int /*nbTol*/){}
virtual void set_dotProduct(Callback_DotProduct /*dotProduct*/){}
virtual void set_ortho_process(int /*value*/){}
virtual int solve(int, void *, void *) { return 0; }
virtual int solve_IB(int, void*, void*) { return 0; }
virtual int solve_QR(int, void*, void*) { return 0; }
virtual void* get_results() { return nullptr; }
virtual double* get_logs(int* /*size*/) { return nullptr; }
};
/**
* @brief Main Engine of the Library. Each method of the API exept for
* the init function is a memeber function of both this class and its
* parent's one.
*/
template<typename S,typename P=S>
class IbGMResDrEngine : public EngineDispatcher {
private:
//User parameters
UserMatrix<S,P> *matrix;
void *userEnvPtr;
//Solve parameters
int max_iter;
std::vector<P> tolerance;
int restart;
OrthoChoice ortho;
//Solution storage
Block<S,P> *sol;
Logger<P> *log;
public:
IbGMResDrEngine(void *userMatrix, int dim, void *inUserEnv) :
matrix(nullptr),
userEnvPtr(inUserEnv),
max_iter(0),
tolerance(0),
restart(0),
ortho(OrthoChoice::MGS),
sol(nullptr),
log(nullptr)
{
matrix = new UserMatrix<S,P>(userEnvPtr,userMatrix,dim);
}
~IbGMResDrEngine()
{
delete matrix;
if (sol) delete sol;
if (log) delete log;
}
/**
* @brief This function set the user Dot product
*/
void set_dotProduct(Callback_DotProduct dotProduct) override
{
matrix->setDotProduct(dotProduct);
}
/**
* @brief This function allow to set a right pre conditionner
*
*/
void set_rightprecond(Callback_RightCond rightPreCond,
void* usrPtrRightPreCond) override
{
matrix->setRightPreCond(rightPreCond,usrPtrRightPreCond);
}
void set_usergemm(Callback_MatBlockVect userProduct) override
{
matrix->setMatBlockVect(userProduct);
}
void set_parameters(int inMaxIte, int inRestart,
void* inTolerance, int nbTol) override
{
P* inCastTol = (reinterpret_cast<P*>(inTolerance));
for (int i=0; i<nbTol; ++i)
tolerance.push_back(inCastTol[i]);
max_iter = inMaxIte;
restart = inRestart;
}
void set_ortho_process(int value){
switch(value){
case 0:
ortho = OrthoChoice::MGS;
break;
case 1:
ortho = OrthoChoice::CGS;
break;
case 2:
ortho = OrthoChoice::IMGS;
break;
case 3:
ortho = OrthoChoice::ICGS;
break;
default:
std::cout<<"Value for Ortho is not part of {0,1,2,3}\n"
<<"Ortho process used will be MGS\n";
}
}
//Solve method : RHS and X0 have the same size
int solve(int nbRHS, void *RHS , void *X0) override
{
//Convert raw data to Block
Block<S,P> X_init{nbRHS,matrix->size()};
Block<S,P> B{nbRHS,matrix->size()};
ConvertStructures(RHS, X0, B, X_init);
//Init an empty block to store solution
sol = new Block<S,P>(nbRHS, matrix->size());
log = new Logger<P>();
//Compute
int res = call_solve<Arnoldi>(B, X_init, *sol, *log);
return res;
}
//Solve method : RHS and X0 have the same size
int solve_QR(int nbRHS, void *RHS, void *X0) override
{
//Convert raw data to Block
Block<S,P> X_init(nbRHS,matrix->size());
Block<S,P> B(nbRHS,matrix->size());
ConvertStructures(RHS,X0,B,X_init);
//Init an empty block to store solution
sol = new Block<S,P>(nbRHS,matrix->size());
log = new Logger<P>();
//Compute
int res = call_solve<Arnoldi_QRInc>(B,X_init,*sol,*log);
return res;
}
int solve_IB(int nbRHS,void *RHS,void *X0) override
{
//Convert raw data to Block
Block<S,P> X_init(nbRHS,matrix->size());
Block<S,P> B(nbRHS,matrix->size());
ConvertStructures(RHS,X0,B,X_init);
//Init an empty block to store solution
sol = new Block<S,P>(nbRHS,matrix->size());
log = new Logger<P>();
int res = call_solve<Arnoldi_IB>(B,X_init,*sol,*log);
return res;
}
void ConvertStructures(void *RHS,
void *X0,
Block<S,P>& B,
Block<S,P>& X_init)
{
X_init.InitBlock(reinterpret_cast<S*>(X0));
B.InitBlock(reinterpret_cast<S*>(RHS));
}
template<class ARNOLDI,
class Block,
class Primary = typename Block::primary_type
>
int call_solve(Block &B,
Block &X0,
Block &sol,
Logger<Primary>& log)
{
return BGMRes<ARNOLDI>(*matrix, B, X0,
max_iter, restart,
sol, log, tolerance,
ortho, ArnOrtho::RUHE );
}
void* get_results() override
{
if (sol)
return sol->getPtr();
else
return nullptr;
}
double* get_logs(int *size) override
{
if (log) {
*size = log->getNbIteLogged();
double *res = log->writeDownArray();
return res;
} else {
*size = 0;
return nullptr;
}
}
};
FABULOUS_BEGIN_C_DECL
//Here, write the implementation of API functions
void fabulous_set_usergemm(Callback_MatBlockVect userProduct,
Fabulous_handle handle)
{
(reinterpret_cast<EngineDispatcher*>(handle))->set_usergemm(userProduct);
}
void fabulous_set_rightprecond(Callback_RightCond rightPreCond,
void *ptrRightPreCond,
Fabulous_handle handle)
{
(reinterpret_cast<EngineDispatcher*>(handle))->set_rightprecond(rightPreCond,
ptrRightPreCond);
}
void fabulous_set_dotProduct(Callback_DotProduct dotProduct,
Fabulous_handle handle)
{
(reinterpret_cast<EngineDispatcher*>(handle))->set_dotProduct(dotProduct);
}
void fabulous_set_parameters(int max_ite, int restart, void *tolerance,
int nbTol,
Fabulous_handle handle)
{
(reinterpret_cast<EngineDispatcher*>(handle))->set_parameters(max_ite,restart,
tolerance, nbTol);
}
void fabulous_set_ortho_process(enum fabulous_ortho_process value,
Fabulous_handle handle)
{
(reinterpret_cast<EngineDispatcher*>(handle))->set_ortho_process(value);
}
int fabulous_solve(int nbRHS,void *RHS,void *X0,
Fabulous_handle handle)
{
return (reinterpret_cast<EngineDispatcher*>(handle))->solve(nbRHS,RHS,X0);
}
int fabulous_solve_QR(int nbRHS,void *RHS,void *X0,
Fabulous_handle handle)
{
return (reinterpret_cast<EngineDispatcher*>(handle))->solve_QR(nbRHS,RHS,X0);
}
int fabulous_solve_IB(int nbRHS,void *RHS,void *X0,
Fabulous_handle handle)
{
return (reinterpret_cast<EngineDispatcher*>(handle))->solve_IB(nbRHS,RHS,X0);
}
void fabulous_dealloc(Fabulous_handle handle)
{
delete (reinterpret_cast<EngineDispatcher*>(handle));
}
void* fabulous_get_results(Fabulous_handle handle)
{
return (reinterpret_cast<EngineDispatcher*>(handle))->get_results();
}
double* fabulous_get_logs(int *size, Fabulous_handle handle)
{
return (reinterpret_cast<EngineDispatcher*>(handle))->get_logs(size);
}
FABULOUS_END_C_DECL
#endif // IBGMRESENGINE_HPP
#ifndef FABULOUS_MATRIX_API_HPP
#define FABULOUS_MATRIX_API_HPP
#include <cassert>
#include "fabulous.h"
#include "Block.hpp"
namespace fabulous {
/**
* @brief This class is a object wrapper over the callback from user.
* The MatrixVectorProduct will be set here
*
* Note: the matrix must be square
*/
template< class U > class MatrixApi
{
public:
typedef typename Arithmetik<U>::value_type value_type;
typedef typename Arithmetik<U>::primary_type primary_type;
fabulous_mvp_t _user_mvp;
fabulous_rightprecond_t _user_rpc;
fabulous_dot_t _user_dot;
int _dim;
void *_user_mvp_data;
void *_user_rpc_data;
void *_user_env;
bool _use_rightprecond;
public:
MatrixApi():
_user_mvp(nullptr),
_user_rpc(nullptr),
_user_dot(nullptr),
_dim(0),
_user_mvp_data(nullptr),
_user_rpc_data(nullptr),
_user_env(nullptr),
_use_rightprecond(false)
{
}
int size() const { return _dim; } // Return dimension of matrix (square matrix)
MatrixApi(void *user_env, void *user_mvp_data, int dim):
_user_mvp(nullptr),
_user_rpc(nullptr),
_user_dot(nullptr),
_dim(dim),
_user_mvp_data(user_mvp_data),
_user_rpc_data(nullptr),
_user_env(user_env),
_use_rightprecond(false)
{
}
U at(int, int) { FABULOUS_FATAL_ERROR("should not be reached"); return U{0.0}; }
void MatBlockVect(Block<U> &input, Block<U> &output, int idxToWrite=0)
{
void* toWrite = output.getPtr(idxToWrite);
std::cout<<"MatBlockVect : Input is "<<input.getSizeBlock()<<" x "
<<input.getLeadingDim()<<"\n";
std::cout<<" : Output is "<<output.getSizeBlock()<<" x "
<<output.getLeadingDim()<<"\n";
assert( _user_mvp != nullptr );
_user_mvp(_user_mvp_data,
input.getSizeBlock(), input.getPtr(), &toWrite,
_user_env );
}
void MatBaseProduct(U *ptrToRead, int nbRHS, void *ptrToWrite)
{
assert( _user_mvp != nullptr );
_user_mvp(_user_mvp_data, nbRHS, ptrToRead, &ptrToWrite, _user_env);
}
void setMatBlockVect(fabulous_mvp_t user_mvp)
{
_user_mvp = user_mvp;
}
void setRightPreCond(fabulous_rightprecond_t user_rpc, void *user_rpc_data)
{
_user_rpc = user_rpc;