From e084fd36edf0b5948cd5bd07bcec511bad57b914 Mon Sep 17 00:00:00 2001
From: manuelpett <manuel.petit@inria.fr>
Date: Mon, 25 Mar 2024 17:46:05 +0100
Subject: [PATCH] Add tests for visu_pyvista & fix small error in
 _visu_feDiriclet (2D data need to be convert in 3D)

---
 src/bvpy/utils/visu_pyvista.py |  2 +-
 test/test_visu_pyvista.py      | 93 ++++++++++++++++++++++++++++++++++
 2 files changed, 94 insertions(+), 1 deletion(-)
 create mode 100644 test/test_visu_pyvista.py

diff --git a/src/bvpy/utils/visu_pyvista.py b/src/bvpy/utils/visu_pyvista.py
index 4b76200..c5775ba 100644
--- a/src/bvpy/utils/visu_pyvista.py
+++ b/src/bvpy/utils/visu_pyvista.py
@@ -340,7 +340,7 @@ def _visu_feDiriclet(diri, val_range = 'auto', cmap: str = 'inferno', plotter: p
             coords_diri = np.array([dofs[i] for i in values.keys()])
             scalars = np.array(list(values.values()))
 
-            return pv.PointSet(coords_diri), scalars
+            return pv.PointSet(_check_is3D(coords_diri)), scalars
 
         elif (L > 1) and (L < 4):  # vector values
             # - reorganize the values
diff --git a/test/test_visu_pyvista.py b/test/test_visu_pyvista.py
new file mode 100644
index 0000000..de0c223
--- /dev/null
+++ b/test/test_visu_pyvista.py
@@ -0,0 +1,93 @@
+import unittest
+import shutil
+# -- test class
+
+from bvpy.utils.visu_pyvista import (_check_is3D, _check_scalar_bar_title, _visu_feDiriclet, _visu_feFunction,
+                                     _visu_feFunctionSizet, _visu_feMesh, visualize)
+
+import pyvista as pv
+import fenics as fe
+import numpy as np
+class TestVisu(unittest.TestCase):
+    def setUp(self):
+        """Initiates the test.
+        """
+        self.mesh = fe.UnitSquareMesh(2, 2)
+        self.V = fe.FunctionSpace(self.mesh, 'P', 1)
+        self.f = fe.project(fe.Constant(0), self.V)
+        self.bc = fe.DirichletBC(self.V, fe.Constant(0), "on_boundary")
+        self.mf = fe.MeshFunction("size_t", self.mesh, 2)
+
+    def tearDown(self):
+        """Concludes and closes the test.
+        """
+        pass
+
+    def test_is3D(self):
+        self.assertEqual(_check_is3D(np.ones((10, 1))).shape[1], 3)  # check if 1D data are converted in 3D data
+        self.assertEqual(_check_is3D(np.ones((10, 2))).shape[1], 3) # check if 2D data are converted in 3D data
+        self.assertEqual(_check_is3D(np.ones((10, 3))).shape[1], 3)  # check if 3D data re unchanged
+        self.assertRaises(AssertionError, _check_is3D, np.ones((10, 4)))  # wrong input if 4D
+
+    def test_check_scalar_bar_title(self):
+        pl = pv.Plotter(off_screen=True)
+        pl = _visu_feFunction(self.f, plotter=pl, show_plot=False)
+        scalar_bar_name = list(pl.scalar_bars.keys())[0]
+        self.assertNotEqual(scalar_bar_name, _check_scalar_bar_title(pl, scalar_bar_name)) # generate unique name ?
+        pl.close()
+
+    def test_visu_mesh(self):
+        # - use dedicated function
+        pl = pv.Plotter(off_screen=True)
+        self.assertIsNotNone(_visu_feMesh(self.mesh, plotter=pl, show_plot=False))
+        pl.close()
+
+        # - use generic function
+        pl = pv.Plotter(off_screen=True)
+        self.assertIsNotNone(visualize(self.mesh, plotter=pl, show_plot=False))
+        pl.close()
+
+    def test_visu_function(self):
+        # - use dedicated function
+        pl = pv.Plotter(off_screen=True)
+        self.assertIsNotNone(_visu_feFunction(self.f, plotter=pl, scale=1, show_plot=False))
+        pl.close()
+
+        # - use generic function
+        pl = pv.Plotter(off_screen=True)
+        self.assertIsNotNone(visualize(self.f, plotter=pl, scale=1, show_plot=False))
+        pl.close()
+    def test_visu_dirichlet(self):
+        # - use dedicated function
+        pl = pv.Plotter(off_screen=True)
+        self.assertIsNotNone(_visu_feDiriclet(self.bc, plotter=pl, show_plot=False))
+        pl.close()
+
+        # - use generic function
+        pl = pv.Plotter(off_screen=True)
+        self.assertIsNotNone(visualize(self.bc, plotter=pl, show_plot=False))
+        pl.close()
+    def test_visu_meshfunc(self):
+        # - use dedicated function
+        pl = pv.Plotter(off_screen=True)
+        self.assertIsNotNone(_visu_feFunctionSizet(self.mf, plotter=pl, show_plot=False, cmap="inferno"))
+        pl.close()
+
+        # - use generic function
+        pl = pv.Plotter(off_screen=True)
+        self.assertIsNotNone(visualize(self.mf, plotter=pl, show_plot=False, cmap="inferno"))
+        pl.close()
+    def test_multiplot(self):
+        # - use generic function
+        pl = pv.Plotter(off_screen=True, shape=(1, 4))
+        pl.subplot(0, 0)
+        visualize(self.mesh, plotter=pl, show_plot=False)
+        pl.subplot(0, 1)
+        visualize(self.bc, plotter=pl, show_plot=False)
+        pl.subplot(0, 2)
+        visualize(self.mf, plotter=pl, show_plot=False, cmap="inferno")
+        pl.subplot(0, 3)
+        visualize(self.f, plotter=pl, scale=1, show_plot=False)
+        self.assertIsNotNone(pl)
+        pl.close()
+
-- 
GitLab