Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 33b0445d authored by hhakim's avatar hhakim
Browse files

Fix experimental pyfaust hierarchical_py/palm4msa_py to handle the...

Fix experimental pyfaust hierarchical_py/palm4msa_py to handle the is_fact_side_left=True (typically for the MEG matrix) that wasn't managed before.
parent 864074a6
No related branches found
No related tags found
No related merge requests found
...@@ -360,9 +360,11 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False, ...@@ -360,9 +360,11 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
print("hierarchical_py factor", i+1) print("hierarchical_py factor", i+1)
if(is_fact_side_left): if(is_fact_side_left):
Si = S.factors(0) Si = S.factors(0)
split_proxs = [res_proxs[i], fac_proxs[i]]
else: else:
Si = S.factors(i) Si = S.factors(i)
Si, l_ = palm4msa_py(Si, 2, N, [fac_proxs[i], res_proxs[i]], is_update_way_R2L, split_proxs = [fac_proxs[i], res_proxs[i]]
Si, l_ = palm4msa_py(Si, 2, N, split_proxs, is_update_way_R2L,
S='zero_and_ids', _lambda=1, S='zero_and_ids', _lambda=1,
compute_2norm_on_arrays=compute_2norm_on_arrays_, compute_2norm_on_arrays=compute_2norm_on_arrays_,
norm2_max_iter=norm2_max_iter, norm2_max_iter=norm2_max_iter,
...@@ -370,15 +372,26 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False, ...@@ -370,15 +372,26 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
dev=dev) dev=dev)
l2_ *= l_ l2_ *= l_
if i > 1: if i > 1:
S = S.left(i-1)@Si if is_fact_side_left:
S = Si@S.right(1)
else:
S = S.left(i-1)@Si
elif i > 0: elif i > 0:
S = Faust(S.left(i-1), dev=dev)@Si if is_fact_side_left:
S = Si@Faust(S.right(1), dev=dev)
else:
S = Faust(S.left(i-1), dev=dev)@Si
else: # i == 0 else: # i == 0
S = Si S = Si
S = S*1/l_ S = S*1/l_
S,l2_ = palm4msa_py(A, S.numfactors(), N, [p for p in if is_fact_side_left:
[*fac_proxs[0:i+1], fp = [*fac_proxs[0:i+1]]
res_proxs[i]]], fp = list(reversed(fp))
n_proxs = [res_proxs[i], *fp]
else:
n_proxs = [*fac_proxs[0:i+1],
res_proxs[i]]
S,l2_ = palm4msa_py(A, S.numfactors(), N, n_proxs,
is_update_way_R2L, S=S, _lambda=l2_, is_update_way_R2L, S=S, _lambda=l2_,
compute_2norm_on_arrays=compute_2norm_on_arrays_, compute_2norm_on_arrays=compute_2norm_on_arrays_,
norm2_max_iter=norm2_max_iter, norm2_max_iter=norm2_max_iter,
...@@ -422,13 +435,13 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1, ...@@ -422,13 +435,13 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
#print("S=", S) #print("S=", S)
#if(2 == S.numfactors()): #if(2 == S.numfactors()):
if(j == 0): if(j == 0):
L = np.eye(dims[0][0],dims[0][0], dtype=A.dtype)
S_j = S.factors(j) S_j = S.factors(j)
R = S.right(j+1) R = S.right(j+1)
L = np.eye(S_j.shape[0], S_j.shape[0], dtype=A.dtype)
elif(j == S.numfactors()-1): elif(j == S.numfactors()-1):
L = S.left(j-1)
S_j = S.factors(j) S_j = S.factors(j)
R = np.eye(dims[j][1], dims[j][1], dtype=A.dtype) R = np.eye(S_j.shape[1], S_j.shape[1], dtype=A.dtype)
L = S.left(j-1)
else: else:
L = S.left(j-1) L = S.left(j-1)
R = S.right(j+1) R = S.right(j+1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment