Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 29e99c0c authored by hhakim's avatar hhakim
Browse files

Refactor the pyfaust.tests module in a modular package in order to properly receive more tests.

parent 872dd655
No related branches found
No related tags found
No related merge requests found
import sys
import unittest import unittest
from pyfaust import (rand as frand, Faust, vstack, hstack, isFaust, dot, from pyfaust import (rand as frand, Faust, vstack, hstack, isFaust, dot,
concatenate, pinv, eye, dft, wht, is_gpu_mod_enabled) concatenate, pinv, eye, dft, wht, is_gpu_mod_enabled)
...@@ -14,7 +13,7 @@ dev = 'cpu' ...@@ -14,7 +13,7 @@ dev = 'cpu'
field = 'real' field = 'real'
class PyfaustSimpleTest(unittest.TestCase): class TestFaust(unittest.TestCase):
MIN_NUM_FACTORS = 1 MIN_NUM_FACTORS = 1
MAX_NUM_FACTORS = 4 MAX_NUM_FACTORS = 4
...@@ -24,12 +23,12 @@ class PyfaustSimpleTest(unittest.TestCase): ...@@ -24,12 +23,12 @@ class PyfaustSimpleTest(unittest.TestCase):
def setUp(self): def setUp(self):
""" """
""" """
nrows = randint(PyfaustSimpleTest.MIN_DIM_SIZE, nrows = randint(TestFaust.MIN_DIM_SIZE,
PyfaustSimpleTest.MAX_DIM_SIZE+1) TestFaust.MAX_DIM_SIZE+1)
ncols = randint(PyfaustSimpleTest.MIN_DIM_SIZE, ncols = randint(TestFaust.MIN_DIM_SIZE,
PyfaustSimpleTest.MAX_DIM_SIZE+1) TestFaust.MAX_DIM_SIZE+1)
nfacts = randint(PyfaustSimpleTest.MIN_NUM_FACTORS, nfacts = randint(TestFaust.MIN_NUM_FACTORS,
PyfaustSimpleTest.MAX_NUM_FACTORS+1) TestFaust.MAX_NUM_FACTORS+1)
self.F = frand(nrows, ncols, num_factors=nfacts, dev=dev, field=field) self.F = frand(nrows, ncols, num_factors=nfacts, dev=dev, field=field)
self.nrows = nrows self.nrows = nrows
self.ncols = ncols self.ncols = ncols
...@@ -282,7 +281,7 @@ class PyfaustSimpleTest(unittest.TestCase): ...@@ -282,7 +281,7 @@ class PyfaustSimpleTest(unittest.TestCase):
self.assertEqual(self.F.astype(np.complex).dtype, np.complex) self.assertEqual(self.F.astype(np.complex).dtype, np.complex)
else: else:
self.assertEqual(self.F.astype(np.float).dtype, np.float) self.assertEqual(self.F.astype(np.float).dtype, np.float)
except ValueError as e: except ValueError:
# complex > float not yet supported # complex > float not yet supported
pass pass
...@@ -387,32 +386,3 @@ class PyfaustSimpleTest(unittest.TestCase): ...@@ -387,32 +386,3 @@ class PyfaustSimpleTest(unittest.TestCase):
self._assertAlmostEqual(eye(self.nrows, self.ncols), self._assertAlmostEqual(eye(self.nrows, self.ncols),
np.eye(self.nrows, np.eye(self.nrows,
self.ncols)) self.ncols))
def run_tests(_dev, _field):
global dev, field
dev = _dev
field = _field
suite = unittest.makeSuite(PyfaustSimpleTest, 'test')
runner = unittest.TextTestRunner()
runner.run(suite)
if __name__ == "__main__":
nargs = len(sys.argv)
if(nargs > 1):
dev = sys.argv[1]
if dev != 'cpu' and not dev.startswith('gpu'):
raise ValueError("dev argument must be cpu or gpu.")
if(nargs > 2):
field = sys.argv[2]
if field not in ['complex', 'real']:
raise ValueError("field must be complex or float")
del sys.argv[2] # deleted to avoid interfering with unittest
del sys.argv[1]
if(len(sys.argv) > 1):
# ENOTE: test only a single test if name passed on command line
singleton = unittest.TestSuite()
singleton.addTest(PyfaustSimpleTest(sys.argv[1]))
unittest.TextTestRunner().run(singleton)
else:
unittest.main()
import unittest
from pyfaust.tests.TestFaust import TestFaust
dev = 'cpu'
field = 'real'
def run_tests(_dev, _field):
global dev, field
dev = _dev
field = _field
suite = unittest.makeSuite(TestFaust, 'test')
runner = unittest.TextTestRunner()
runner.run(suite)
import unittest
from pyfaust.tests.TestFaust import TestFaust
import sys
if __name__ == "__main__":
nargs = len(sys.argv)
if(nargs > 1):
dev = sys.argv[1]
if dev != 'cpu' and not dev.startswith('gpu'):
raise ValueError("dev argument must be cpu or gpu.")
if(nargs > 2):
field = sys.argv[2]
if field not in ['complex', 'real']:
raise ValueError("field must be complex or float")
del sys.argv[2] # deleted to avoid interfering with unittest
del sys.argv[1]
if(len(sys.argv) > 1):
# ENOTE: test only a single test if name passed on command line
singleton = unittest.TestSuite()
singleton.addTest(TestFaust(sys.argv[1]))
unittest.TextTestRunner().run(singleton)
else:
unittest.main()
...@@ -33,7 +33,7 @@ setup( ...@@ -33,7 +33,7 @@ setup(
name = 'pyfaust@PYFAUST_PKG_SUFFIX@', name = 'pyfaust@PYFAUST_PKG_SUFFIX@',
version = version, # cf. header version = version, # cf. header
ext_modules = cythonize(PyFaust, compiler_directives={'language_level': sys.version_info.major }), ext_modules = cythonize(PyFaust, compiler_directives={'language_level': sys.version_info.major }),
packages = [ 'pyfaust' ], packages = [ 'pyfaust', 'pyfaust.tests' ],
url = 'https://faust.inria.fr', url = 'https://faust.inria.fr',
author = 'INRIA', author = 'INRIA',
author_email = 'remi.gribonval@inria.fr', author_email = 'remi.gribonval@inria.fr',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment