Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 02fd1166 authored by hhakim's avatar hhakim
Browse files

Secure SVDTJ two-errors strategy in case of very large tol or zero matrix sum.

parent b68577ec
No related branches found
No related tags found
No related merge requests found
...@@ -395,6 +395,7 @@ namespace Faust ...@@ -395,6 +395,7 @@ namespace Faust
int k1 = 0, k2 = 0; // k counts the number of Givens built (t per factor) on each side (for U (W1) and V (W2)) int k1 = 0, k2 = 0; // k counts the number of Givens built (t per factor) on each side (for U (W1) and V (W2))
auto M_norm = dM.norm(); auto M_norm = dM.norm();
bool dM_adj = false; //cf. svdtj_compute_W1H_M_W2_meth1 bool dM_adj = false; //cf. svdtj_compute_W1H_M_W2_meth1
// cf. https://gitlab.inria.fr/faustgrp/faust/-/issues/318 for two-error strategy (approximate lest costly error and exact error)
Real<FPP> aerr = 1; // squared error approximate Real<FPP> aerr = 1; // squared error approximate
Real<FPP> terr = 1; // true error (absolute or relative) Real<FPP> terr = 1; // true error (absolute or relative)
bool terr_enabled = false; bool terr_enabled = false;
...@@ -448,6 +449,12 @@ namespace Faust ...@@ -448,6 +449,12 @@ namespace Faust
if(str_pinvtj_err) if(str_pinvtj_err)
pinvtj_err = bool(std::atoi(str_pinvtj_err)); pinvtj_err = bool(std::atoi(str_pinvtj_err));
if(! pinvtj_err && Faust::fabs(tol * tol) > Faust::fabs(E))
{
std::cerr << "warning: tol² > E, computing exact error from start." << std::endl;
terr_enabled = true;
}
while(loop) while(loop)
{ {
...@@ -565,15 +572,8 @@ namespace Faust ...@@ -565,15 +572,8 @@ namespace Faust
} }
else else
{ {
// SVDTJ error // SVDTJ error
if(! terr_enabled) if(terr_enabled)
{
auto Sd_norm = S.norm();
aerr = Faust::fabs(M_norm * M_norm - Sd_norm * Sd_norm);
terr_enabled = Faust::fabs(tol * tol - aerr) < Faust::fabs(E);
}
else
{ // true error enabled { // true error enabled
auto USV_prod = calc_USV_(algoW1, algoW2, better_W1 && better_S?-1:algoW1->nfacts() - err_period, better_W2 && better_S?-1:algoW2->nfacts() - err_period, S); auto USV_prod = calc_USV_(algoW1, algoW2, better_W1 && better_S?-1:algoW1->nfacts() - err_period, better_W2 && better_S?-1:algoW2->nfacts() - err_period, S);
USV_prod -= dM; USV_prod -= dM;
...@@ -590,6 +590,14 @@ namespace Faust ...@@ -590,6 +590,14 @@ namespace Faust
std::cout << terr << std::endl; std::cout << terr << std::endl;
} }
} }
else
{
// approximate error
auto Sd_norm = S.norm();
aerr = Faust::fabs(M_norm * M_norm - Sd_norm * Sd_norm);
// check if we're good to switch to the exact error
terr_enabled = Faust::fabs(tol * tol - aerr) < Faust::fabs(E);
}
if(verbosity) if(verbosity)
std::cout << "SVDTJ iteration: " << int(k1 / t1) << " singular values approximate square norm error: " << aerr << std::endl; std::cout << "SVDTJ iteration: " << int(k1 / t1) << " singular values approximate square norm error: " << aerr << std::endl;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment