Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 742d23eb authored by hhakim's avatar hhakim
Browse files

Add a unit test for pyfaust rand_butterfly/opt_butterfly_faust and complete...

Add a unit test for pyfaust rand_butterfly/opt_butterfly_faust and complete the unit test for the dft butterfly opt (transpose case).

Issue #275
parent cc8facbb
No related branches found
No related tags found
No related merge requests found
......@@ -1401,17 +1401,54 @@ class TestFaustFactory(unittest.TestCase):
pow2_exp = random.Random().randint(1,10)
n = 2**pow2_exp
for normed in [True, False]:
x = np.random.rand(n)
F = dft(n, normed=normed, diag_opt=False)
oF = dft(n, normed=normed, diag_opt=True)
fF = oF.toarray()
ref_fft = F.toarray()
self.assertAlmostEqual(norm(ref_fft-fF)/norm(ref_fft),0)
self.assertTrue(np.allclose(F@x, oF@x))
# transpose case
FT = F.T
oFT = oF.T
fFT = oFT.toarray()
ref_fft = FT.toarray()
self.assertAlmostEqual(norm(ref_fft-fFT)/norm(ref_fft),0)
self.assertTrue(np.allclose(FT@x, oFT@x))
def testRandButterflyOpt(self):
print("Test pyfaust.opt_butterfly_faust()")
from pyfaust import rand_butterfly, opt_butterfly_faust
from scipy.sparse import csr_matrix
pow2_exp = random.Random().randint(1,10)
n = 2**pow2_exp
x = np.random.rand(n)
F = rand_butterfly(n)
# add a permutation
P = np.zeros((n, n))
randperm = np.random.permutation(n)
for i in range(P.shape[1]):
P[i, randperm[i]] = 1
F = F @ Faust(csr_matrix(P))
###
oF = opt_butterfly_faust(F)
fF = oF.toarray()
ref_F = F.toarray()
self.assertAlmostEqual(norm(ref_F-fF)/norm(ref_F),0)
self.assertTrue(np.allclose(F@x, oF@x))
# transpose case
FT = F.T
oFT = oF.T
fFT = oFT.toarray()
ref_F = FT.toarray()
self.assertAlmostEqual(norm(ref_F-fFT)/norm(ref_F),0)
self.assertTrue(np.allclose(FT@x, oFT@x))
def testRandButterfly(self):
print("Test pyfaust.rand_butterfly")
from pyfaust import wht, rand_butterfly
H = wht(32).toarray()
for dtype in ['float32', 'double', 'complex']:
for dtype in ['double', 'complex', 'float32']:
F = rand_butterfly(32, dtype=dtype)
self.assertTrue(not np.allclose(F.toarray(), H))
ref_I, ref_J = np.nonzero(H)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment