Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 251d0961 authored by hhakim's avatar hhakim
Browse files

Test also UpdateCholeskySparse and take it into account in coverage calculation.

parent 3ccde7e7
No related branches found
No related tags found
No related merge requests found
Pipeline #885883 passed
...@@ -33,6 +33,7 @@ CE conda install -c conda-forge -y coverage ...@@ -33,6 +33,7 @@ CE conda install -c conda-forge -y coverage
CE coverage erase # just in case CE coverage erase # just in case
PYFAUST_DIR=$(dirname $(CE python -c "import pyfaust as pf; print(pf.__file__)")) PYFAUST_DIR=$(dirname $(CE python -c "import pyfaust as pf; print(pf.__file__)"))
CE coverage run --source $PYFAUST_DIR misc/test/src/Python/test_FaustPy.py CE coverage run --source $PYFAUST_DIR misc/test/src/Python/test_FaustPy.py
CE coverage run -a --source $PYFAUST_DIR misc/test/src/Python/test_update_cholesky.py
CE coverage run -a --source $PYFAUST_DIR $PYFAUST_DIR/datadl.py /tmp/myfaustdata CE coverage run -a --source $PYFAUST_DIR $PYFAUST_DIR/datadl.py /tmp/myfaustdata
rm -Rf /tmp/myfaustdata rm -Rf /tmp/myfaustdata
#coverage run -a --source $PYFAUST_DIR $PYFAUST_DIR/tests/run.py # only real and cpu, all tests are ran below #coverage run -a --source $PYFAUST_DIR $PYFAUST_DIR/tests/run.py # only real and cpu, all tests are ran below
......
...@@ -7,29 +7,38 @@ if(len(sys.argv) > 1): ...@@ -7,29 +7,38 @@ if(len(sys.argv) > 1):
from pyfaust import * from pyfaust import *
from numpy import empty, allclose, zeros, tril from numpy import empty, allclose, zeros, tril
from scipy.io import loadmat from scipy.io import loadmat
from pyfaust.tools import UpdateCholeskyFull from pyfaust.tools import (UpdateCholeskyFull, UpdateCholesky,
UpdateCholeskySparse)
from pyfaust.demo import get_data_dirpath from pyfaust.demo import get_data_dirpath
from scipy.sparse import issparse
datap = os.path.join(get_data_dirpath(), 'faust_MEG_rcg_8.mat') if __name__ == '__main__':
datap = os.path.join(get_data_dirpath(), 'faust_MEG_rcg_8.mat')
d = loadmat(datap) d = loadmat(datap)
facts = d['facts'] facts = d['facts']
facts = [facts[0,i] for i in range(facts.shape[1]) ] facts = [facts[0,i] for i in range(facts.shape[1]) ]
FD = Faust(facts) FD = Faust(facts)
D = FD.todense() D = FD.toarray()
I = [125, 132, 1000, 155] I = [125, 132, 1000, 155]
for P, Pt in [(lambda x: D*x, lambda x: D.H*x), for P, Pt in [(lambda x: D@x, lambda x: D.T.conj()@x),
(lambda x: np.matrix(FD@x), (lambda x: np.matrix(FD@x),
lambda x: np.matrix(FD.H@x))]: lambda x: np.matrix(FD.H@x))]:
R = empty((0,0)) for s, R in enumerate([empty((0, 0)), csr_matrix((0, 0))]):
for i in range(1, len(I)+1): for i in range(1, len(I)+1):
R = UpdateCholeskyFull(R[0:i,0:i], P, Pt, I[:i], 8193) R = UpdateCholesky(R[:i,:i], P, Pt, I[:i], 8193)
print(R, end='\n\n') print("R:", R.shape, issparse(R))
assert(allclose(D[:,I[:min(i,len(I))]].H*D[:,I[:min(i,len(I))]], R.H*R)) print(R, end='\n\n')
assert(allclose((R.H*R).H, R.H*R)) if s and issparse(R):
assert(allclose(tril(R, -1), zeros(R.shape))) R = R.toarray()
print(D[:,I].H*D[:,I]) assert allclose(D[:,I[:min(i,len(I))]].T.conj()@D[:,I[:min(i,len(I))]],
print(R.H*R) R.T.conj()@R)
assert allclose((R.T.conj()@R).T.conj(), R.T.conj()@R)
assert allclose(tril(R, -1), zeros(R.shape))
print(D[:,I].T.conj()@D[:,I])
print(R.T.conj()*R)
if s:
R = csr_matrix(R)
...@@ -85,6 +85,8 @@ def omp(y, D, maxiter=None, tol=0, relerr=True, verbose=False): ...@@ -85,6 +85,8 @@ def omp(y, D, maxiter=None, tol=0, relerr=True, verbose=False):
tolerr = tol**2 tolerr = tol**2
# try if enough memory # try if enough memory
# TODO: maybe it should be done directly on empty()
# below
try: try:
R = zeros(maxiter) R = zeros(maxiter)
except: except:
...@@ -96,6 +98,8 @@ def omp(y, D, maxiter=None, tol=0, relerr=True, verbose=False): ...@@ -96,6 +98,8 @@ def omp(y, D, maxiter=None, tol=0, relerr=True, verbose=False):
residual = y residual = y
s = s_initial s = s_initial
R = empty((maxiter+1,maxiter+1)).astype(np.complex128) R = empty((maxiter+1,maxiter+1)).astype(np.complex128)
# TODO: why do we use complex, we should do in function of D, y
# likewise H should be T in real case
oldErr = y.T.conj()@y oldErr = y.T.conj()@y
# err_mse = [] # err_mse = []
...@@ -114,9 +118,9 @@ def omp(y, D, maxiter=None, tol=0, relerr=True, verbose=False): ...@@ -114,9 +118,9 @@ def omp(y, D, maxiter=None, tol=0, relerr=True, verbose=False):
I = argmax(abs(DR)) I = argmax(abs(DR))
IN += [I] IN += [I]
# update R # update R
R[0:r_count+1, 0:r_count+1] = UpdateCholeskyFull(R[0:r_count, R[0:r_count+1, 0:r_count+1] = UpdateCholesky(R[0:r_count,
0:r_count], P, Pt, 0:r_count], P, Pt,
IN, m) IN, m)
r_count+=1 r_count+=1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment