Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ac32b2cd authored by hhakim's avatar hhakim
Browse files

Fix pyfaust expm_multiply unit test.

parent ceff7fe2
Branches
Tags
No related merge requests found
......@@ -2,6 +2,7 @@ import unittest
from pyfaust.poly import basis, poly, expm_multiply
from numpy.random import randint
import numpy as np
from numpy.linalg import norm
from scipy.sparse import csr_matrix, random
from scipy.sparse.linalg import expm_multiply as scipy_expm_multiply
import tempfile
......@@ -26,6 +27,7 @@ class TestPoly(unittest.TestCase):
else:
self.dtype = 'double'
self.L = random(self.d, self.d, .02, format='csr', dtype=self.dtype)
self.L @= self.L.H
self.K = 5
def test_basis(self):
......@@ -137,18 +139,18 @@ class TestPoly(unittest.TestCase):
t = np.linspace(**pts_args).astype(L.dtype)
y = expm_multiply(L, x, t)
y_ref = scipy_expm_multiply(L, x, **pts_args)
self.assertTrue(np.allclose(y, y_ref))
self.assertTrue(norm(y-y_ref)/norm(y_ref) < 1e-2)
# test expm_multiply on a matrix
X = np.random.rand(L.shape[1], 32).astype(L.dtype)
pts_args = {'start':-.5, 'stop':-0.1, 'num':3, 'endpoint':True}
t = np.linspace(**pts_args)
y = expm_multiply(L, X, t)
y_ref = scipy_expm_multiply(L, X, **pts_args)
self.assertTrue(np.allclose(y, y_ref))
self.assertTrue(norm(y-y_ref)/norm(y_ref) < 1e-2)
# test expm_multiply with (non-default) tradeoff=='memory'
X = np.random.rand(L.shape[1], 32).astype(L.dtype)
pts_args = {'start':-.5, 'stop':-0.1, 'num':3, 'endpoint':True}
t = np.linspace(**pts_args)
y = expm_multiply(L, X, t, tradeoff='memory')
y_ref = scipy_expm_multiply(L, X, **pts_args)
self.assertTrue(np.allclose(y, y_ref))
self.assertTrue(norm(y-y_ref)/norm(y_ref) < 1e-2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment