From 480a5995f21651c7e4b1f7bc5cd00696c7160e82 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Wed, 30 Aug 2023 15:01:11 +0200
Subject: [PATCH] Add unit tests for 'EarlyStopping'.

---
 test/main/test_early_stopping.py | 108 +++++++++++++++++++++++++++++++
 1 file changed, 108 insertions(+)
 create mode 100644 test/main/test_early_stopping.py

diff --git a/test/main/test_early_stopping.py b/test/main/test_early_stopping.py
new file mode 100644
index 00000000..764b59ff
--- /dev/null
+++ b/test/main/test_early_stopping.py
@@ -0,0 +1,108 @@
+# coding: utf-8
+
+# Copyright 2023 Inria (Institut National de Recherche en Informatique
+# et Automatique)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for 'declearn.main.utils.EarlyStopping'."""
+
+
+from declearn.main.utils import EarlyStopping
+
+
+class TestEarlyStopping:
+    """Unit tests for 'declearn.main.utils.EarlyStopping'."""
+
+    def test_keep_training_initial(self) -> None:
+        """Test that a brand new EarlyStopping indicates to train."""
+        early_stop = EarlyStopping()
+        assert early_stop.keep_training
+
+    def test_update_first(self) -> None:
+        """Test that an instantiated EarlyStopping's update works."""
+        early_stop = EarlyStopping()
+        keep_going = early_stop.update(1.0)
+        assert keep_going
+        assert keep_going == early_stop.keep_training
+
+    def test_update_twice(self) -> None:
+        """Test that an EarlyStopping can be reached in a simple case."""
+        early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=True)
+        assert early_stop.update(1.0)
+        assert not early_stop.update(1.0)
+        assert not early_stop.keep_training
+
+    def test_reset_after_stopping(self) -> None:
+        """Test that 'EarlyStopping.reset()' works properly."""
+        # Reach the criterion once.
+        early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=True)
+        assert early_stop.update(1.0)
+        assert not early_stop.update(1.0)
+        assert not early_stop.keep_training
+        # Reset and test that the criterion has been properly reset.
+        early_stop.reset()
+        assert early_stop.keep_training
+        assert early_stop.update(1.0)
+        # Reach the criterion for the second time.
+        assert not early_stop.update(1.0)
+        assert not early_stop.keep_training
+
+    def test_with_two_steps_patience(self) -> None:
+        """Test an EarlyStopping criterion with 2-steps patience."""
+        early_stop = EarlyStopping(tolerance=0.0, patience=2, decrease=True)
+        assert early_stop.update(1.0)
+        assert early_stop.update(1.5)  # patience tempers stopping
+        assert early_stop.update(0.0)  # patience is reset
+        assert early_stop.update(0.5)  # patience tempers stopping
+        assert not early_stop.update(0.2)  # patience is exhausted
+
+    def test_with_absolute_tolerance_positive(self) -> None:
+        """Test an EarlyStopping criterion with 0.2 absolute tolerance."""
+        early_stop = EarlyStopping(tolerance=0.2, patience=1, decrease=True)
+        assert early_stop.update(1.0)
+        assert early_stop.update(0.7)
+        assert not early_stop.update(0.6)  # progress below tolerance
+
+    def test_with_absolute_tolerance_negative(self) -> None:
+        """Test an EarlyStopping criterion with -0.5 absolute tolerance."""
+        early_stop = EarlyStopping(tolerance=-0.5, patience=1, decrease=True)
+        assert early_stop.update(1.0)
+        assert early_stop.update(1.2)  # regression below tolerance
+        assert not early_stop.update(1.6)  # regression above tolerance
+
+    def test_with_relative_tolerance_positive(self) -> None:
+        """Test an EarlyStopping criterion with 0.1 relative tolerance."""
+        early_stop = EarlyStopping(
+            tolerance=0.1, patience=1, decrease=True, relative=True
+        )
+        assert early_stop.update(1.0)
+        assert early_stop.update(0.8)  # progress above tolerance
+        assert not early_stop.update(0.75)  # progress below tolerance
+
+    def test_with_relative_tolerance_negative(self) -> None:
+        """Test an EarlyStopping criterion with -0.1 relative tolerance."""
+        early_stop = EarlyStopping(
+            tolerance=-0.1, patience=1, decrease=True, relative=True
+        )
+        assert early_stop.update(1.0)
+        assert early_stop.update(0.80)  # progress
+        assert early_stop.update(0.85)  # regression below tolerance
+        assert not early_stop.update(0.89)  # regression above tolerance
+
+    def test_with_increasing_metric(self) -> None:
+        """Test an EarlyStopping that monitors an increasing metric."""
+        early_stop = EarlyStopping(tolerance=0.0, patience=1, decrease=False)
+        assert early_stop.update(1.0)
+        assert early_stop.update(2.0)  # progress
+        assert not early_stop.update(1.5)  # regression (no patience/tolerance)
-- 
GitLab