Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ec1ce630 authored by hhakim's avatar hhakim
Browse files

Implement optional pseudoinverse error in C++ svdtj.

parent f02287d4
No related branches found
No related tags found
No related merge requests found
......@@ -123,6 +123,16 @@ namespace Faust
template<typename FPP, FDevice DEVICE>
MatDense<FPP,DEVICE> svdtj_compute_W1H_M_W2_meth3(const MatDense<FPP,DEVICE> &dM, const Transform<FPP, DEVICE> &tW1, const Transform<FPP, DEVICE> &tW2, MatDense<FPP, DEVICE> & prev_W1H_M_W2, const int err_period, const int k1, const int k2, const int t1, const int t2, const bool new_W1, const bool new_W2);
/**
* Computes the error of PINVTJ (pseudo-inverse) with W1_MW2 the precomputed product of U'MW2 approximate and S the approximation of singular values.
*
* \param relErr: true to compute relative error else otherwise error is computed.
* \param verbosity: true to print a message including the error.
*
*/
template<typename FPP, FDevice DEVICE>
Real<FPP> calc_err_pinvtj(const Vect<FPP, DEVICE> &S, MatDense<FPP, DEVICE> &W1_MW2, const bool relErr, const bool verbosity);
}
#include "faust_SVDTJ.hpp"
......
......@@ -319,6 +319,41 @@ namespace Faust
return USV_prod;
}
template<typename FPP, FDevice DEVICE>
Real<FPP> calc_err_pinvtj(const Vect<FPP, DEVICE> &S, MatDense<FPP, DEVICE> &W1_MW2, const bool relErr, const bool verbosity)
{
const size_t m = W1_MW2.getNbRow();
const size_t n = W1_MW2.getNbCol();
const size_t min_mn = m > n?n:m;
unsigned int *ids = new unsigned int[min_mn];
std::iota(ids, ids+min_mn, 0);
MatSparse<FPP, DEVICE> S_inv(ids, ids, S.getData(), n, m, min_mn);
for(int i=0;i<min_mn;i++)
if(S[i] != FPP(0))
S_inv.getValuePtr()[i] = FPP(1) / S[i];
else
S_inv.getValuePtr()[i] = FPP(0);
delete[] ids;
W1_MW2.multiplyRight(S_inv);
Real<FPP> terr = W1_MW2.norm();
terr *= terr;
terr -= min_mn;
if(relErr) terr /= min_mn;
terr = std::sqrt(std::abs(terr));
if(verbosity)
{
std::cout << "PINVTJ ";
if(relErr)
std::cout << "relative ";
else
std::cout << "absolute ";
std::cout << "error: ";
std::cout << terr << std::endl;
}
return terr;
}
template<typename FPP, FDevice DEVICE, typename FPP2>
void svdtj_core_gen_step(MatGeneric<FPP,DEVICE>* M, MatDense<FPP,DEVICE> &dM, MatDense<FPP,DEVICE> &dM_M, MatDense<FPP,DEVICE> &dMM_, int J1, int J2, int t1, int t2, FPP2 tol, unsigned int verbosity, bool relErr, int order, const bool enable_large_Faust, TransformHelper<FPP,DEVICE> ** U, TransformHelper<FPP,DEVICE> **V, Vect<FPP,DEVICE> ** S_, const int err_period/*=100*/)
{
......@@ -384,6 +419,8 @@ namespace Faust
bool new_W1 = true, new_W2 = true; // true if W1 (resp. W2) has grown up during the iteration
bool W1_too_long = false, W2_too_long = false; // used only if enable_large_Faust is false
//
bool pinvtj_err = false;
if(! enable_large_Faust)
{
W1_max_size = m * m / 4;
......@@ -401,10 +438,16 @@ namespace Faust
// environment variable for testing
// if set the exact error will be computed for all iterations
// TODO: it should be an argument instead of an env. var.
char* str_all_true_err = getenv("SVDTJ_ALL_TRUE_ERR");
if(str_all_true_err)
terr_enabled = bool(std::atoi(str_all_true_err));
// TODO: it should be an argument instead of an env. var.
char* str_pinvtj_err = getenv("PINVTJ_ERR");
if(str_pinvtj_err)
pinvtj_err = bool(std::atoi(str_pinvtj_err));
while(loop)
{
......@@ -514,33 +557,45 @@ namespace Faust
S[i] = W1_MW2(m - min_mn + i,n - min_mn + i);
// compute error
if(! terr_enabled)
if(pinvtj_err)
{
auto Sd_norm = S.norm();
aerr = Faust::fabs(M_norm * M_norm - Sd_norm * Sd_norm);
terr_enabled = Faust::fabs(tol * tol - aerr) < Faust::fabs(E);
assert(order < 0); // TODO: S_inv if order > 0
// the error for pinvtj
terr = calc_err_pinvtj(S, W1_MW2, relErr, verbosity);
}
else
{ // true error enabled
auto USV_prod = calc_USV_(algoW1, algoW2, better_W1 && better_S?-1:algoW1->nfacts() - err_period, better_W2 && better_S?-1:algoW2->nfacts() - err_period, S);
USV_prod -= dM;
terr = USV_prod.norm();
if(relErr) terr /= M_norm;
if(verbosity)
{
// SVDTJ error
if(! terr_enabled)
{
std::cout << "SVDTJ ";
if(relErr)
std::cout << "relative ";
else
std::cout << "absolute ";
std::cout << "error: ";
std::cout << terr << std::endl;
auto Sd_norm = S.norm();
aerr = Faust::fabs(M_norm * M_norm - Sd_norm * Sd_norm);
terr_enabled = Faust::fabs(tol * tol - aerr) < Faust::fabs(E);
}
else
{ // true error enabled
auto USV_prod = calc_USV_(algoW1, algoW2, better_W1 && better_S?-1:algoW1->nfacts() - err_period, better_W2 && better_S?-1:algoW2->nfacts() - err_period, S);
USV_prod -= dM;
terr = USV_prod.norm();
if(relErr) terr /= M_norm;
if(verbosity)
{
std::cout << "SVDTJ ";
if(relErr)
std::cout << "relative ";
else
std::cout << "absolute ";
std::cout << "error: ";
std::cout << terr << std::endl;
}
}
if(verbosity)
std::cout << "SVDTJ iteration: " << int(k1 / t1) << " singular values approximate square norm error: " << aerr << std::endl;
}
if(verbosity)
std::cout << "SVDTJ iteration: " << int(k1 / t1) << " singular values approximate square norm error: " << aerr << std::endl;
if(prev_aerr > 0 && aerr > prev_aerr && aerr / prev_aerr >= 10)
{
......@@ -585,6 +640,7 @@ namespace Faust
thW1->save_mat_file("/tmp/W1_cpp.mat");
thW2->save_mat_file("/tmp/W2_cpp.mat");
// warning if pinvtj_err is true, then it might be the W1_MW2 times S_inv matrix
W1_MW2.save_to_mat_file("/tmp/W1_MW2_cpp.mat", "W1_MW2_cpp");
MatDense<FPP, Cpu> S1_mat(S.size(), 1, S.getData());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment