Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 681cd4b1 authored by hhakim's avatar hhakim
Browse files

Patch palm4msa_py with an option compute_2norm_on_arrays that avoids the nan...

Patch palm4msa_py with an option compute_2norm_on_arrays that avoids the nan norm issue 131 (it occurs for example when trying to factorize a hadamard matrix of order 128 with hierarchica_py).
parent f62b54b9
Branches
No related tags found
No related merge requests found
...@@ -329,16 +329,18 @@ def _check_fact_mat(funcname, M): ...@@ -329,16 +329,18 @@ def _check_fact_mat(funcname, M):
# experimental block start # experimental block start
def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False, def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
is_fact_side_left=False): is_fact_side_left=False, compute_2norm_on_arrays=False):
S = Faust([A]) S = Faust([A])
l2_ = 1 l2_ = 1
for i in range(J-1): for i in range(J-1):
print("hierarchical_py factor", i+1)
if(is_fact_side_left): if(is_fact_side_left):
Si = S.factors(0) Si = S.factors(0)
else: else:
Si = S.factors(i) Si = S.factors(i)
Si, l_ = palm4msa_py(Si, 2, N, [fac_proxs[i], res_proxs[i]], is_update_way_R2L, Si, l_ = palm4msa_py(Si, 2, N, [fac_proxs[i], res_proxs[i]], is_update_way_R2L,
S='zero_and_ids', _lambda=1) S='zero_and_ids', _lambda=1,
compute_2norm_on_arrays=compute_2norm_on_arrays)
l2_ *= l_ l2_ *= l_
if i > 1: if i > 1:
S = S.left(i-1)*Si S = S.left(i-1)*Si
...@@ -350,21 +352,28 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False, ...@@ -350,21 +352,28 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
S,l2_ = palm4msa_py(A, S.numfactors(), N, [p for p in S,l2_ = palm4msa_py(A, S.numfactors(), N, [p for p in
[*fac_proxs[0:i+1], [*fac_proxs[0:i+1],
res_proxs[i]]], res_proxs[i]]],
is_update_way_R2L, S=S, _lambda=l2_) is_update_way_R2L, S=S, _lambda=l2_,
compute_2norm_on_arrays=compute_2norm_on_arrays)
S = S*1/l2_ S = S*1/l2_
S = S*l2_ S = S*l2_
return S return S
def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1): from numpy.linalg import norm
def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
compute_2norm_on_arrays=False):
dims = [(prox.constraint._num_rows, prox.constraint._num_cols) for prox in dims = [(prox.constraint._num_rows, prox.constraint._num_cols) for prox in
proxs ] proxs ]
# start Faust, identity factors A_H = A.T.conj()
if(not isinstance(A_H, np.ndarray)):
A_H = A_H.toarray()
if(S == 'zero_and_ids'): if(S == 'zero_and_ids'):
# start Faust, identity factors and one zero
if(is_update_way_R2L): if(is_update_way_R2L):
S = Faust([np.eye(dims[i][0],dims[i][1]) for i in range(J-1)]+[np.zeros((dims[J-1][0], dims[J-1][1]))]) S = Faust([np.eye(dims[i][0],dims[i][1]) for i in range(J-1)]+[np.zeros((dims[J-1][0], dims[J-1][1]))])
else: else:
S = Faust([np.zeros((dims[0][0],dims[0][1]))]+[np.eye(dims[i+1][0], dims[i+1][1]) for i in range(J-1)]) S = Faust([np.zeros((dims[0][0],dims[0][1]))]+[np.eye(dims[i+1][0], dims[i+1][1]) for i in range(J-1)])
elif(S == None): elif(S == None):
# start Faust, identity factors
S = Faust([np.eye(dims[i][0], dims[i][1]) for i in range(J)]) S = Faust([np.eye(dims[i][0], dims[i][1]) for i in range(J)])
lipschitz_multiplicator=1.001 lipschitz_multiplicator=1.001
for i in range(N): for i in range(N):
...@@ -389,22 +398,30 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1): ...@@ -389,22 +398,30 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1):
S_j = S.factors(j) S_j = S.factors(j)
if(not pyfaust.isFaust(L)): L = Faust(L) if(not pyfaust.isFaust(L)): L = Faust(L)
if(not pyfaust.isFaust(R)): R = Faust(R) if(not pyfaust.isFaust(R)): R = Faust(R)
c = \ if(compute_2norm_on_arrays):
lipschitz_multiplicator*_lambda**2*R.norm(2,max_num_its=1000, c = \
treshold=1e-16)**2 * \ lipschitz_multiplicator*_lambda**2*norm(R.toarray(),2)**2 * \
L.norm(2,max_num_its=1000, treshold=1e-16)**2 norm(L.toarray(),2)**2
# if c == 0:
# c = 1/1e-16 else:
#print("j=",j, c = \
# (S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c).shape) lipschitz_multiplicator*_lambda**2*R.norm(2,max_num_its=1000,
S_j = proxs[j](S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c) treshold=1e-16)**2 * \
L.norm(2,max_num_its=1000, treshold=1e-16)**2
if(np.isnan(c) or c == 0):
raise Exception("Failed to compute c (inverse of descent step),"
"it could be because of the Faust 2-norm error,"
"try option compute_2norm_on_arrays=True")
S_j = \
proxs[j](S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c)
#csr_matrix(proxs[j](S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c))
if(S.numfactors() > 2 and j > 0 and j < S.numfactors()-1): if(S.numfactors() > 2 and j > 0 and j < S.numfactors()-1):
S = L*Faust(S_j)*R S = L*Faust(S_j)*R
elif(j == 0): elif(j == 0):
S = Faust(S_j)*R S = Faust(S_j)*R
else: else:
S = L*Faust(S_j) S = L*Faust(S_j)
_lambda = np.trace(A.T.conj()*S).real/S.norm()**2 _lambda = np.trace(A_H*S).real/S.norm()**2
#print("lambda:", _lambda) #print("lambda:", _lambda)
S = _lambda*S S = _lambda*S
return S, _lambda return S, _lambda
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment