Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 0c1d3ecd authored by hhakim's avatar hhakim
Browse files

Update step-by-step_palm4msa.py since it's now possible for the PALM4MSA 2020...

Update step-by-step_palm4msa.py since it's now possible for the PALM4MSA 2020 impl. to use ParamPalm4MSA.init_facts (f78c96a3).
parent 8c26b422
No related branches found
No related tags found
No related merge requests found
...@@ -2,11 +2,13 @@ from pyfaust.fact import palm4msa ...@@ -2,11 +2,13 @@ from pyfaust.fact import palm4msa
from pyfaust.factparams import ParamsPalm4MSA, StoppingCriterion from pyfaust.factparams import ParamsPalm4MSA, StoppingCriterion
from pyfaust.proj import splin, spcol from pyfaust.proj import splin, spcol
import pyfaust import pyfaust
from pyfaust import Faust
from numpy.linalg import norm from numpy.linalg import norm
import numpy as np import numpy as np
num_its = 100 # the total number of iterations to run PALM4MSA num_its = 100 # the total number of iterations to run PALM4MSA
backend = 2016 # or 2020
use_csr = True
# Generate a matrix to factorize # Generate a matrix to factorize
d = 64 d = 64
S = pyfaust.rand(d, d, num_factors=2, density=0.1, per_row=True) S = pyfaust.rand(d, d, num_factors=2, density=0.1, per_row=True)
...@@ -20,42 +22,55 @@ stop_crit = StoppingCriterion(num_its=1) ...@@ -20,42 +22,55 @@ stop_crit = StoppingCriterion(num_its=1)
# pack all these parameters into a ParamsPalm4MSA # pack all these parameters into a ParamsPalm4MSA
p = ParamsPalm4MSA(projs, stop_crit) p = ParamsPalm4MSA(projs, stop_crit)
p.use_csr = use_csr
rel_errs = [] # relative errors for all iterations rel_errs = [] # relative errors for all iterations
# Runs one iteration at a time and compute the relative error # Runs one iteration at a time and compute the relative error
for i in range(num_its): for i in range(num_its):
F, scale = palm4msa(M, p, ret_lambda=True) F, scale = palm4msa(M, p, ret_lambda=True, backend=backend)
# backup the error # backup the error
rel_errs += [(F-M).norm()/np.linalg.norm(M)] rel_errs += [(F-M).norm()/np.linalg.norm(M)]
# retrieve the factors from the Faust F obtained after the one-iteration execution # retrieve the factors from the Faust F obtained after the one-iteration execution
# it's needed to convert them explicitely as Fortran (that is column major # it's needed to convert them explicitely as Fortran (that is column major
# order) numpy array, otherwise it would fail the next iteration # order) numpy array, otherwise it would fail the next iteration
facts = [np.asfortranarray(F.factors(0).toarray()), if isinstance(F.factors(0), np.ndarray):
np.asfortranarray(F.factors(1).toarray())] facts = [np.asfortranarray(F.factors(0)),
np.asfortranarray(F.factors(1))]
elif backend == 2016:
facts = [np.asfortranarray(F.factors(0).toarray()),
np.asfortranarray(F.factors(1).toarray())]
else:
facts = [F.factors(0), F.factors(1)]
# don't bother with np.ndarray/csr_matrix cases, use a Faust
if Faust(facts[0]).norm() > Faust(facts[1]).norm():
facts[0] /= scale
else:
facts[1] /= scale
# F is multiplied by lambda, we need to undo that before continuing to the # F is multiplied by lambda, we need to undo that before continuing to the
# next iteration, the factor which was multiplied by scale can be any of # next iteration, the factor which was multiplied by scale can be any of
# them but because all the others are normalized, it's easy to deduce which # them but because all the others are normalized, it's easy to deduce which
# one it is # one it is
# NOTE: remember we want to set the factors exactly as they were at the end of # NOTE: remember we want to set the factors exactly as they were at the end of
# iteration 0, that's why all must be normalized (by dividind by scale) # iteration 0, that's why all must be normalized (by dividind by scale)
if norm(facts[0]) > norm(facts[1]):
facts[0] /= scale
else:
facts[1] /= scale
# all is ready to pack the parameters and initial state (the factors and # all is ready to pack the parameters and initial state (the factors and
# the scale) in a new ParamsPalm4MSA instance # the scale) in a new ParamsPalm4MSA instance
p = ParamsPalm4MSA(projs, stop_crit, p = ParamsPalm4MSA(projs, stop_crit,
init_facts=facts, init_facts=facts,
init_lambda=scale) init_lambda=scale)
p.use_csr = use_csr
print("relative errors along the iterations (step-by-step mode):", rel_errs) print("relative errors along the iterations (step-by-step mode):", rel_errs)
stop_crit = StoppingCriterion(num_its=num_its) stop_crit = StoppingCriterion(num_its=num_its)
p = ParamsPalm4MSA(projs, stop_crit) p = ParamsPalm4MSA(projs, stop_crit)
G = palm4msa(M, p) p.use_csr = use_csr
G = palm4msa(M, p, backend=backend)
print("Relative error when running all PALM4MSA iterations at once: ", (G-M).norm() / norm(M)) print("Relative error when running all PALM4MSA iterations at once: ", (G-M).norm() / norm(M))
print("Last relative error obtained when running PALM4MSA " print("Last relative error obtained when running PALM4MSA "
"iteration-by-iteration: ", rel_errs[-1]) "iteration-by-iteration: ", rel_errs[-1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment