Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 7c64228e authored by hhakim's avatar hhakim
Browse files

Implement pyfaust.fact.hierarchical_py (pure py version).

parent f29315d2
Branches
Tags
No related merge requests found
...@@ -174,21 +174,20 @@ def svdtj(M, nGivens=None, tol=0, order='ascend', relerr=True, ...@@ -174,21 +174,20 @@ def svdtj(M, nGivens=None, tol=0, order='ascend', relerr=True,
def eigtj(M, nGivens=None, tol=0, order='ascend', relerr=True, def eigtj(M, nGivens=None, tol=0, order='ascend', relerr=True,
nGivens_per_fac=None, verbosity=0, enable_large_Faust=False): nGivens_per_fac=None, verbosity=0, enable_large_Faust=False):
""" """
Performs an approximate eigendecomposition of M and returns the eigenvalues in W along with the corresponding left eigenvectors (as the columns of the Faust object V). Performs an approximate eigendecomposition of M and returns the eigenvalues
in W along with the corresponding right eigenvectors (as the columns of the Faust object V).
The output is such that V*numpy.diag(W)*V.H approximates M. V is a product The output is such that V*numpy.diag(W)*V.H approximates M. V is a product
of Givens rotations obtained by truncating the Jacobi algorithm. of Givens rotations obtained by truncating the Jacobi algorithm.
The trade-off between accuracy and sparsity can be set through the The trade-off between accuracy and complexity of V can be set through the
parameters nGivens and nGivens_per_fac or concurrently with the arguments parameters nGivens and tol that define number of Givens rotations and targeted error.
tol and relerr that define the targeted error.
Args: Args:
M: (numpy.ndarray or csr_matrix) the matrix to diagonalize. Must be M: (numpy.ndarray or csr_matrix) the matrix to diagonalize. Must be
real and symmetric, or complex hermitian. Can be in dense or sparse format. real and symmetric, or complex hermitian. Can be in dense or sparse format.
nGivens: (int) the number of Givens rotations that can be computed in nGivens: (int) targeted number of Givens (this argument is optional
eigenvector transform V. only if tol is set).
The number of rotations per factor of V is defined by nGivens_per_fac.
tol: (float) the tolerance error at which the algorithm stops. By default, tol: (float) the tolerance error at which the algorithm stops. By default,
it's zero for not stopping on an error criterion. Note that the error it's zero for not stopping on an error criterion. Note that the error
reaching is not guaranteed (in particular, if the error starts to reaching is not guaranteed (in particular, if the error starts to
...@@ -312,15 +311,51 @@ def _check_fact_mat(funcname, M): ...@@ -312,15 +311,51 @@ def _check_fact_mat(funcname, M):
# experimental block start # experimental block start
def palm4msa_py(A, J, N, proxs): def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
is_fact_side_left=False):
S = Faust([A])
l2_ = 1
for i in range(J-1):
if(is_fact_side_left):
Si = S.factors(0)
else:
Si = S.factors(i)
Si, l_ = palm4msa_py(Si, 2, N, [fac_proxs[i], res_proxs[i]], is_update_way_R2L,
S='zero_and_ids', _lambda=1)
l2_ *= l_
if i > 1:
S = S.left(i-1)*Si
elif i > 0:
S = Faust(S.left(i-1))*Si
else: # i == 0
S = Si
S = S*1/l_
S,l2_ = palm4msa_py(A, S.numfactors(), N, [p for p in
[*fac_proxs[0:i+1],
res_proxs[i]]],
is_update_way_R2L, S=S, _lambda=l2_)
S = S*1/l2_
S = S*l2_
return S
def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1):
dims = [(prox.constraint._num_rows, prox.constraint._num_cols) for prox in dims = [(prox.constraint._num_rows, prox.constraint._num_cols) for prox in
proxs ] proxs ]
# start Faust, identity factors # start Faust, identity factors
S = Faust([np.eye(dims[i][0], dims[i][1]) for i in range(J)]) if(S == 'zero_and_ids'):
_lambda = 1 if(is_update_way_R2L):
S = Faust([np.eye(dims[i][0],dims[i][1]) for i in range(J-1)]+[np.zeros((dims[J-1][0], dims[J-1][1]))])
else:
S = Faust([np.zeros((dims[0][0],dims[0][1]))]+[np.eye(dims[i+1][0], dims[i+1][1]) for i in range(J-1)])
elif(S == None):
S = Faust([np.eye(dims[i][0], dims[i][1]) for i in range(J)])
lipschitz_multiplicator=1.001 lipschitz_multiplicator=1.001
for i in range(N): for i in range(N):
for j in range(0,J): if(is_update_way_R2L):
iter_ = reversed(range(J))
else:
iter_ = range(J)
for j in iter_:
#print("S=", S) #print("S=", S)
#if(2 == S.numfactors()): #if(2 == S.numfactors()):
if(j == 0): if(j == 0):
...@@ -341,6 +376,8 @@ def palm4msa_py(A, J, N, proxs): ...@@ -341,6 +376,8 @@ def palm4msa_py(A, J, N, proxs):
lipschitz_multiplicator*_lambda**2*R.norm(2,max_num_its=1000, lipschitz_multiplicator*_lambda**2*R.norm(2,max_num_its=1000,
treshold=1e-16)**2 * \ treshold=1e-16)**2 * \
L.norm(2,max_num_its=1000, treshold=1e-16)**2 L.norm(2,max_num_its=1000, treshold=1e-16)**2
# if c == 0:
# c = 1/1e-16
#print("j=",j, #print("j=",j,
# (S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c).shape) # (S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c).shape)
S_j = proxs[j](S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c) S_j = proxs[j](S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment