Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 0a1627c4 authored by hhakim's avatar hhakim
Browse files

Update butterfly matfaust and pyfaust wrappers to handle the bbtree kind of...

Update butterfly matfaust and pyfaust wrappers to handle the bbtree kind of factorization + documentation.

Minor changes/fixes in pyfaust.fact.hierarchical_py too.
parent bb5d2640
Branches
Tags
No related merge requests found
%========================================================================== %==========================================================================
%> @brief Factorizes the matrix M according to a butterfly support. %> @brief Factorizes the matrix M according to a butterfly support.
%>
%> @param 'dir', str: the direction of factorization 'right'ward or 'left'ward %> @param 'type', str: the type of factorization 'right'ward, 'left'ward or 'bbtree'.
%> (more precisely: at each stage of the factorization the most right factor or %> More precisely: if 'left' (resp. 'right') is used then at each stage of the
%> the most left factor is split in two). %> factorization the most left factor (resp. the most left factor) is split in two.
%> If 'bbtree' is used then the matrix is factorized according to a Balanced
%> Binary Tree (which is faster as it allows parallelization).
%> %>
%> @retval F the Faust which is an approximate of M according to a butterfly support. %> @retval F the Faust which is an approximate of M according to a butterfly support.
%>
%> @b Example:
%> @code
%> >> import matfaust.wht
%> >> import matfaust.dft
%> >> import matfaust.fact.butterfly
%> >> M = full(wht(32)); % it works with dft too!
%> >> F = butterfly(M, 'type', 'bbtree');
%> >> err = norm(full(F)-M)/norm(M)
%> err =
%>
%> 1.4311e-15
%> @endcode
%========================================================================== %==========================================================================
function F = butterfly(M, varargin) function F = butterfly(M, varargin)
import matfaust.Faust import matfaust.Faust
nargin = length(varargin); nargin = length(varargin);
dir = 'right'; type = 'right';
if(nargin > 0) if(nargin > 0)
for i=1:nargin for i=1:nargin
switch(varargin{i}) switch(varargin{i})
case 'dir' case 'type'
if(nargin < i+1 || ~ any(strcmp(varargin{i+1}, {'right', 'left'}))) if(nargin < i+1 || ~ any(strcmp(varargin{i+1}, {'right', 'left', 'bbtree'})))
error('keyword argument ''dir'' must be followed by ''left'' or ''right''') error('keyword argument ''type'' must be followed by ''left'' or ''right'' or ''bbtree''')
else else
dir = varargin{i+1}; type = varargin{i+1};
end end
end end
end end
end end
if(strcmp(dir, 'right')) if(strcmp(type, 'right'))
dir = 1; type = 1;
elseif(strcmp(dir, 'left')) elseif(strcmp(type, 'left'))
dir = 0; type = 0;
elseif(strcmp(type, 'bbtree'))
type = 2;
end end
if(strcmp(class(M), 'single')) if(strcmp(class(M), 'single'))
core_obj = mexButterflyRealFloat(M, dir); core_obj = mexButterflyRealFloat(M, type);
F = Faust(core_obj, isreal(M), 'cpu', 'float'); F = Faust(core_obj, isreal(M), 'cpu', 'float');
else else
if(isreal(M)) if(isreal(M))
core_obj = mexButterflyReal(M, dir); core_obj = mexButterflyReal(M, type);
else else
core_obj = mexButterflyCplx(M, dir); core_obj = mexButterflyCplx(M, type);
end end
F = Faust(core_obj, isreal(M)); F = Faust(core_obj, isreal(M));
end end
......
...@@ -351,7 +351,7 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False, ...@@ -351,7 +351,7 @@ def hierarchical_py(A, J, N, res_proxs, fac_proxs, is_update_way_R2L=False,
J: number of factors. J: number of factors.
N: number of iterations. N: number of iterations.
""" """
S = Faust([A], dev=dev) S = Faust([A], dev=dev, dtype=A.dtype)
l2_ = 1 l2_ = 1
compute_2norm_on_arrays_ = compute_2norm_on_arrays compute_2norm_on_arrays_ = compute_2norm_on_arrays
for i in range(J-1): for i in range(J-1):
...@@ -400,17 +400,18 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1, ...@@ -400,17 +400,18 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
if(S == 'zero_and_ids'): if(S == 'zero_and_ids'):
# start Faust, identity factors and one zero # start Faust, identity factors and one zero
if(is_update_way_R2L): if(is_update_way_R2L):
S = Faust([np.eye(dims[i][0],dims[i][1]) for i in S = Faust([np.eye(dims[i][0],dims[i][1], dtype=A.dtype) for i in
range(J-1)]+[np.zeros((dims[J-1][0], dims[J-1][1]))], range(J-1)]+[np.zeros((dims[J-1][0], dims[J-1][1]), dtype=A.dtype)],
dev=dev) dev=dev)
else: else:
S = Faust([np.zeros((dims[0][0],dims[0][1]))]+[np.eye(dims[i+1][0], S = Faust([np.zeros((dims[0][0],dims[0][1]), dtype=A.dtype)]+[np.eye(dims[i+1][0],
dims[i+1][1]) dims[i+1][1], dtype=A.dtype)
for i in for i in
range(J-1)], dev=dev) range(J-1)],
dev=dev)
elif(S == None): elif(S == None):
# start Faust, identity factors # start Faust, identity factors
S = Faust([np.eye(dims[i][0], dims[i][1]) for i in range(J)], dev=dev) S = Faust([np.eye(dims[i][0], dims[i][1], dtype=A.dtype) for i in range(J)], dev=dev)
lipschitz_multiplicator=1.001 lipschitz_multiplicator=1.001
for i in range(N): for i in range(N):
if(is_update_way_R2L): if(is_update_way_R2L):
...@@ -421,19 +422,19 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1, ...@@ -421,19 +422,19 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
#print("S=", S) #print("S=", S)
#if(2 == S.numfactors()): #if(2 == S.numfactors()):
if(j == 0): if(j == 0):
L = np.eye(dims[0][0],dims[0][0]) L = np.eye(dims[0][0],dims[0][0], dtype=A.dtype)
S_j = S.factors(j) S_j = S.factors(j)
R = S.right(j+1) R = S.right(j+1)
elif(j == S.numfactors()-1): elif(j == S.numfactors()-1):
L = S.left(j-1) L = S.left(j-1)
S_j = S.factors(j) S_j = S.factors(j)
R = np.eye(dims[j][1], dims[j][1]) R = np.eye(dims[j][1], dims[j][1], dtype=A.dtype)
else: else:
L = S.left(j-1) L = S.left(j-1)
R = S.right(j+1) R = S.right(j+1)
S_j = S.factors(j) S_j = S.factors(j)
if(not pyfaust.isFaust(L)): L = Faust(L, dev=dev) if(not pyfaust.isFaust(L)): L = Faust(L, dev=dev, dtype=A.dtype)
if(not pyfaust.isFaust(R)): R = Faust(R, dev=dev) if(not pyfaust.isFaust(R)): R = Faust(R, dev=dev, dtype=A.dtype)
if(compute_2norm_on_arrays): if(compute_2norm_on_arrays):
c = \ c = \
lipschitz_multiplicator*_lambda**2*norm(R.toarray(),2)**2 * \ lipschitz_multiplicator*_lambda**2*norm(R.toarray(),2)**2 * \
...@@ -458,11 +459,11 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1, ...@@ -458,11 +459,11 @@ def palm4msa_py(A, J, N, proxs, is_update_way_R2L=False, S=None, _lambda=1,
if(use_csr): if(use_csr):
S_j = csr_matrix(S_j) S_j = csr_matrix(S_j)
if(S.numfactors() > 2 and j > 0 and j < S.numfactors()-1): if(S.numfactors() > 2 and j > 0 and j < S.numfactors()-1):
S = L@Faust(S_j, dev=dev)@R S = L@Faust(S_j, dev=dev, dtype=A.dtype)@R
elif(j == 0): elif(j == 0):
S = Faust(S_j, dev=dev)@R S = Faust(S_j, dev=dev, dtype=A.dtype)@R
else: else:
S = L@Faust(S_j, dev=dev) S = L@Faust(S_j, dev=dev, dtype=A.dtype)
_lambda = np.trace(A_H@S).real/S.norm()**2 _lambda = np.trace(A_H@S).real/S.norm()**2
#print("lambda:", _lambda) #print("lambda:", _lambda)
S = _lambda*S S = _lambda*S
...@@ -1186,27 +1187,38 @@ def fgft_palm(U, Lap, p, init_D=None, ret_lambda=False, ret_params=False): ...@@ -1186,27 +1187,38 @@ def fgft_palm(U, Lap, p, init_D=None, ret_lambda=False, ret_params=False):
# experimental block end # experimental block end
def butterfly(M, dir="right"): def butterfly(M, type="right"):
""" """
Factorizes M according to a butterfly support. Factorizes M according to a butterfly support.
Args: Args:
M: the numpy ndarray. The dtype must be float32, float64 M: the numpy ndarray. The dtype must be float32, float64
or complex128 (the dtype might have a large impact on performance). or complex128 (the dtype might have a large impact on performance).
dir: (str) the direction of factorization 'right'ward or 'left'ward type: (str) the type of factorization 'right'ward, 'left'ward or
(more precisely: at each stage of the factorization the most right factor or 'bbtree'. More precisely: if 'left' (resp. 'right') is used then at each stage of the
the most left factor is split in two). factorization the most left factor (resp. the most left factor) is split in two.
If 'bbtree' is used then the matrix is factorized according to a Balanced
Binary Tree (which is faster as it allows parallelization).
Returns: Returns:
The Faust which is an approximate of M according to a butterfly support. The Faust which is an approximate of M according to a butterfly support.
Example:
>>> from pyfaust import Faust, wht, dft
>>> from pyfaust.fact import butterfly
>>> H = wht(32).toarray() # it works with dft too!
>>> F = butterfly(H, dir='bbtree')
>>> (F-M).norm()/Faust(M).norm()
1.0272844187006565e-15
""" """
is_real = np.empty((1,)) is_real = np.empty((1,))
M = _check_fact_mat('butterfly()', M, is_real) M = _check_fact_mat('butterfly()', M, is_real)
if is_real: if is_real:
is_float = M.dtype == 'float32' is_float = M.dtype == 'float32'
if is_float: if is_float:
return Faust(core_obj=_FaustCorePy.FaustAlgoGenFlt.butterfly_hierarchical(M, dir)) return Faust(core_obj=_FaustCorePy.FaustAlgoGenFlt.butterfly_hierarchical(M, type))
else: else:
return Faust(core_obj=_FaustCorePy.FaustAlgoGenDbl.butterfly_hierarchical(M, dir)) return Faust(core_obj=_FaustCorePy.FaustAlgoGenDbl.butterfly_hierarchical(M, type))
else: else:
return Faust(core_obj=_FaustCorePy.FaustAlgoGenCplxDbl.butterfly_hierarchical(M, dir)) return Faust(core_obj=_FaustCorePy.FaustAlgoGenCplxDbl.butterfly_hierarchical(M, type))
...@@ -567,6 +567,8 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -567,6 +567,8 @@ cdef class FaustAlgoGen@TYPE_NAME@:
dir = 1 dir = 1
elif dir == "left": elif dir == "left":
dir = 0 dir = 0
elif dir == "bbtree":
dir = 2
else: else:
raise ValueError("dir argument must be 'right' or 'left'.") raise ValueError("dir argument must be 'right' or 'left'.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment