Mentions légales du service

Skip to content
Snippets Groups Projects
Commit e846f7f9 authored by hhakim's avatar hhakim
Browse files

Support float scalar type for matfaust.poly.* functions.

parent c2395cb4
Branches
Tags
No related merge requests found
...@@ -23,11 +23,15 @@ classdef FaustPoly < matfaust.Faust ...@@ -23,11 +23,15 @@ classdef FaustPoly < matfaust.Faust
%================================================================ %================================================================
function M = next(self) function M = next(self)
if(self.isreal) if(self.isreal)
core_obj = mexPolyReal('nextPolyFaust', self.matrix.objectHandle); if(strcmp(self.dtype, 'float'))
core_obj = mexPolyRealFloat('nextPolyFaust', self.matrix.objectHandle);
else
core_obj = mexPolyReal('nextPolyFaust', self.matrix.objectHandle);
end
else else
core_obj = mexPolyCplx('nextPolyFaust', self.matrix.objectHandle); core_obj = mexPolyCplx('nextPolyFaust', self.matrix.objectHandle);
end end
M = matfaust.poly.FaustPoly(core_obj, self.isreal); M = matfaust.poly.FaustPoly(core_obj, self.isreal, 'cpu', self.dtype);
end end
end end
...@@ -42,11 +46,15 @@ classdef FaustPoly < matfaust.Faust ...@@ -42,11 +46,15 @@ classdef FaustPoly < matfaust.Faust
if(iscell(X)) if(iscell(X))
% X is {}: no X passed (see matfaust.poly.poly()) % X is {}: no X passed (see matfaust.poly.poly())
if(self.isreal) if(self.isreal)
core_obj = mexPolyReal('polyFaust', coeffs, self.matrix.objectHandle); if(strcmp(class(self), 'single'))
core_obj = mexPolyRealFloat('polyFaust', coeffs, self.matrix.objectHandle);
else
core_obj = mexPolyReal('polyFaust', coeffs, self.matrix.objectHandle);
end
else else
core_obj = mexPolyCplx('polyFaust', coeffs, self.matrix.objectHandle); core_obj = mexPolyCplx('polyFaust', coeffs, self.matrix.objectHandle);
end end
M = matfaust.poly.FaustPoly(core_obj, self.isreal); M = matfaust.poly.FaustPoly(core_obj, self.isreal, 'cpu', self.dtype);
elseif(ismatrix(X)) elseif(ismatrix(X))
if(issparse(X)) if(issparse(X))
error('X must be a dense matrix') error('X must be a dense matrix')
...@@ -55,7 +63,11 @@ classdef FaustPoly < matfaust.Faust ...@@ -55,7 +63,11 @@ classdef FaustPoly < matfaust.Faust
error('The faust and X dimensions must agree.') error('The faust and X dimensions must agree.')
end end
if(self.isreal) if(self.isreal)
M = mexPolyReal('mulPolyFaust', coeffs, self.matrix.objectHandle, X); if(strcmp(self.dtype, 'float'))
M = mexPolyRealFloat('mulPolyFaust', coeffs, self.matrix.objectHandle, X);
else
M = mexPolyReal('mulPolyFaust', coeffs, self.matrix.objectHandle, X);
end
else else
M = mexPolyCplx('mulPolyFaust', coeffs, self.matrix.objectHandle, X); M = mexPolyCplx('mulPolyFaust', coeffs, self.matrix.objectHandle, X);
end end
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
%> @param basis_name 'chebyshev', and others yet to come. %> @param basis_name 'chebyshev', and others yet to come.
%> @param 'T0', matrix (optional): a sparse matrix to replace the identity as a 0-degree polynomial of the basis. %> @param 'T0', matrix (optional): a sparse matrix to replace the identity as a 0-degree polynomial of the basis.
%> @param 'dev', str (optional): the computing device ('cpu' or 'gpu'). %> @param 'dev', str (optional): the computing device ('cpu' or 'gpu').
%> @param 'dtype', str (optional): to decide in which data type the resulting Faust will be encoded ('float' or 'double' by default).
%> %>
%> @retval F the Faust of the basis composed of the K+1 orthogonal polynomials. %> @retval F the Faust of the basis composed of the K+1 orthogonal polynomials.
%> %>
...@@ -59,6 +60,7 @@ function F = basis(L, K, basis_name, varargin) ...@@ -59,6 +60,7 @@ function F = basis(L, K, basis_name, varargin)
T0 = []; % no T0 by default T0 = []; % no T0 by default
argc = length(varargin); argc = length(varargin);
dev = 'cpu'; dev = 'cpu';
dtype = 'double';
if(argc > 0) if(argc > 0)
for i=1:2:argc for i=1:2:argc
if(argc > i) if(argc > i)
...@@ -79,6 +81,12 @@ function F = basis(L, K, basis_name, varargin) ...@@ -79,6 +81,12 @@ function F = basis(L, K, basis_name, varargin)
else else
dev = tmparg; dev = tmparg;
end end
case 'dtype'
if(argc == i || ~ strcmp(tmparg, 'float') && ~ startsWith(tmparg, 'double'))
error('dtype keyword argument is not followed by a valid value: float or double.')
else
dtype = tmparg;
end
otherwise otherwise
if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu')) if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu'))
error([ tmparg ' unrecognized argument']) error([ tmparg ' unrecognized argument'])
...@@ -87,6 +95,7 @@ function F = basis(L, K, basis_name, varargin) ...@@ -87,6 +95,7 @@ function F = basis(L, K, basis_name, varargin)
end end
end end
is_float = strcmp(class(L), 'single') || strcmp(dtype, 'float'); % since it's impossible to create a float sparse matrix in matlab use a dtype argument
mex_args = {basis_name, L, K, startsWith(dev, 'gpu')}; mex_args = {basis_name, L, K, startsWith(dev, 'gpu')};
if(T0_is_set) if(T0_is_set)
...@@ -101,7 +110,11 @@ function F = basis(L, K, basis_name, varargin) ...@@ -101,7 +110,11 @@ function F = basis(L, K, basis_name, varargin)
if(strcmp(basis_name, 'chebyshev')) if(strcmp(basis_name, 'chebyshev'))
if(is_real) if(is_real)
core_obj = mexPolyReal(mex_args{:}); if(is_float)
core_obj = mexPolyRealFloat(mex_args{:});
else
core_obj = mexPolyReal(mex_args{:});
end
else else
core_obj = mexPolyCplx(mex_args{:}); core_obj = mexPolyCplx(mex_args{:});
end end
...@@ -109,5 +122,5 @@ function F = basis(L, K, basis_name, varargin) ...@@ -109,5 +122,5 @@ function F = basis(L, K, basis_name, varargin)
error(['unknown basis name: ' basis_name]) error(['unknown basis name: ' basis_name])
end end
F = matfaust.poly.FaustPoly(core_obj, is_real); F = matfaust.poly.FaustPoly(core_obj, is_real, 'cpu', dtype);
end end
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
%> @param 'K', integer (default value is 10) the greatest polynomial degree of the Chebyshev polynomial basis. The greater it is, the better is the approximate accuracy but note that a larger K increases the computational cost. %> @param 'K', integer (default value is 10) the greatest polynomial degree of the Chebyshev polynomial basis. The greater it is, the better is the approximate accuracy but note that a larger K increases the computational cost.
%> @param 'tradeoff', str (optional): 'memory' or 'time' to specify what matters the most: a small memory footprint or a small time of execution. It changes the implementation of pyfaust.poly.poly used behind. It can help when the memory size is limited relatively to the value of rel_err or the size of A and B. %> @param 'tradeoff', str (optional): 'memory' or 'time' to specify what matters the most: a small memory footprint or a small time of execution. It changes the implementation of pyfaust.poly.poly used behind. It can help when the memory size is limited relatively to the value of rel_err or the size of A and B.
%> @param 'dev', str (optional): the computing device ('cpu' or 'gpu'). %> @param 'dev', str (optional): the computing device ('cpu' or 'gpu').
%> @param 'dtype', str (optional): to decide in which data type the resulting array C will be encoded ('float' or 'double' by default).
%> %>
%> %>
%> @retval C the approximate of e^{t_k A} B. C is a tridimensional array of size (sizef(A,1), size(B,2), size(t, 1)), each slice C(:,:,i) is the action of the matrix exponentatial of A on B according to the time point t(i). %> @retval C the approximate of e^{t_k A} B. C is a tridimensional array of size (sizef(A,1), size(B,2), size(t, 1)), each slice C(:,:,i) is the action of the matrix exponentatial of A on B according to the time point t(i).
...@@ -36,6 +37,7 @@ function C = expm_multiply(A, B, t, varargin) ...@@ -36,6 +37,7 @@ function C = expm_multiply(A, B, t, varargin)
argc = length(varargin); argc = length(varargin);
dev = 'cpu'; dev = 'cpu';
tradeoff = 'time'; tradeoff = 'time';
dtype = 'double';
if(argc > 0) if(argc > 0)
for i=1:2:argc for i=1:2:argc
if(argc > i) if(argc > i)
...@@ -73,6 +75,12 @@ function C = expm_multiply(A, B, t, varargin) ...@@ -73,6 +75,12 @@ function C = expm_multiply(A, B, t, varargin)
else else
group_coeffs = tmparg; group_coeffs = tmparg;
end end
case 'dtype'
if(argc == i || ~ strcmp(tmparg, 'float') && ~ startsWith(tmparg, 'double'))
error('dtype keyword argument is not followed by a valid value: float or double.')
else
dtype = tmparg;
end
otherwise otherwise
if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu')) if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu'))
error([ tmparg ' unrecognized argument']) error([ tmparg ' unrecognized argument'])
...@@ -102,7 +110,7 @@ function C = expm_multiply(A, B, t, varargin) ...@@ -102,7 +110,7 @@ function C = expm_multiply(A, B, t, varargin)
error('A must be symmetric positive definite') error('A must be symmetric positive definite')
end end
phi = eigs(A, 1) / 2; phi = eigs(A, 1) / 2;
T = matfaust.poly.basis(A/phi-speye(size(A)), K, 'chebyshev', 'dev', dev); T = matfaust.poly.basis(A/phi-speye(size(A)), K, 'chebyshev', 'dev', dev, 'dtype', dtype);
if (~ ismatrix(t) || ~ isreal(t) || size(t, 1) ~= 1 && size(t, 2) ~= 1) if (~ ismatrix(t) || ~ isreal(t) || size(t, 1) ~= 1 && size(t, 2) ~= 1)
error('t must be a real value or a real vector') error('t must be a real value or a real vector')
end end
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
%> @param 'L', matrix the sparse matrix on which the polynomial basis is built if basis is not already a Faust or a full array. %> @param 'L', matrix the sparse matrix on which the polynomial basis is built if basis is not already a Faust or a full array.
%> @param 'X', matrix if X is set, the linear combination of basis*X is computed (note that the memory space is optimized compared to the manual way of doing first B = basis*X and then calling poly on B without X set). %> @param 'X', matrix if X is set, the linear combination of basis*X is computed (note that the memory space is optimized compared to the manual way of doing first B = basis*X and then calling poly on B without X set).
%> @param 'dev', str (optional): the computating device ('cpu' or 'gpu'). %> @param 'dev', str (optional): the computating device ('cpu' or 'gpu').
%> @param 'dtype', str (optional): to decide in which data type the resulting Faust or array will be encoded ('float' or 'double' by default). If basis is a Faust or an array its dtype/class is prioritary over this parameter which is in fact useful only if basis is the name of the basis (a str/char array).
%> @retval LC The linear combination Faust or full array depending on if basis is itself a Faust or a np.ndarray. %> @retval LC The linear combination Faust or full array depending on if basis is itself a Faust or a np.ndarray.
%> %>
%> @b Example %> @b Example
...@@ -75,6 +76,7 @@ function LC = poly(coeffs, basis, varargin) ...@@ -75,6 +76,7 @@ function LC = poly(coeffs, basis, varargin)
X = {}; % by default no X argument is passed, set it as a cell (see why in matfaust.Faust.poly) X = {}; % by default no X argument is passed, set it as a cell (see why in matfaust.Faust.poly)
argc = length(varargin); argc = length(varargin);
dev = 'cpu'; dev = 'cpu';
dtype = 'double';
if(argc > 0) if(argc > 0)
for i=1:2:argc for i=1:2:argc
if(argc > i) if(argc > i)
...@@ -100,6 +102,12 @@ function LC = poly(coeffs, basis, varargin) ...@@ -100,6 +102,12 @@ function LC = poly(coeffs, basis, varargin)
else else
X = tmparg; X = tmparg;
end end
case 'dtype'
if(argc == i || ~ strcmp(tmparg, 'float') && ~ startsWith(tmparg, 'double'))
error('dtype keyword argument is not followed by a valid value: float or double.')
else
dtype = tmparg;
end
otherwise otherwise
if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu')) if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu'))
error([ tmparg ' unrecognized argument']) error([ tmparg ' unrecognized argument'])
...@@ -118,17 +126,22 @@ function LC = poly(coeffs, basis, varargin) ...@@ -118,17 +126,22 @@ function LC = poly(coeffs, basis, varargin)
error('coeffs and basis must be of the same scalar type (real or complex)') error('coeffs and basis must be of the same scalar type (real or complex)')
end end
K = size(coeffs, 1)-1; if isvector(coeffs)
K = numel(coeffs)-1;
else
K = size(coeffs, 1)-1;
end
if(isstr(basis) || ischar(basis)) if(isstr(basis) || ischar(basis))
if(exist('L') ~= 1) if(exist('L') ~= 1)
error('L key-value pair argument is missing in the argument list.') error('L key-value pair argument is missing in the argument list.')
end end
basis = matfaust.poly.basis(L, K, basis, 'dev', dev); basis = matfaust.poly.basis(L, K, basis, 'dev', dev, 'dtype', dtype);
end end
is_real = isreal(basis); is_real = isreal(basis);
is_float = strcmp(class(basis), 'single') || strcmp('dtype', 'float');
if(matfaust.isFaust(basis)) if(matfaust.isFaust(basis))
if(numfactors(basis) ~= numel(coeffs)) if(numfactors(basis) ~= numel(coeffs))
error('coeffs and basis dimensions must agree.') error('coeffs and basis dimensions must agree.')
...@@ -142,13 +155,21 @@ function LC = poly(coeffs, basis, varargin) ...@@ -142,13 +155,21 @@ function LC = poly(coeffs, basis, varargin)
d = floor(size(basis,1) / (K+1)); d = floor(size(basis,1) / (K+1));
if(size(coeffs, 2) == 1) if(size(coeffs, 2) == 1)
if(is_real) if(is_real)
LC = mexPolyReal('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu); if(is_float)
LC = mexPolyRealFloat('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu);
else
LC = mexPolyReal('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu);
end
else else
LC = mexPolyCplx('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu); LC = mexPolyCplx('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu);
end end
else else
if(is_real) if(is_real)
LC = mexPolyReal('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu); if(is_float)
LC = mexPolyRealFloat('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu);
else
LC = mexPolyReal('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu);
end
else else
LC = mexPolyCplx('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu); LC = mexPolyCplx('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu);
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment