Mentions légales du service

Skip to content
Snippets Groups Projects
Commit d6183f47 authored by hhakim's avatar hhakim
Browse files

Add a unit test for pyfaust.poly.poly python impl.

parent f21bc746
Branches
Tags
No related merge requests found
Pipeline #880252 passed
...@@ -93,49 +93,56 @@ class TestPoly(unittest.TestCase): ...@@ -93,49 +93,56 @@ class TestPoly(unittest.TestCase):
def test_poly(self): def test_poly(self):
print("Test poly()") print("Test poly()")
self._test_poly_impl('native')
def test_poly_py(self):
print("Test poly(impl='py')")
self._test_poly_impl('py')
def _test_poly_impl(self, impl):
d = self.d d = self.d
L = self.L L = self.L
K = self.K K = self.K
density = self.density density = self.density
F = basis(L, K, 'chebyshev', dev=self.dev).astype(self.dtype) F = basis(L, K, 'chebyshev', dev=self.dev, impl=impl).astype(self.dtype)
coeffs = np.random.rand(K+1).astype(self.dtype) coeffs = np.random.rand(K+1).astype(self.dtype)
G = poly(coeffs, F) G = poly(coeffs, F, impl=impl)
# Test polynomial as Faust # Test polynomial as Faust
poly_ref = np.zeros((d,d)) poly_ref = np.zeros((d,d))
for i,c in enumerate(coeffs[:]): for i,c in enumerate(coeffs[:]):
poly_ref += c * F[d*i:(i+1)*d, :] poly_ref += c * F[d*i:(i+1)*d, :]
self.assertAlmostEqual((G-poly_ref).norm(), 0) self.assertAlmostEqual((G-poly_ref).norm(), 0)
# Test polynomial as array # Test polynomial as array
GM = poly(coeffs, F.toarray()) GM = poly(coeffs, F.toarray(), impl=impl)
self.assertTrue(isinstance(GM, np.ndarray)) self.assertTrue(isinstance(GM, np.ndarray))
err = norm(GM - poly_ref.toarray())/norm(poly_ref.toarray()) err = norm(GM - poly_ref.toarray())/norm(poly_ref.toarray())
self.assertLessEqual(err, 1e-6) self.assertLessEqual(err, 1e-6)
# Test polynomial-vector product # Test polynomial-vector product
x = np.random.rand(F.shape[1], 1).astype(L.dtype) x = np.random.rand(F.shape[1], 1).astype(L.dtype)
# Three ways to do (not all as efficient as each other) # Three ways to do (not all as efficient as each other)
Fx1 = poly(coeffs, F, dev=self.dev)@x Fx1 = poly(coeffs, F, dev=self.dev, impl=impl)@x
Fx2 = poly(coeffs, F@x, dev=self.dev) Fx2 = poly(coeffs, F@x, dev=self.dev, impl=impl)
Fx3 = poly(coeffs, F, X=x, dev=self.dev) Fx3 = poly(coeffs, F, X=x, dev=self.dev, impl=impl)
err = norm(Fx1-Fx2)/norm(Fx1) err = norm(Fx1-Fx2)/norm(Fx1)
self.assertLessEqual(err, 1e-6) self.assertLessEqual(err, 1e-6)
self.assertTrue(np.allclose(Fx1, Fx3)) self.assertTrue(np.allclose(Fx1, Fx3))
# Test polynomial-matrix product # Test polynomial-matrix product
X = np.random.rand(F.shape[1], 18).astype(L.dtype) X = np.random.rand(F.shape[1], 18).astype(L.dtype)
FX1 = poly(coeffs, F, dev=self.dev)@X FX1 = poly(coeffs, F, dev=self.dev, impl=impl)@X
FX2 = poly(coeffs, F@X, dev=self.dev) FX2 = poly(coeffs, F@X, dev=self.dev, impl=impl)
FX3 = poly(coeffs, F, X=X, dev=self.dev) FX3 = poly(coeffs, F, X=X, dev=self.dev, impl=impl)
err = norm(FX1-FX2)/norm(FX1) err = norm(FX1-FX2)/norm(FX1)
self.assertLessEqual(err, 1e-6) self.assertLessEqual(err, 1e-6)
self.assertTrue(np.allclose(FX2, FX3)) self.assertTrue(np.allclose(FX2, FX3))
# Test creating the polynomial basis on the fly # Test creating the polynomial basis on the fly
G2 = poly(coeffs, 'chebyshev', L) G2 = poly(coeffs, 'chebyshev', L, impl=impl)
self.assertAlmostEqual((G-G2).norm(), 0) self.assertAlmostEqual((G-G2).norm(), 0)
GX = poly(coeffs, 'chebyshev', L, X=X, dev=self.dev) GX = poly(coeffs, 'chebyshev', L, X=X, dev=self.dev, impl=impl)
err = norm(FX1-GX)/norm(FX1) err = norm(FX1-GX)/norm(FX1)
self.assertLessEqual(err, 1e-6) self.assertLessEqual(err, 1e-6)
# Test polynomial-matrix product with arbitrary T0 # Test polynomial-matrix product with arbitrary T0
F_ = basis(L, K, 'chebyshev', dev=self.dev, T0=csr_matrix(X)) F_ = basis(L, K, 'chebyshev', dev=self.dev, T0=csr_matrix(X), impl=impl)
GT0eqX = poly(coeffs, F_, dev=self.dev).toarray() GT0eqX = poly(coeffs, F_, dev=self.dev, impl=impl).toarray()
self.assertTrue(np.allclose(GT0eqX, FX1)) self.assertTrue(np.allclose(GT0eqX, FX1))
def test_expm_multiply(self): def test_expm_multiply(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment