Mentions légales du service

Skip to content
Snippets Groups Projects
Commit f2d31429 authored by Nicolas Bellot's avatar Nicolas Bellot Committed by hhakim
Browse files

legere modification wrapper matlab et demo

parent b6d4e003
Branches
Tags
No related merge requests found
......@@ -81,12 +81,16 @@ X = X_fixed;
%X_norm = X;
X_norm = X./repmat(sqrt(sum(X.^2,1)),size(X,1),1);
%Loading of the MEG matrix approximations
%% Loading of the MEG matrix approximations
% wrapper C++ faust
MEG_faustS=cell(1,nb_approx_MEG);
faustS_mult=cell(1,nb_approx_MEG);
trans_faustS_mult=cell(1,nb_approx_MEG);
% matlab faust (no C++ Faust Toolbox code is running)
matlab_faustS_mult=cell(1,nb_approx_MEG);
matlab_trans_faustS_mult=cell(1,nb_approx_MEG);
for i=1:nb_approx_MEG
......@@ -105,23 +109,37 @@ for i=1:nb_approx_MEG
sp_facts = faust_transpose(sp_facts_trans);
trans_fc=matlab_faust(facts);
fc=transpose(trans_fc);
matlab_trans_faustS_mult{i}=@(x) f_mult(sp_facts_trans,x);
matlab_faustS_mult{i}=@(x) f_mult(sp_facts,x);
trans_faustS_mult{i}=@(x) trans_fc*x;
faustS_mult{i}=@(x) fc*x;
MEG_faustS{i}=trans_fc;
% matlab faust (no C++ Faust Toolbox code is running)
matlab_trans_faustS_mult{i}=@(x) f_mult(sp_facts_trans,x); % function handle
matlab_faustS_mult{i}=@(x) f_mult(sp_facts,x); % function handle
% wrapper C++ faust
trans_faustS_mult{i}=@(x) trans_fc*x;% function handle
faustS_mult{i}=@(x) fc*x; % function handle
MEG_faustS{i}=fc; % store the different faust approximations
end
M=size(X_norm,2);
params.Ntraining = 20; % Number of training vectors
params.Sparsity = 2; % Number of sources per training vector
params.dist_paliers = [0.01,0.05,0.08,0.5];
Ntraining = params.Ntraining;
Sparsity = params.Sparsity;
dist_paliers = params.dist_paliers;
solver_choice='omp';
%solver_choice='iht';
Ntraining = 500; % Number of training vectors
Sparsity = 2; % Number of sources per training vector
dist_paliers = [0.01,0.05,0.08]; dist_paliers = [dist_paliers, 0.5];
resDist = zeros(nb_approx_MEG+1,numel(dist_paliers)-1,Sparsity,Ntraining); % (Matrice,m�thode,dist_sources,src_nb,run);
compute_Times = zeros(nb_approx_MEG+1,numel(dist_paliers)-1,Ntraining);
resDist_matlab = zeros(nb_approx_MEG+1,numel(dist_paliers)-1,Sparsity,Ntraining); % (Matrice,m�thode,dist_sources,src_nb,run);
compute_Times_matlab = zeros(nb_approx_MEG+1,numel(dist_paliers)-1,Ntraining);
for k=1:numel(dist_paliers)-1
disp(['k=' num2str(k) '/' num2str(numel(dist_paliers)-1)])
%Parameters settings
......@@ -139,16 +157,16 @@ for k=1:numel(dist_paliers)-1
sol_omp = zeros(size(Gamma));
sol_omp_hat = zeros(size(Gamma));
sol_omp_hat2 = zeros(size(Gamma));
err_omp = zeros(2,Ntraining);
err_omp_hat = zeros(2,Ntraining);
err_omp_hat2 = zeros(2,Ntraining);
diff_omp = zeros(2,Ntraining);
dist_omp = zeros(Sparsity,Ntraining);
dist_omp_hat = zeros(Sparsity,Ntraining);
dist_omp_hat2 = zeros(Sparsity,Ntraining);
sol_solver = zeros(size(Gamma));
sol_solver_hat = zeros(size(Gamma));
sol_solver_hat2 = zeros(size(Gamma));
err_solver = zeros(2,Ntraining);
err_solver_hat = zeros(2,Ntraining);
err_solver_hat2 = zeros(2,Ntraining);
diff_solver = zeros(2,Ntraining);
dist_solver = zeros(Sparsity,Ntraining);
dist_solver_hat = zeros(Sparsity,Ntraining);
dist_solver_hat2 = zeros(Sparsity,Ntraining);
......@@ -161,40 +179,78 @@ for k=1:numel(dist_paliers)-1
%OMP solving
tic
[sol_omp(:,i), err_mse_omp, iter_time_omp]=greed_omp_chol(Data(:,i),X_norm,size(X_norm,2),'stopTol',1*Sparsity);
t1=toc;
err_omp(1,i) = norm(X_norm*Gamma(:,i)-X_norm*sol_omp(:,i))/norm(X_norm*Gamma(:,i));
err_omp(2,i) = isequal(find(Gamma(:,i)),find(sol_omp(:,i)>1e-4));
idx_omp = find(sol_omp(:,i));
resDist(1,k,1,i) = min(norm(points(idx(1)) - points(idx_omp(1))),norm(points(idx(1)) - points(idx_omp(2))));
resDist(1,k,2,i) = min(norm(points(idx(2)) - points(idx_omp(1))),norm(points(idx(2)) - points(idx_omp(2))));
if strcmp(solver_choice,'omp')
%OMP
tic
[sol_solver(:,i), err_mse_solver, iter_time_solver]=greed_omp_chol(Data(:,i),X_norm,M,'stopTol',1*Sparsity);
t1=toc;
elseif strcmp(solver_choice,'iht')
%IHT
tic
[sol_solver(:,i), err_mse_solver, iter_time_solver]= hard_l0_Mterm(Data(:,i),X_norm,M,1*Sparsity,'verbose',false,'maxIter',1000);
t1=toc;
else
error('invalid solver choice: must be omp or iht');
end
err_solver(1,i) = norm(X_norm*Gamma(:,i)-X_norm*sol_solver(:,i))/norm(X_norm*Gamma(:,i));
err_solver(2,i) = isequal(find(Gamma(:,i)),find(sol_solver(:,i)>1e-4));
idx_solver = find(sol_solver(:,i));
resDist(1,k,1,i) = min(norm(points(idx(1)) - points(idx_solver(1))),norm(points(idx(1)) - points(idx_solver(2))));
resDist(1,k,2,i) = min(norm(points(idx(2)) - points(idx_solver(1))),norm(points(idx(2)) - points(idx_solver(2))));
compute_Times(1,k,i)=t1;
resDist_matlab(1,k,1,i) = min(norm(points(idx(1)) - points(idx_omp(1))),norm(points(idx(1)) - points(idx_omp(2))));
resDist_matlab(1,k,2,i) = min(norm(points(idx(2)) - points(idx_omp(1))),norm(points(idx(2)) - points(idx_omp(2))));
resDist_matlab(1,k,1,i) = min(norm(points(idx(1)) - points(idx_solver(1))),norm(points(idx(1)) - points(idx_solver(2))));
resDist_matlab(1,k,2,i) = min(norm(points(idx(2)) - points(idx_solver(1))),norm(points(idx(2)) - points(idx_solver(2))));
compute_Times_matlab(1,k,i)=t1;
for ll=1:nb_approx_MEG
X_approx_norm = MEG_approxS_norm{ll};
tic
[sol_omp_hat(:,i), err_mse_omp_hat, iter_time_omp_hat]=greed_omp_chol(Data(:,i),faustS_mult{ll},size(X_approx_norm,2),'stopTol',1*Sparsity,'P_trans',trans_faustS_mult{ll});
%% objet faust
if strcmp(solver_choice,'omp')
%OMP
tic
[sol_solver_hat(:,i), err_mse_solver_hat, iter_time_solver_hat]=greed_omp_chol(Data(:,i),MEG_faustS{ll},M,'stopTol',1*Sparsity);
t1=toc;
elseif strcmp(solver_choice,'iht')
%IHT
tic
[sol_solver_hat(:,i), err_mse_solver, iter_time_solver]=hard_l0_Mterm(Data(:,i),MEG_faustS{ll},M,1*Sparsity,'verbose',false,'maxIter',1000);
t1=toc;
else
error('invalid solver choice: must be omp or iht');
end
t1=toc;
err_omp_hat(1,i) = norm(X_norm*Gamma(:,i)-X_approx_norm*sol_omp_hat(:,i))/norm(X_norm*Gamma(:,i));
err_omp_hat(2,i) = isequal(find(Gamma(:,i)),find(sol_omp_hat(:,i)>1e-4));
idx_omp = find(sol_omp_hat(:,i));
resDist(ll+1,k,1,i) = min(norm(points(idx(1)) - points(idx_omp(1))),norm(points(idx(1)) - points(idx_omp(2))));
resDist(ll+1,k,2,i) = min(norm(points(idx(2)) - points(idx_omp(1))),norm(points(idx(2)) - points(idx_omp(2))));
err_solver_hat(1,i) = norm(X_norm*Gamma(:,i)-X_approx_norm*sol_solver_hat(:,i))/norm(X_norm*Gamma(:,i));
err_solver_hat(2,i) = isequal(find(Gamma(:,i)),find(sol_solver_hat(:,i)>1e-4));
idx_solver = find(sol_solver_hat(:,i));
resDist(ll+1,k,1,i) = min(norm(points(idx(1)) - points(idx_solver(1))),norm(points(idx(1)) - points(idx_solver(2))));
resDist(ll+1,k,2,i) = min(norm(points(idx(2)) - points(idx_solver(1))),norm(points(idx(2)) - points(idx_solver(2))));
compute_Times(ll+1,k,i)=t1;
tic
[sol_omp_hat(:,i), err_mse_omp_hat, iter_time_omp_hat]=greed_omp_chol(Data(:,i),matlab_faustS_mult{ll},size(X_approx_norm,2),'stopTol',1*Sparsity,'P_trans',matlab_trans_faustS_mult{ll});
t2=toc;
err_omp_hat(1,i) = norm(X_norm*Gamma(:,i)-X_approx_norm*sol_omp_hat(:,i))/norm(X_norm*Gamma(:,i));
err_omp_hat(2,i) = isequal(find(Gamma(:,i)),find(sol_omp_hat(:,i)>1e-4));
idx_omp = find(sol_omp_hat(:,i));
resDist_matlab(ll+1,k,1,i) = min(norm(points(idx(1)) - points(idx_omp(1))),norm(points(idx(1)) - points(idx_omp(2))));
resDist_matlab(ll+1,k,2,i) = min(norm(points(idx(2)) - points(idx_omp(1))),norm(points(idx(2)) - points(idx_omp(2))));
%% matlab function
if strcmp(solver_choice,'omp')
% OMP
tic
[sol_solver_hat(:,i), err_mse_solver_hat, iter_time_solver_hat]=greed_omp_chol(Data(:,i),matlab_faustS_mult{ll},M,'stopTol',1*Sparsity,'P_trans',matlab_trans_faustS_mult{ll});
t2=toc;
elseif strcmp(solver_choice,'iht')
% IHT
tic
[sol_solver_hat(:,i), err_mse_solver, iter_time_solver]=hard_l0_Mterm(Data(:,i),matlab_faustS_mult{ll},M,1*Sparsity,'verbose',false,'maxIter',1000,'P_trans',matlab_trans_faustS_mult{ll});
t2=toc;
else
error('invalid solver choice : must be omp or iht');
end
err_solver_hat(1,i) = norm(X_norm*Gamma(:,i)-X_approx_norm*sol_solver_hat(:,i))/norm(X_norm*Gamma(:,i));
err_solver_hat(2,i) = isequal(find(Gamma(:,i)),find(sol_solver_hat(:,i)>1e-4));
idx_solver = find(sol_solver_hat(:,i));
resDist_matlab(ll+1,k,1,i) = min(norm(points(idx(1)) - points(idx_solver(1))),norm(points(idx(1)) - points(idx_solver(2))));
resDist_matlab(ll+1,k,2,i) = min(norm(points(idx(2)) - points(idx_solver(1))),norm(points(idx(2)) - points(idx_solver(2))));
compute_Times_matlab(ll+1,k,i)=t2;
end
......@@ -205,7 +261,7 @@ toc
heure = clock ;
matfile = fullfile(pathname, 'output/results_BSL_user');
save(matfile,'resDist','resDist_matlab','RCG_approxS_MEG','nb_approx_MEG','compute_Times','compute_Times_matlab', 'RCG_approxS_MEG');
save(matfile,'resDist','params','resDist_matlab','RCG_approxS_MEG','nb_approx_MEG','compute_Times','compute_Times_matlab', 'RCG_approxS_MEG');
......@@ -50,12 +50,17 @@ if (not(exist(matfile)))
error('run BSL.m before Fig_BSL.m');
end
load(matfile);
Ntraining=params.Ntraining;
Sparsity=params.Sparsity;
Ntest=Ntraining*Sparsity;
%% convergence analysis
d1 = cat(4,resDist(:,1,1,:),resDist(:,1,2,:));
d2 = cat(4,resDist(:,2,1,:),resDist(:,2,2,:));
d3 = cat(4,resDist(:,3,1,:),resDist(:,3,2,:));
test2 = 100*[squeeze(d1);zeros(1,1000);squeeze(d2);zeros(1,1000);squeeze(d3)];
test2 = 100*[squeeze(d1);zeros(1,Ntest);squeeze(d2);zeros(1,Ntest);squeeze(d3)];
figure('color',[1 1 1]);
......@@ -105,7 +110,7 @@ title('Fig 9 : C++ wrapper faust');
d1 = cat(4,resDist_matlab(:,1,1,:),resDist_matlab(:,1,2,:));
d2 = cat(4,resDist_matlab(:,2,1,:),resDist_matlab(:,2,2,:));
d3 = cat(4,resDist_matlab(:,3,1,:),resDist_matlab(:,3,2,:));
test2 = 100*[squeeze(d1);zeros(1,1000);squeeze(d2);zeros(1,1000);squeeze(d3)];
test2 = 100*[squeeze(d1);zeros(1,Ntest);squeeze(d2);zeros(1,Ntest);squeeze(d3)];
figure('color',[1 1 1]);
title('MATLAB');
......
......@@ -7,6 +7,9 @@ y=x;
for i=length(facts):-1:1
y=facts{i}*y;
end
if issparse(y)
y=full(y);
end
end
......
......@@ -227,6 +227,7 @@ else error('P is of unsupported type. Use matrix, function_handle or obje
% This vector is used repeatedly, so pre-calculate;
Ptx=Pt(x);
......
function [s, err_mse, iter_time]=hard_l0_Mterm(x,A,m,M,varargin)
% hard_l0_Mterm: Hard thresholding algorithm that keeps exactly M elements
% in each iteration.
%
% This algorithm has certain performance guarantees as described in [1],
% [2] and [3].
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Usage
%
% [s, err_mse, iter_time]=hard_l0_Mterm(x,P,m,M,'option_name','option_value')
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% Input
%
% Mandatory:
% x Observation vector to be decomposed
% P Either:
% 1) An nxm matrix (n must be dimension of x)
% 2) A function handle (type "help function_format"
% for more information)
% Also requires specification of P_trans option.
% 3) An object handle (type "help object_format" for
% more information)
% m length of s
% M non-zero elements to keep in each iteration
%
% Possible additional options:
% (specify as many as you want using 'option_name','option_value' pairs)
% See below for explanation of options:
%__________________________________________________________________________
% option_name | available option_values | default
%--------------------------------------------------------------------------
% stopTol | number (see below) | 1e-16
% P_trans | function_handle (see below) |
% maxIter | positive integer (see below) | n^2
% verbose | true, false | false
% start_val | vector of length m | zeros
% step_size | number | 0 (auto)
%
% stopping criteria used : (OldRMS-NewRMS)/RMS(x) < stopTol
%
% stopTol: Value for stopping criterion.
%
% P_trans: If P is a function handle, then P_trans has to be specified and
% must be a function handle.
%
% maxIter: Maximum number of allowed iterations.
%
% verbose: Logical value to allow algorithm progress to be displayed.
%
% start_val: Allows algorithms to start from partial solution.
%
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% Outputs
%
% s Solution vector
% err_mse Vector containing mse of approximation error for each
% iteration
% iter_time Vector containing computation times for each iteration
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% Description
%
% Implements the M-sparse algorithm described in [1], [2] and [3].
% This algorithm takes a gradient step and then thresholds to only retain
% M non-zero elements. It allows the step-size to be calculated
% automatically as described in [3] and is therefore now independent from
% a rescaling of P.
%
%
% References
% [1] T. Blumensath and M.E. Davies, "Iterative Thresholding for Sparse
% Approximations", submitted, 2007
% [2] T. Blumensath and M. Davies; "Iterative Hard Thresholding for
% Compressed Sensing" to appear Applied and Computational Harmonic
% Analysis
% [3] T. Blumensath and M. Davies; "A modified Iterative Hard
% Thresholding algorithm with guaranteed performance and stability"
% in preparation (title may change)
% See Also
% hard_l0_reg
%
% Copyright (c) 2007 Thomas Blumensath
%
% The University of Edinburgh
% Email: thomas.blumensath@ed.ac.uk
% Comments and bug reports welcome
%
% This file is part of sparsity Version 0.4
% Created: April 2007
% Modified January 2009
%
% Part of this toolbox was developed with the support of EPSRC Grant
% D000246/1
%
% Please read COPYRIGHT.m for terms and conditions.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Default values and initialisation
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[n1 n2]=size(x);
if n2 == 1
n=n1;
elseif n1 == 1
x=x';
n=n2;
else
error('x must be a vector.');
end
sigsize = x'*x/n;
oldERR = sigsize;
err_mse = [];
iter_time = [];
STOPTOL = 1e-16;
MAXITER = n^2;
verbose = false;
initial_given=0;
s_initial = zeros(m,1);
MU = 0;
if verbose
display('Initialising...')
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Output variables
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
switch nargout
case 3
comp_err=true;
comp_time=true;
case 2
comp_err=true;
comp_time=false;
case 1
comp_err=false;
comp_time=false;
case 0
error('Please assign output variable.')
otherwise
error('Too many output arguments specified')
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Look through options
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Put option into nice format
Options={};
OS=nargin-4;
c=1;
for i=1:OS
if isa(varargin{i},'cell')
CellSize=length(varargin{i});
ThisCell=varargin{i};
for j=1:CellSize
Options{c}=ThisCell{j};
c=c+1;
end
else
Options{c}=varargin{i};
c=c+1;
end
end
OS=length(Options);
if rem(OS,2)
error('Something is wrong with argument name and argument value pairs.')
end
for i=1:2:OS
switch Options{i}
case {'stopTol'}
if isa(Options{i+1},'numeric') ; STOPTOL = Options{i+1};
else error('stopTol must be number. Exiting.'); end
case {'P_trans'}
if isa(Options{i+1},'function_handle'); Pt = Options{i+1};
else error('P_trans must be function _handle. Exiting.'); end
case {'maxIter'}
if isa(Options{i+1},'numeric'); MAXITER = Options{i+1};
else error('maxIter must be a number. Exiting.'); end
case {'verbose'}
if isa(Options{i+1},'logical'); verbose = Options{i+1};
else error('verbose must be a logical. Exiting.'); end
case {'start_val'}
if isa(Options{i+1},'numeric') && length(Options{i+1}) == m ;
s_initial = Options{i+1};
initial_given=1;
else error('start_val must be a vector of length m. Exiting.'); end
case {'step_size'}
if isa(Options{i+1},'numeric') && (Options{i+1}) > 0 ;
MU = Options{i+1};
else error('Stepsize must be between a positive number. Exiting.'); end
otherwise
error('Unrecognised option. Exiting.')
end
end
if nargout >=2
err_mse = zeros(MAXITER,1);
end
if nargout ==3
iter_time = zeros(MAXITER,1);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Make P and Pt functions
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if isa(A,'float') P =@(z) A*z; Pt =@(z) A'*z;
elseif isobject(A) P =@(z) A*z; Pt =@(z) A'*z;
elseif isa(A,'function_handle')
try
if isa(Pt,'function_handle'); P=A;
else error('If P is a function handle, Pt also needs to be a function handle. Exiting.'); end
catch error('If P is a function handle, Pt needs to be specified. Exiting.'); end
else error('P is of unsupported type. Use matrix, function_handle or object. Exiting.'); end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Do we start from zero or not?
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if initial_given ==1;
if length(find(s_initial)) > M
display('Initial vector has more than M non-zero elements. Keeping only M largest.')
end
s = s_initial;
[ssort sortind] = sort(abs(s),'descend');
s(sortind(M+1:end)) = 0;
Ps = P(s);
Residual = x-Ps;
oldERR = Residual'*Residual/n;
else
s_initial = zeros(m,1);
Residual = x;
s = s_initial;
Ps = zeros(n,1);
oldERR = sigsize;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Random Check to see if dictionary norm is below 1
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
x_test=randn(m,1);
x_test=x_test/norm(x_test);
nP=norm(P(x_test));
if abs(MU*nP)>1;
display('WARNING! Algorithm likely to become unstable.')
display('Use smaller step-size or || P ||_2 < 1.')
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Main algorithm
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if verbose
display('Main iterations...')
end
tic
t=0;
done = 0;
iter=1;
while ~done
if MU == 0
%Calculate optimal step size and do line search
olds = s;
oldPs = Ps;
IND = s~=0;
d = Pt(Residual);
% If the current vector is zero, we take the largest elements in d
if sum(IND)==0
[dsort sortdind] = sort(abs(d),'descend');
IND(sortdind(1:M)) = 1;
end
id = (IND.*d);
Pd = P(id);
mu = id'*id/(Pd'*Pd);
s = olds + mu * d;
[ssort sortind] = sort(abs(s),'descend');
s(sortind(M+1:end)) = 0;
Ps = P(s);
% Calculate step-size requirement
omega = (norm(s-olds)/norm(Ps-oldPs))^2;
% As long as the support changes and mu > omega, we decrease mu
while mu > (0.99)*omega && sum(xor(IND,s~=0))~=0 && sum(IND)~=0
% display(['decreasing mu'])
% We use a simple line search, halving mu in each step
mu = mu/2;
s = olds + mu * d;
[ssort sortind] = sort(abs(s),'descend');
s(sortind(M+1:end)) = 0;
Ps = P(s);
% Calculate step-size requirement
omega = (norm(s-olds)/norm(Ps-oldPs))^2;
end
else
% Use fixed step size
s = s + MU * Pt(Residual);
[ssort sortind] = sort(abs(s),'descend');
s(sortind(M+1:end)) = 0;
Ps = P(s);
end
Residual = x-Ps;
ERR=Residual'*Residual/n;
if comp_err
err_mse(iter)=ERR;
end
if comp_time
iter_time(iter)=toc;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Are we done yet?
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if comp_err && iter >=2
if ((err_mse(iter-1)-err_mse(iter))/sigsize<STOPTOL);
if verbose
display(['Stopping. Approximation error changed less than ' num2str(STOPTOL)])
end
done = 1;
elseif verbose && toc-t>10
display(sprintf('Iteration %i. --- %i mse change',iter ,(err_mse(iter-1)-err_mse(iter))/sigsize))
t=toc;
end
else
if ((oldERR - ERR)/sigsize < STOPTOL) && iter >=2;
if verbose
display(['Stopping. Approximation error changed less than ' num2str(STOPTOL)])
end
done = 1;
elseif verbose && toc-t>10
display(sprintf('Iteration %i. --- %i mse change',iter ,(oldERR - ERR)/sigsize))
t=toc;
end
end
% Also stop if residual gets too small or maxIter reached
if comp_err
if err_mse(iter)<1e-16
display('Stopping. Exact signal representation found!')
done=1;
end
elseif iter>1
if ERR<1e-16
display('Stopping. Exact signal representation found!')
done=1;
end
end
if iter >= MAXITER
display('Stopping. Maximum number of iterations reached!')
done = 1;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% If not done, take another round
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if ~done
iter=iter+1;
oldERR=ERR;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Only return as many elements as iterations
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if nargout >=2
err_mse = err_mse(1:iter);
end
if nargout ==3
iter_time = iter_time(1:iter);
end
if verbose
display('Done')
end
......@@ -60,11 +60,15 @@ classdef matlab_faust < handle
%% Size
function Size=size(this,varargin);
%size - overload of the matlab size function
if (nargin == 1)
Size=mexFaust('size',this.objectHandle);
else (nargin == 2)
Size=mexFaust('size',this.objectHandle,varargin);
Size=mexFaust('size',this.objectHandle);
if (length(varargin)~=0)
if (varargin{1}==1)
Size=Size(1);
elseif (varargin{1}==2)
Size=Size(2);
end
end
end
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment