Mentions légales du service

Skip to content
Snippets Groups Projects
Commit c20f8ffd authored by hhakim's avatar hhakim
Browse files

Add pyfaust.lazylinop.vstack/hstack and their unit tests.

parent e884e3cc
No related branches found
No related tags found
No related merge requests found
...@@ -540,3 +540,29 @@ def asLazyLinearOp(obj): ...@@ -540,3 +540,29 @@ def asLazyLinearOp(obj):
Creates a LazyLinearOp based on the object obj which must be of a linear operator compatible type. Creates a LazyLinearOp based on the object obj which must be of a linear operator compatible type.
""" """
return LazyLinearOp.create(obj) return LazyLinearOp.create(obj)
def hstack(lop1, obj):
"""
Concatenates lop1 and obj horizontally.
Args:
lop1: a LazyLinearOp object.
obj: any array / matrix / LazyLinearOp compatible in dimensions.
"""
if isLazyLinearOp(lop1):
return lop1.concatenate(obj, axis=1)
else:
raise TypeError('lop1 must be a LazyLinearOp')
def vstack(lop1, obj):
"""
Concatenates lop1 and obj vertically.
Args:
lop1: a LazyLinearOp object.
obj: any array / matrix / LazyLinearOp compatible in dimensions.
"""
if isLazyLinearOp(lop1):
return lop1.concatenate(obj, axis=0)
else:
raise TypeError('lop1 must be a LazyLinearOp')
import unittest import unittest
import pyfaust as pf import pyfaust as pf
from pyfaust.lazylinop import LazyLinearOp from pyfaust.lazylinop import LazyLinearOp, vstack, hstack
import numpy.linalg as LA import numpy.linalg as LA
import numpy as np import numpy as np
...@@ -152,6 +152,18 @@ class TestLazyLinearOpFaust(unittest.TestCase): ...@@ -152,6 +152,18 @@ class TestLazyLinearOpFaust(unittest.TestCase):
self.lopA))), self.lopA))),
0) 0)
self.assertEqual(lcat.shape[0], self.lop.shape[0] + self.lop.shape[0]) self.assertEqual(lcat.shape[0], self.lop.shape[0] + self.lop.shape[0])
# using hstack and vstack
lcat = vstack(self.lop, self.lop2)
self.assertAlmostEqual(LA.norm(lcat.toarray() - np.vstack((self.lopA,
self.lop2A))),
0)
self.assertEqual(lcat.shape[0], self.lop.shape[0] + self.lop2.shape[0])
lcat = hstack(self.lop, self.lop2)
self.assertAlmostEqual(LA.norm(lcat.toarray() - np.hstack((self.lopA,
self.lop2A))),
0)
self.assertEqual(lcat.shape[1], self.lop.shape[1] + self.lop2.shape[1])
def test_chain_ops(self): def test_chain_ops(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment