Mentions légales du service

Skip to content
Snippets Groups Projects
Commit e846f7f9 authored by hhakim's avatar hhakim
Browse files

Support float scalar type for matfaust.poly.* functions.

parent c2395cb4
Branches
Tags
No related merge requests found
......@@ -23,11 +23,15 @@ classdef FaustPoly < matfaust.Faust
%================================================================
function M = next(self)
if(self.isreal)
core_obj = mexPolyReal('nextPolyFaust', self.matrix.objectHandle);
if(strcmp(self.dtype, 'float'))
core_obj = mexPolyRealFloat('nextPolyFaust', self.matrix.objectHandle);
else
core_obj = mexPolyReal('nextPolyFaust', self.matrix.objectHandle);
end
else
core_obj = mexPolyCplx('nextPolyFaust', self.matrix.objectHandle);
end
M = matfaust.poly.FaustPoly(core_obj, self.isreal);
M = matfaust.poly.FaustPoly(core_obj, self.isreal, 'cpu', self.dtype);
end
end
......@@ -42,11 +46,15 @@ classdef FaustPoly < matfaust.Faust
if(iscell(X))
% X is {}: no X passed (see matfaust.poly.poly())
if(self.isreal)
core_obj = mexPolyReal('polyFaust', coeffs, self.matrix.objectHandle);
if(strcmp(class(self), 'single'))
core_obj = mexPolyRealFloat('polyFaust', coeffs, self.matrix.objectHandle);
else
core_obj = mexPolyReal('polyFaust', coeffs, self.matrix.objectHandle);
end
else
core_obj = mexPolyCplx('polyFaust', coeffs, self.matrix.objectHandle);
end
M = matfaust.poly.FaustPoly(core_obj, self.isreal);
M = matfaust.poly.FaustPoly(core_obj, self.isreal, 'cpu', self.dtype);
elseif(ismatrix(X))
if(issparse(X))
error('X must be a dense matrix')
......@@ -55,7 +63,11 @@ classdef FaustPoly < matfaust.Faust
error('The faust and X dimensions must agree.')
end
if(self.isreal)
M = mexPolyReal('mulPolyFaust', coeffs, self.matrix.objectHandle, X);
if(strcmp(self.dtype, 'float'))
M = mexPolyRealFloat('mulPolyFaust', coeffs, self.matrix.objectHandle, X);
else
M = mexPolyReal('mulPolyFaust', coeffs, self.matrix.objectHandle, X);
end
else
M = mexPolyCplx('mulPolyFaust', coeffs, self.matrix.objectHandle, X);
end
......
......@@ -6,6 +6,7 @@
%> @param basis_name 'chebyshev', and others yet to come.
%> @param 'T0', matrix (optional): a sparse matrix to replace the identity as a 0-degree polynomial of the basis.
%> @param 'dev', str (optional): the computing device ('cpu' or 'gpu').
%> @param 'dtype', str (optional): to decide in which data type the resulting Faust will be encoded ('float' or 'double' by default).
%>
%> @retval F the Faust of the basis composed of the K+1 orthogonal polynomials.
%>
......@@ -59,6 +60,7 @@ function F = basis(L, K, basis_name, varargin)
T0 = []; % no T0 by default
argc = length(varargin);
dev = 'cpu';
dtype = 'double';
if(argc > 0)
for i=1:2:argc
if(argc > i)
......@@ -79,6 +81,12 @@ function F = basis(L, K, basis_name, varargin)
else
dev = tmparg;
end
case 'dtype'
if(argc == i || ~ strcmp(tmparg, 'float') && ~ startsWith(tmparg, 'double'))
error('dtype keyword argument is not followed by a valid value: float or double.')
else
dtype = tmparg;
end
otherwise
if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu'))
error([ tmparg ' unrecognized argument'])
......@@ -87,6 +95,7 @@ function F = basis(L, K, basis_name, varargin)
end
end
is_float = strcmp(class(L), 'single') || strcmp(dtype, 'float'); % since it's impossible to create a float sparse matrix in matlab use a dtype argument
mex_args = {basis_name, L, K, startsWith(dev, 'gpu')};
if(T0_is_set)
......@@ -101,7 +110,11 @@ function F = basis(L, K, basis_name, varargin)
if(strcmp(basis_name, 'chebyshev'))
if(is_real)
core_obj = mexPolyReal(mex_args{:});
if(is_float)
core_obj = mexPolyRealFloat(mex_args{:});
else
core_obj = mexPolyReal(mex_args{:});
end
else
core_obj = mexPolyCplx(mex_args{:});
end
......@@ -109,5 +122,5 @@ function F = basis(L, K, basis_name, varargin)
error(['unknown basis name: ' basis_name])
end
F = matfaust.poly.FaustPoly(core_obj, is_real);
F = matfaust.poly.FaustPoly(core_obj, is_real, 'cpu', dtype);
end
......@@ -8,6 +8,7 @@
%> @param 'K', integer (default value is 10) the greatest polynomial degree of the Chebyshev polynomial basis. The greater it is, the better is the approximate accuracy but note that a larger K increases the computational cost.
%> @param 'tradeoff', str (optional): 'memory' or 'time' to specify what matters the most: a small memory footprint or a small time of execution. It changes the implementation of pyfaust.poly.poly used behind. It can help when the memory size is limited relatively to the value of rel_err or the size of A and B.
%> @param 'dev', str (optional): the computing device ('cpu' or 'gpu').
%> @param 'dtype', str (optional): to decide in which data type the resulting array C will be encoded ('float' or 'double' by default).
%>
%>
%> @retval C the approximate of e^{t_k A} B. C is a tridimensional array of size (sizef(A,1), size(B,2), size(t, 1)), each slice C(:,:,i) is the action of the matrix exponentatial of A on B according to the time point t(i).
......@@ -36,6 +37,7 @@ function C = expm_multiply(A, B, t, varargin)
argc = length(varargin);
dev = 'cpu';
tradeoff = 'time';
dtype = 'double';
if(argc > 0)
for i=1:2:argc
if(argc > i)
......@@ -73,6 +75,12 @@ function C = expm_multiply(A, B, t, varargin)
else
group_coeffs = tmparg;
end
case 'dtype'
if(argc == i || ~ strcmp(tmparg, 'float') && ~ startsWith(tmparg, 'double'))
error('dtype keyword argument is not followed by a valid value: float or double.')
else
dtype = tmparg;
end
otherwise
if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu'))
error([ tmparg ' unrecognized argument'])
......@@ -102,7 +110,7 @@ function C = expm_multiply(A, B, t, varargin)
error('A must be symmetric positive definite')
end
phi = eigs(A, 1) / 2;
T = matfaust.poly.basis(A/phi-speye(size(A)), K, 'chebyshev', 'dev', dev);
T = matfaust.poly.basis(A/phi-speye(size(A)), K, 'chebyshev', 'dev', dev, 'dtype', dtype);
if (~ ismatrix(t) || ~ isreal(t) || size(t, 1) ~= 1 && size(t, 2) ~= 1)
error('t must be a real value or a real vector')
end
......
......@@ -9,6 +9,7 @@
%> @param 'L', matrix the sparse matrix on which the polynomial basis is built if basis is not already a Faust or a full array.
%> @param 'X', matrix if X is set, the linear combination of basis*X is computed (note that the memory space is optimized compared to the manual way of doing first B = basis*X and then calling poly on B without X set).
%> @param 'dev', str (optional): the computating device ('cpu' or 'gpu').
%> @param 'dtype', str (optional): to decide in which data type the resulting Faust or array will be encoded ('float' or 'double' by default). If basis is a Faust or an array its dtype/class is prioritary over this parameter which is in fact useful only if basis is the name of the basis (a str/char array).
%> @retval LC The linear combination Faust or full array depending on if basis is itself a Faust or a np.ndarray.
%>
%> @b Example
......@@ -75,6 +76,7 @@ function LC = poly(coeffs, basis, varargin)
X = {}; % by default no X argument is passed, set it as a cell (see why in matfaust.Faust.poly)
argc = length(varargin);
dev = 'cpu';
dtype = 'double';
if(argc > 0)
for i=1:2:argc
if(argc > i)
......@@ -100,6 +102,12 @@ function LC = poly(coeffs, basis, varargin)
else
X = tmparg;
end
case 'dtype'
if(argc == i || ~ strcmp(tmparg, 'float') && ~ startsWith(tmparg, 'double'))
error('dtype keyword argument is not followed by a valid value: float or double.')
else
dtype = tmparg;
end
otherwise
if((isstr(varargin{i}) || ischar(varargin{i})) && ~ strcmp(tmparg, 'cpu') && ~ startsWith(tmparg, 'gpu'))
error([ tmparg ' unrecognized argument'])
......@@ -118,17 +126,22 @@ function LC = poly(coeffs, basis, varargin)
error('coeffs and basis must be of the same scalar type (real or complex)')
end
K = size(coeffs, 1)-1;
if isvector(coeffs)
K = numel(coeffs)-1;
else
K = size(coeffs, 1)-1;
end
if(isstr(basis) || ischar(basis))
if(exist('L') ~= 1)
error('L key-value pair argument is missing in the argument list.')
end
basis = matfaust.poly.basis(L, K, basis, 'dev', dev);
basis = matfaust.poly.basis(L, K, basis, 'dev', dev, 'dtype', dtype);
end
is_real = isreal(basis);
is_float = strcmp(class(basis), 'single') || strcmp('dtype', 'float');
if(matfaust.isFaust(basis))
if(numfactors(basis) ~= numel(coeffs))
error('coeffs and basis dimensions must agree.')
......@@ -142,13 +155,21 @@ function LC = poly(coeffs, basis, varargin)
d = floor(size(basis,1) / (K+1));
if(size(coeffs, 2) == 1)
if(is_real)
LC = mexPolyReal('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu);
if(is_float)
LC = mexPolyRealFloat('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu);
else
LC = mexPolyReal('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu);
end
else
LC = mexPolyCplx('polyMatrix', d, K, size(basis,2), coeffs, basis, on_gpu);
end
else
if(is_real)
LC = mexPolyReal('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu);
if(is_float)
LC = mexPolyRealFloat('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu);
else
LC = mexPolyReal('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu);
end
else
LC = mexPolyCplx('polyMatrixGroupCoeffs', d, K, size(basis,2), size(coeffs, 2), coeffs, basis, on_gpu);
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment