Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 38b10b94 authored by hhakim's avatar hhakim
Browse files

Add a unit test for pyfaust.factparams.ParamsPalm4MSA (issue #81).

parent c2a291f8
Branches
Tags
No related merge requests found
...@@ -818,6 +818,112 @@ class TestFaustFactory(unittest.TestCase): ...@@ -818,6 +818,112 @@ class TestFaustFactory(unittest.TestCase):
# misc/test/src/C++/test_palm4MSA.cpp # misc/test/src/C++/test_palm4MSA.cpp
self.assertAlmostEqual(norm(E,"fro")/norm(M,"fro"), 0.270954109668, places=4) self.assertAlmostEqual(norm(E,"fro")/norm(M,"fro"), 0.270954109668, places=4)
def testParamsPalm4MSA(self):
# call Palm4MSA specifying params
from os import dup2, pipe # for
from pyfaust.fact import palm4msa
from pyfaust.factparams import (ParamsPalm4MSA, ConstraintList,
StoppingCriterion,
ConstraintInt,
ConstraintReal)
import numpy as np
from tempfile import gettempdir
from os.path import join
M = np.random.rand(500, 32)
cons = ConstraintList('splin', 5, 500, 32, 'normcol', 1.0, 32, 32)
# or alternatively using pyfaust.proj
# from pyfaust.proj import splin, normcol
# cons = [ splin((500,32), 5), normcol((32,32), 1.0)]
stop_crit = StoppingCriterion(num_its=200)
param = ParamsPalm4MSA(cons, stop_crit)
param.is_verbose = True
param.grad_calc_opt_mode = 1
tmp_dir = gettempdir()
tmp_file = join(tmp_dir, "verbose_output_of_palm4msa_test")
f = open(tmp_file, 'w')
dup2(1,4)
dup2(f.fileno(), 1)
F = palm4msa(M, param)
print()
f.close()
dup2(4,1)
# retrieve the params effectively used from C++ core output
# reconstruct a ParamsPalm4MSA from the values found
param_test = ParamsPalm4MSA(cons, stop_crit)
param_test.constraints = []
for line in open(tmp_file, 'r').readlines():
print(line, end='')
if(line.startswith('NFACTS')):
param_test.num_facts = int(line.split(':')[-1].strip())
if(line.startswith('VERBOSE')):
param_test.is_verbose = bool(int(line.split(':')[-1].strip()))
if(line.startswith('UPDATEWAY')):
param_test.is_update_way_R2L = \
bool(int(line.split(':')[-1].strip()))
if(line.startswith('INIT_LAMBDA')):
param_test.init_lambda = float(line.split(':')[-1].strip())
if(line.startswith('ISCONSTANTSTEPSIZE')):
param_test.constant_step_size = \
bool(int(line.split(':')[-1].strip()))
if(line.startswith('step_size')):
param_test.step_size = float(line.split(':')[-1].strip())
if(line.startswith('gradCalcOptMode')):
param_test.grad_calc_opt_mode = int(line.split(':')[-1].strip())
if(line.startswith('use_csr')):
param_test.use_csr = bool(int(line.split(':')[-1].strip()))
if(line.startswith('packing_RL')):
param_test.packing_RL = bool(int(line.split(':')[-1].strip()))
if(line.startswith('errorThreshold')):
param_test.stop_crit.tol = float(line.split(':')[-1].strip())
param_test.stop_crit._is_criterion_error = True
if(line.startswith('nb_it')):
param_test.stop_crit.num_its = \
bool(int(line.split(':')[-1].strip()))
if(line.startswith('maxIteration')):
param_test.stop_crit.maxiter = \
bool(int(line.split(':')[-1].strip()))
if(line.startswith('type_cont')):
colon_fields = line.split(':')
const_type_name = colon_fields[1].strip()
if(const_type_name.startswith('INT CONSTRAINT_NAME_SPLIN')):
nrows = int(colon_fields[2].split(' ')[0].strip())
ncols = int(colon_fields[3].split(' ')[0].strip())
cons_val = int(colon_fields[-1].strip())
cons = ConstraintInt('splin', nrows, ncols, cons_val)
param_test.constraints += [cons]
if(const_type_name.startswith('FAUST_REAL CONSTRAINT_NAME_NORMCOL')):
nrows = int(colon_fields[2].split(' ')[0].strip())
ncols = int(colon_fields[3].split(' ')[0].strip())
cons_val = float(colon_fields[-1].strip())
cons = ConstraintReal('normcol', nrows, ncols, cons_val)
param_test.constraints += [cons]
# compare original param instance and the reconstructed one
self.assertEqual(param_test.num_facts, param.num_facts)
self.assertEqual(param_test.is_verbose, param.is_verbose)
self.assertEqual(param_test.is_update_way_R2L, param.is_update_way_R2L)
self.assertEqual(param_test.init_lambda, param.init_lambda)
self.assertEqual(param_test.constant_step_size,
param.constant_step_size)
self.assertEqual(param_test.step_size, param.step_size)
self.assertEqual(param_test.grad_calc_opt_mode, param.grad_calc_opt_mode)
self.assertEqual(param_test.use_csr, param.use_csr)
self.assertEqual(param_test.packing_RL, param.packing_RL)
self.assertEqual(param_test.stop_crit.maxiter, param.stop_crit.maxiter)
self.assertEqual(param_test.stop_crit._is_criterion_error,
param.stop_crit._is_criterion_error)
self.assertEqual(param_test.stop_crit.num_its,
param.stop_crit.num_its)
self.assertEqual(len(param_test.constraints), len(param.constraints))
for i in range(len(param_test.constraints)):
self.assertEqual(param_test.constraints[i]._num_rows,
(param.constraints[i]._num_rows))
self.assertEqual(param_test.constraints[i]._num_cols,
(param.constraints[i]._num_cols))
self.assertEqual(param_test.constraints[i].name,
param_test.constraints[i].name)
self.assertEqual(param_test.constraints[i]._cons_value,
param_test.constraints[i]._cons_value)
def testFactPalm4MSA2020(self): def testFactPalm4MSA2020(self):
from pyfaust.fact import palm4msa from pyfaust.fact import palm4msa
print("Test pyfaust.fact.palm4msa2020()") print("Test pyfaust.fact.palm4msa2020()")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment