Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 74c0fe07 authored by hhakim's avatar hhakim
Browse files

Add a test for pyfaust.tools.UpdateCholeskyFull() used by OMP-Chol.

Ensures we keep assertions along the updates: R'R == D[:,I].H*D[:,I] and R is upper triangular.
parent aebf2f37
Branches
Tags
No related merge requests found
......@@ -259,6 +259,7 @@ if (BUILD_WRAPPER_PYTHON)
add_test(NAME PYTHON${PY_VER}_FAUST_TIME COMMAND ${PYTHON_EXE} ${FAUST_SRC_TEST_SRC_PYTHON_DIR}/test_pyFaust_time.py ${FAUST_PYTHON_BIN_DIR} ${FAUST_BIN_TEST_FIG_DIR})
add_test(NAME PYTHON${PY_VER}_FAUST_DEMO_INSTALL COMMAND ${PYTHON_EXE} -c "import sys; sys.path += ['${FAUST_PYTHON_BIN_DIR}'];from pyfaust.demo import quickstart; quickstart.run()")
add_test(NAME PYTHON${PY_VER}_FAUST_UNIT_TESTS COMMAND ${PYTHON_EXE} ${FAUST_SRC_TEST_SRC_PYTHON_DIR}/test_FaustPy.py ${FAUST_PYTHON_BIN_DIR}) # second arg. is the FaustPy's dir. to add to PYTHONPATH
add_test(NAME PYTHON${PY_VER}_FAUST_UPDATE_CHOL COMMAND ${PYTHON_EXE} ${FAUST_SRC_TEST_SRC_PYTHON_DIR}/test_update_cholesky.py ${FAUST_PYTHON_BIN_DIR}) # second arg. is the FaustPy's dir. to add to PYTHONPATH
endif(PYTHON_MODULE_SCIPY)
endif()
endforeach()
......
from __future__ import print_function
import sys
import os
if(len(sys.argv) > 1):
sys.path.append(sys.argv[1])
from pyfaust import *
from numpy import empty, allclose, zeros, tril
from scipy.io import loadmat
from pyfaust.tools import UpdateCholeskyFull
from pyfaust.demo import get_data_dirpath
datap = os.path.join(get_data_dirpath(), 'faust_MEG_rcg_8.mat')
d = loadmat(datap)
facts = d['facts']
facts = [facts[0,i] for i in range(facts.shape[1]) ]
FD = Faust(facts)
D = FD.todense()
I = [125, 132, 1000, 155]
for P, Pt in [(lambda x: D*x, lambda x: D.H*x),
(lambda x: np.matrix(FD*x),
lambda x: np.matrix(FD.H*x))]:
R = empty((0,0))
for i in range(1, len(I)+1):
R = UpdateCholeskyFull(R[0:i,0:i], P, Pt, I[:i], 8193)
print(R, end='\n\n')
assert(allclose(D[:,I[:min(i,len(I))]].H*D[:,I[:min(i,len(I))]], R.H*R))
assert(allclose((R.H*R).H, R.H*R))
assert(allclose(tril(R, -1), zeros(R.shape)))
print(D[:,I].H*D[:,I])
print(R.H*R)
......@@ -156,6 +156,9 @@ def UpdateCholeskyFull(R,P,Pt,index,m):
cat((R,new_col),axis=1),
cat((np.zeros((1, R.shape[1])), R_ii),axis=1)
),axis=0)
#assert(np.allclose((R.H*R).H, R.H*R))
#D = P(np.eye(m,m))
#assert(np.allclose(D[:,index].H*D[:,index], R.H*R))
return R
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment