From 8b67d07294745d5aae37b96a0ca1fa53563d33a1 Mon Sep 17 00:00:00 2001
From: manuelpett <manuel.petit@inria.fr>
Date: Wed, 23 Oct 2024 15:19:14 +0200
Subject: [PATCH] Make the use of VariableDirichlet more flexible. Add tests
 for ConstantDirichlet and VariableDirichlet.

Extended the test suite to include tests for newly added `ConstantDirichlet` and `VariableDirichlet` classes in `test_dirichlet.py`. Also refactored the `VariableDirichlet` class to utilize a new utility function for expression creation and pre-formatting.
---
 src/bvpy/boundary_conditions/dirichlet.py | 31 ++++++----
 test/test_dirichlet.py                    | 75 ++++++++++++++++++++---
 2 files changed, 85 insertions(+), 21 deletions(-)

diff --git a/src/bvpy/boundary_conditions/dirichlet.py b/src/bvpy/boundary_conditions/dirichlet.py
index 7255301..b1d018d 100644
--- a/src/bvpy/boundary_conditions/dirichlet.py
+++ b/src/bvpy/boundary_conditions/dirichlet.py
@@ -23,16 +23,10 @@ import fenics as fe
 from bvpy import logger
 
 import numpy as np
-from sympy import Symbol
-from sympy.parsing.sympy_parser import parse_expr
 
 from .boundary import Boundary
 from bvpy.domains.geometry import boundary_normal
-
-_XYZ = {'x': Symbol('x[0]'),
-        'y': Symbol('x[1]'),
-        'z': Symbol('x[2]')}
-
+from bvpy.utils.pre_processing import create_expression, _stringify_expression_from_string
 
 class ConstantDirichlet(object):
     """Defines a constant Dirichlet condition on a domain.
@@ -121,10 +115,21 @@ class ConstantDirichlet(object):
 class VariableDirichlet(object):
     """Defines a variable Dirichlet condition on a domain.
 
+    The `val` argument allows to pass a string expression that can be :
+
+        val = "sin(x) + cos(y)"    # scalar
+        val = "(x, y * pow(x,2))"  # vector
+
+        Tensor expressions of rank 2 (matrices) may also be created:
+
+        val = "((exp(x), sin(y)),
+                (sin(x), tan(y))"
+
+    See <cmath> for the available functions. https://cplusplus.com/reference/cmath/
     """
 
     def __init__(self, val, boundary="all",
-                 method='topological', subspace=None, degree=None):
+                 method='topological', subspace=None, degree=None, **kwargs):
         """Generator of the VariableDirichlet class.
 
         Parameters
@@ -161,6 +166,7 @@ class VariableDirichlet(object):
         self._method = method
         self._subspace = subspace
         self._degree = degree
+        self._extra_kwargs = kwargs
 
     def apply(self, functionSpace):
         """Applies the specified condition on the actual function space.
@@ -177,12 +183,13 @@ class VariableDirichlet(object):
             The boundary condition in a fenics-readable format.
 
         """
-        parse = parse_expr(self._val, local_dict=_XYZ, evaluate=False)
+        # Pre-format the input value (if vector or tensor)
+        val = _stringify_expression_from_string(self._val)
+
         if self._degree is None:
-            expr = fe.Expression(str(parse),
-                                 element=functionSpace.ufl_element())
+            expr = create_expression(val, functionspace=functionSpace, **self._extra_kwargs)
         else:
