Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 243a2cae authored by CARRIVAIN Pascal's avatar CARRIVAIN Pascal
Browse files

add tests to fast-deformable-butterfly factorization function

parent 0f57ba1f
1 merge request!1review of fdb function
Pipeline #984375 failed
import numpy as np
import os
import scipy as sp
import unittest
from pyfaust import fact
class TestFact(unittest.TestCase):
"""Test case for the 'fact' module."""
def test_fdb(self):
"""Test of the function 'fdb'."""
os.environ["GB_DISABLE_EINOPS"] = "1"
os.environ["GB_DISABLE_PYTORCH"] = "1"
M = 2 ** np.random.randint(1, high=11) + 1
N = 2 ** np.random.randint(1, high=11) + 1
matrix = np.random.randn(M, N)
self.assertRaises(Exception, fact.fdb, matrix, 2, 1)
M = 2 ** np.random.randint(1, high=11)
N = 2 ** np.random.randint(1, high=11)
matrix = np.random.randn(M, N)
self.assertRaises(Exception, fact.fdb, matrix, 2, max(M, N) + 1)
self.assertRaises(NotImplementedError,
fact.fdb, matrix, 2, 1, True, 'nothing')
for i in range(100):
rank = 1
M, N = rank, rank
while M <= rank:
M = 2 ** np.random.randint(1, high=11)
while N <= rank:
N = 2 ** np.random.randint(1, high=11)
n_factors = np.random.randint(2, high=11)
matrix = np.ones((M, N))
print("ones {0:d}: shape=({1:d}, {2:d}),".format(i, M, N) +
" number of factors={0:d}".format(n_factors))
F = fact.fdb(matrix, n_factors=n_factors, rank=rank)
ncols = F.factors(F.numfactors() - 1).shape[1]
approx = np.eye(ncols, M=N, k=0)
for f in range(F.numfactors()):
approx = F.factors(F.numfactors() - 1 - f) @ approx
error = np.linalg.norm(matrix - approx) / np.linalg.norm(matrix)
self.assertTrue(error < 1e-12)
# print('true:')
# print(matrix)
# print('approximation:')
# print(approx)
# print('error={0:e}'.format(error))
N = 2 ** np.random.randint(1, high=11)
H = sp.linalg.hadamard(N)
print("hadamard {0:d}: shape=({1:d}, {2:d}),".format(i, N, N) +
" number of factors={0:d}".format(n_factors))
F = fact.fdb(H, n_factors=n_factors, rank=rank)
print(F)
ncols = F.factors(F.numfactors() - 1).shape[1]
approx = np.eye(ncols, M=N, k=0)
for f in range(F.numfactors()):
approx = F.factors(F.numfactors() - 1 - f) @ approx
error = np.linalg.norm(H - approx) / np.linalg.norm(H)
print(error)
self.assertTrue(error < 1e-12)
M = 2 ** np.random.randint(1, high=11)
N = 2 ** np.random.randint(1, high=11)
N = 64
M = N
if N == 4:
rank = 1
elif N == 8:
rank = 2
elif N == 16:
rank = 4
elif N == 32:
rank = 4
else:
continue
x = np.exp(-2.0j * np.pi * np.arange(M) / M)
V = np.vander(x, N, increasing=True)
print("vandermonde {0:d}: shape=({1:d}, {2:d}),".format(i, N, N) +
" number of factors={0:d}".format(n_factors))
F = fact.fdb(V, n_factors=n_factors, rank=4)#rank)
print(F)
ncols = F.factors(F.numfactors() - 1).shape[1]
approx = np.eye(ncols, M=N, k=0)
for f in range(F.numfactors()):
approx = F.factors(F.numfactors() - 1 - f) @ approx
error = np.linalg.norm(V - approx) / np.linalg.norm(V)
print(np.round(V, 3))
print(np.round(approx, 3))
print(error)
self.assertTrue(error < 1e-12)
...@@ -2,6 +2,7 @@ import unittest ...@@ -2,6 +2,7 @@ import unittest
from pyfaust.tests.TestFaust import TestFaust from pyfaust.tests.TestFaust import TestFaust
from pyfaust.tests.TestPoly import TestPoly from pyfaust.tests.TestPoly import TestPoly
from pyfaust.tests.TestFactParams import TestFactParams from pyfaust.tests.TestFactParams import TestFactParams
from pyfaust.tests.TestFact import TestFact
def run_tests(dev, dtype): def run_tests(dev, dtype):
...@@ -11,9 +12,10 @@ def run_tests(dev, dtype): ...@@ -11,9 +12,10 @@ def run_tests(dev, dtype):
""" """
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()
suite = unittest.TestSuite() suite = unittest.TestSuite()
for class_name in ['TestFaust', 'TestPoly', 'TestFactParams']: for class_name in ['TestFaust', 'TestPoly', 'TestFactParams', 'TestFact']:
testloader = unittest.TestLoader() testloader = unittest.TestLoader()
test_names = eval("testloader.getTestCaseNames("+class_name+")") test_names = eval("testloader.getTestCaseNames("+class_name+")")
print('here', test_names)
for meth_name in test_names: for meth_name in test_names:
test = eval(""+class_name+"('"+meth_name+"', dev=dev, dtype=dtype)") test = eval(""+class_name+"('"+meth_name+"', dev=dev, dtype=dtype)")
suite.addTest(test) suite.addTest(test)
......
import unittest import unittest
from pyfaust.tests.TestFact import TestFact
from pyfaust.tests.TestFaust import TestFaust from pyfaust.tests.TestFaust import TestFaust
from pyfaust.tests.TestPoly import TestPoly from pyfaust.tests.TestPoly import TestPoly
import sys import sys
......
  • hhakim @hricha ·
    Owner

    @pcarriva are you planning to add a matlab version in matfaust after?

  • Author Owner

    First of all I would like to ask you for a review. However, I am 100% ready to ask for a MR. I let you know asap.

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment