Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 4f8dc333 authored by hhakim's avatar hhakim
Browse files

Add dev argument to experimental 'pure' python hierarchical_py/palm4msa_py to...

Add dev argument to experimental 'pure' python hierarchical_py/palm4msa_py to possibily use gpu wrapper (not tested).
parent 0bff26ba
No related branches found
No related tags found
No related merge requests found
Pipeline #833779 skipped
...@@ -295,8 +295,14 @@ def _check_fact_mat(funcname, M): ...@@ -295,8 +295,14 @@ def _check_fact_mat(funcname, M):
# experimental block start # experimental block start
def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False, def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
is_fact_side_left=False, compute_2norm_on_arrays=False, is_fact_side_left=False, compute_2norm_on_arrays=False,
norm2_max_iter=100, norm2_threshold=1e-6, use_csr=True): norm2_max_iter=100, norm2_threshold=1e-6, use_csr=True,
S = Faust([A]) dev='cpu'):
"""
Args:
J: number of factors.
N: number of iterations.
"""
S = Faust([A], dev=dev)
l2_ = 1 l2_ = 1
compute_2norm_on_arrays_ = compute_2norm_on_arrays compute_2norm_on_arrays_ = compute_2norm_on_arrays
for i in range(J-1): for i in range(J-1):
...@@ -311,22 +317,24 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False, ...@@ -311,22 +317,24 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
S='zero_and_ids', _lambda=1, S='zero_and_ids', _lambda=1,
compute_2norm_on_arrays=compute_2norm_on_arrays_, compute_2norm_on_arrays=compute_2norm_on_arrays_,
norm2_max_iter=norm2_max_iter, norm2_max_iter=norm2_max_iter,
norm2_threshold=norm2_threshold, use_csr=use_csr) norm2_threshold=norm2_threshold, use_csr=use_csr,
dev=dev)
l2_ *= l_ l2_ *= l_
if i > 1: if i > 1:
S = S.left(i-1)*Si S = S.left(i-1)@Si
elif i > 0: elif i > 0:
S = Faust(S.left(i-1))*Si S = Faust(S.left(i-1), dev=dev)@Si
else: # i == 0 else: # i == 0
S = Si S = Si
S = S*1/l_ S = S*1/l_
S,l2_ = palm4msa_py(A, S.numfactors(), N, [p for p in S,l2_ = palm4msa_py(A, S.numfactors(), N, [p for p in
[*fac_proxs[0:i+1], [*fac_proxs[0:i+1],
res_proxs[i]]], res_proxs[i]]],
is_update_way_R2L, S=S, _lambda=l2_, is_update_way_R2L, S=S, _lambda=l2_,
compute_2norm_on_arrays=compute_2norm_on_arrays_, compute_2norm_on_arrays=compute_2norm_on_arrays_,
norm2_max_iter=norm2_max_iter, norm2_max_iter=norm2_max_iter,
norm2_threshold=norm2_threshold, use_csr=use_csr) norm2_threshold=norm2_threshold, use_csr=use_csr,
dev=dev)
S = S*1/l2_ S = S*1/l2_
S = S*l2_ S = S*l2_
return S return S
...@@ -334,7 +342,7 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False, ...@@ -334,7 +342,7 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
from numpy.linalg import norm from numpy.linalg import norm
def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1, def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
compute_2norm_on_arrays=False, norm2_max_iter=100, compute_2norm_on_arrays=False, norm2_max_iter=100,
norm2_threshold=1e-6, use_csr=True): norm2_threshold=1e-6, use_csr=True, dev='cpu'):
dims = [(prox.constraint._num_rows, prox.constraint._num_cols) for prox in dims = [(prox.constraint._num_rows, prox.constraint._num_cols) for prox in
proxs ] proxs ]
A_H = A.T.conj() A_H = A.T.conj()
...@@ -343,12 +351,17 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1, ...@@ -343,12 +351,17 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
if(S == 'zero_and_ids'): if(S == 'zero_and_ids'):
# start Faust, identity factors and one zero # start Faust, identity factors and one zero
if(is_update_way_R2L): if(is_update_way_R2L):
S = Faust([np.eye(dims[i][0],dims[i][1]) for i in range(J-1)]+[np.zeros((dims[J-1][0], dims[J-1][1]))]) S = Faust([np.eye(dims[i][0],dims[i][1]) for i in
range(J-1)]+[np.zeros((dims[J-1][0], dims[J-1][1]))],
dev=dev)
else: else:
S = Faust([np.zeros((dims[0][0],dims[0][1]))]+[np.eye(dims[i+1][0], dims[i+1][1]) for i in range(J-1)]) S = Faust([np.zeros((dims[0][0],dims[0][1]))]+[np.eye(dims[i+1][0],
dims[i+1][1])
for i in
range(J-1)], dev=dev)
elif(S == None): elif(S == None):
# start Faust, identity factors # start Faust, identity factors
S = Faust([np.eye(dims[i][0], dims[i][1]) for i in range(J)]) S = Faust([np.eye(dims[i][0], dims[i][1]) for i in range(J)], dev=dev)
lipschitz_multiplicator=1.001 lipschitz_multiplicator=1.001
for i in range(N): for i in range(N):
if(is_update_way_R2L): if(is_update_way_R2L):
...@@ -370,8 +383,8 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1, ...@@ -370,8 +383,8 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
L = S.left(j-1) L = S.left(j-1)
R = S.right(j+1) R = S.right(j+1)
S_j = S.factors(j) S_j = S.factors(j)
if(not pyfaust.isFaust(L)): L = Faust(L) if(not pyfaust.isFaust(L)): L = Faust(L, dev=dev)
if(not pyfaust.isFaust(R)): R = Faust(R) if(not pyfaust.isFaust(R)): R = Faust(R, dev=dev)
if(compute_2norm_on_arrays): if(compute_2norm_on_arrays):
c = \ c = \
lipschitz_multiplicator*_lambda**2*norm(R.toarray(),2)**2 * \ lipschitz_multiplicator*_lambda**2*norm(R.toarray(),2)**2 * \
...@@ -389,19 +402,19 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1, ...@@ -389,19 +402,19 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
if(not isinstance(S_j, np.ndarray)): # equiv. to use_csr except if(not isinstance(S_j, np.ndarray)): # equiv. to use_csr except
# maybe for the first iteration # maybe for the first iteration
S_j = S_j.toarray() S_j = S_j.toarray()
D = S_j-_lambda*(L.H*(_lambda*L*(S_j*R)-A)*R.H)*1/c D = S_j-_lambda*(L.H@(_lambda*L@(S_j@R)-A)@R.H)*1/c
if(not isinstance(D, np.ndarray)): if(not isinstance(D, np.ndarray)):
D = D.toarray() D = D.toarray()
S_j = proxs[j](D) S_j = proxs[j](D)
if(use_csr): if(use_csr):
S_j = csr_matrix(S_j) S_j = csr_matrix(S_j)
if(S.numfactors() > 2 and j > 0 and j < S.numfactors()-1): if(S.numfactors() > 2 and j > 0 and j < S.numfactors()-1):
S = L*Faust(S_j)*R S = L@Faust(S_j, dev=dev)@R
elif(j == 0): elif(j == 0):
S = Faust(S_j)*R S = Faust(S_j, dev=dev)@R
else: else:
S = L*Faust(S_j) S = L@Faust(S_j, dev=dev)
_lambda = np.trace(A_H*S).real/S.norm()**2 _lambda = np.trace(A_H@S).real/S.norm()**2
#print("lambda:", _lambda) #print("lambda:", _lambda)
S = _lambda*S S = _lambda*S
return S, _lambda return S, _lambda
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment