diff --git a/declearn/dataset/tensorflow/_tensorflow.py b/declearn/dataset/tensorflow/_tensorflow.py
index 070d2f76d81d3a07c6f0a32ed074576bc39f0e11..02bf602d95660088a412c7882f7ef1bfb7cee11e 100644
--- a/declearn/dataset/tensorflow/_tensorflow.py
+++ b/declearn/dataset/tensorflow/_tensorflow.py
@@ -334,7 +334,7 @@ def get_batch_function(
 
 def get_stack_function(
     batch_mode: BatchMode,
-) -> Callable[[Union[List[None], List[tf.Tensor]]], Optional[tf.Tensor],]:
+) -> Callable[[Union[List[None], List[tf.Tensor]]], Optional[tf.Tensor]]:
     """Return a function to stack sample-wise atomic elements."""
     if batch_mode == "default":
         return _stack_default
diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py
index 639ac398d80f36cbb6f97585511de71db510a7b1..1ad815967221aacfddde1b15c7b79a2cf526565a 100644
--- a/declearn/model/torch/_model.py
+++ b/declearn/model/torch/_model.py
@@ -373,7 +373,7 @@ class TorchModel(Model):
     def compute_batch_predictions(
         self,
         batch: Batch,
-    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray],]:
+    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
         inputs, y_true, s_wght = self._unpack_batch(batch)
         if y_true is None:
             raise TypeError(
diff --git a/pyproject.toml b/pyproject.toml
index d9d4a1d34ea18ce2f1f4a0383bb4b5324c401b16..90175b68688a06a2be78ce29532160bd633d4286 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -94,7 +94,7 @@ docs = [
 ]
 # test-specific dependencies
 tests = [
-    "black ~= 23.0",
+    "black ~= 24.0",
     "mypy ~= 1.0",
     "pylint ~= 3.0",
     "pytest ~= 7.4",
diff --git a/test/dataset/dataset_testbase.py b/test/dataset/dataset_testbase.py
index 7dc36b80c52fb4cf863a2de0c83f142510cfb8c9..818df3f3286c8bfcaf89224b950fca2bcfb9bdbb 100644
--- a/test/dataset/dataset_testbase.py
+++ b/test/dataset/dataset_testbase.py
@@ -27,7 +27,6 @@ from declearn.test_utils import assert_batch_equal, to_numpy
 
 
 class DatasetTestToolbox:
-
     """TestCase fixture-provider protocol."""
 
     # pylint: disable=too-few-public-methods
@@ -46,8 +45,7 @@ class DatasetTestToolbox:
 
 
 class DatasetTestSuite:
-
-    """Base tests for declearn Dataset abstract methods"""
+    """Base tests for declearn Dataset abstract methods."""
 
     def test_generate_batches_batchsize(self, toolbox: DatasetTestToolbox):
         """Test batch_size argument to test_generate_batches method"""
diff --git a/test/dataset/test_torch_dataset.py b/test/dataset/test_torch_dataset.py
index 1c8e8ce1feda176907b212166c4a3d93a9dc7a00..8af2fa4a43f83302c7faf3faff674f96260487cf 100644
--- a/test/dataset/test_torch_dataset.py
+++ b/test/dataset/test_torch_dataset.py
@@ -43,8 +43,7 @@ SEED = 0
 
 
 class CustomDataset(torch.utils.data.Dataset):
-
-    """Basic torch.utils.data.Dataset for testing purposes"""
+    """Basic torch.utils.data.Dataset for testing purposes."""
 
     def __init__(self, inputs, labels, weights) -> None:
         self.inputs = inputs
@@ -66,8 +65,7 @@ class CustomDataset(torch.utils.data.Dataset):
 
 
 class TorchDatasetTestToolbox(DatasetTestToolbox):
-
-    """Toolbox for Torch Dataset"""
+    """Toolbox for Torch Dataset."""
 
     # pylint: disable=too-few-public-methods
 
diff --git a/test/main/test_checkpoint.py b/test/main/test_checkpoint.py
index 03ee3a2c42f31cfb379bd72f3b425b5562421791..61c26f17ddb10b21eca2a796210fcd3171eb540b 100644
--- a/test/main/test_checkpoint.py
+++ b/test/main/test_checkpoint.py
@@ -89,8 +89,7 @@ def create_config_file(checkpointer: Checkpointer, type_obj: str) -> str:
 
 
 class TestCheckpointer:
-
-    """Unit tests for Checkpointer class"""
+    """Unit tests for Checkpointer class."""
 
     def test_init_default(self, tmp_path: str) -> None:
         """Test `Checkpointer.__init__` with `max_history=None`."""
diff --git a/test/metrics/test_binary_roc.py b/test/metrics/test_binary_roc.py
index c1a290743f007ee1c1dcd4ee8f65f4db18deabaa..5010f712bfed2cf8ecd9ac3c2a15fa5f48cd33c0 100644
--- a/test/metrics/test_binary_roc.py
+++ b/test/metrics/test_binary_roc.py
@@ -78,13 +78,11 @@ def test_case_fixture(
     )
 
 
-def _test_case_1d() -> (
-    Tuple[
-        Dict[str, np.ndarray],
-        Dict[str, Union[float, np.ndarray]],
-        Dict[str, Union[float, np.ndarray]],
-    ]
-):
+def _test_case_1d() -> Tuple[
+    Dict[str, np.ndarray],
+    Dict[str, Union[float, np.ndarray]],
+    Dict[str, Union[float, np.ndarray]],
+]:
     """Return a test case with 1-D samples (standard binary classif)."""
     # similar inputs as for Binary APR; pylint: disable=duplicate-code
     inputs = {
@@ -124,13 +122,11 @@ def _test_case_1d() -> (
     return inputs, states, scores
 
 
-def _test_case_2d() -> (
-    Tuple[
-        Dict[str, np.ndarray],
-        Dict[str, Union[float, np.ndarray]],
-        Dict[str, Union[float, np.ndarray]],
-    ]
-):
+def _test_case_2d() -> Tuple[
+    Dict[str, np.ndarray],
+    Dict[str, Union[float, np.ndarray]],
+    Dict[str, Union[float, np.ndarray]],
+]:
     """Return a test case with 2-D samples (multilabel binary classif)."""
     # similar inputs as for Binary APR; pylint: disable=duplicate-code
     inputs = {