-            expr = fe.Expression(str(parse), degree=self._degree)
+            expr = create_expression(val, degree=self._degree, **self._extra_kwargs)
 
         if self._subspace is None:
             dir = fe.DirichletBC(functionSpace, expr,
diff --git a/test/test_dirichlet.py b/test/test_dirichlet.py
index d5e2ae9..8877412 100644
--- a/test/test_dirichlet.py
+++ b/test/test_dirichlet.py
@@ -1,6 +1,6 @@
 import unittest
-from bvpy.boundary_conditions.dirichlet import NormalDirichlet, ZeroDirichlet
-from bvpy.boundary_conditions import Boundary
+from bvpy.boundary_conditions.dirichlet import NormalDirichlet, ZeroDirichlet, ConstantDirichlet, VariableDirichlet
+from bvpy.boundary_conditions import Boundary, dirichlet
 import fenics as fe
 
 from numpy.testing import assert_array_equal
@@ -10,23 +10,80 @@ class TestDirichlet(unittest.TestCase):
     def setUp(self):
         """Initiates the test.
         """
+        self.mesh = fe.UnitSquareMesh(2, 2)
+        self.V = fe.VectorFunctionSpace(self.mesh, 'P', 1)
+        self.scalar_V = fe.FunctionSpace(self.mesh, 'P', 1)
 
     def tearDown(self):
         """Concludes and closes the test.
         """
 
     def test_normal(self):
-        mesh = fe.UnitSquareMesh(2, 2)
-        V = fe.VectorFunctionSpace(mesh, 'P', 1)
-        diri = NormalDirichlet(boundary='all')
         diri = NormalDirichlet(boundary=Boundary('all'))
-        self.assertEqual("<NormalDirichlet object, location: all, value: 1>",
-                         diri.__repr__())
-
-        d = diri.apply(V)
+        self.assertEqual("<NormalDirichlet object, location: all, value: 1>", diri.__repr__())
+        d = diri.apply(self.V)
         self.assertEqual(d.get_boundary_values()[5], 1)
 
     def test_zero(self):
         b = Boundary("all")
         diri = ZeroDirichlet(b, shape=3)
         assert_array_equal(diri._val, [0, 0, 0])
+
+    def test_constant(self):
+        # - vector
+        val = (1, 2)
+        boundary = Boundary('all')
+        diri = ConstantDirichlet(val, boundary=boundary, method='topological')
+        self.assertEqual(diri._val, val)
+        self.assertEqual(diri._boundary, boundary)
+        self.assertEqual(diri._method, 'topological')
+        d = diri.apply(self.V)
+        for bc_val in d.get_boundary_values().values():
+            self.assertIn(bc_val, val)
+
+        # - scalar
+        val = 2
+        boundary = Boundary('all')
+        diri = ConstantDirichlet(val, boundary=boundary)
+        self.assertEqual(diri._val, val)
+        self.assertEqual(diri._boundary, boundary)
+        d = diri.apply(self.scalar_V)
+        bc_vals = d.get_boundary_values().values()
+        assert_array_equal(list(bc_vals), [val] * len(bc_vals))
+
+        # - subspace
+        diri = ConstantDirichlet(val, boundary=boundary, subspace=0)
+        self.assertEqual(diri._val, val)
+        self.assertEqual(diri._boundary, boundary)
+        d = diri.apply(self.V)
+        bc_vals = d.get_boundary_values().values()
+        assert_array_equal(list(bc_vals), [val] * len(bc_vals))
+
+    def test_variable(self):
+        # - vector
+        val = '(x, y)'
+        boundary = Boundary('all')
+        diri = VariableDirichlet(val, boundary=boundary, degree=1)
+        self.assertEqual(diri._val, val)
+        self.assertEqual(diri._boundary, boundary)
+        d = diri.apply(self.V)
+        bc_vals = d.get_boundary_values().values()
+        self.assertTrue(all(val >= 0 and val <= 2 for val in bc_vals))
+
+        # - scalar
+        val = "x"
+        boundary = Boundary('all')
+        diri = VariableDirichlet(val, boundary=boundary, degree=1)
+        self.assertEqual(diri._val, val)
+        self.assertEqual(diri._boundary, boundary)
+        d = diri.apply(self.scalar_V)
+        bc_vals = d.get_boundary_values().values()
+        self.assertTrue(all(val >= 0 and val <= 1 for val in bc_vals))
+
+        # - subspace
+        diri = dirichlet(val, boundary=boundary, subspace=0, degree=1)
+        self.assertEqual(diri._val, val)
+        self.assertEqual(diri._boundary, boundary)
+        d = diri.apply(self.V)
+        bc_vals = d.get_boundary_values().values()
+        self.assertTrue(all(val >= 0 and val <= 1 for val in bc_vals))
-- 
GitLab