diff --git a/declearn/dataset/tensorflow/_tensorflow.py b/declearn/dataset/tensorflow/_tensorflow.py index 070d2f76d81d3a07c6f0a32ed074576bc39f0e11..02bf602d95660088a412c7882f7ef1bfb7cee11e 100644 --- a/declearn/dataset/tensorflow/_tensorflow.py +++ b/declearn/dataset/tensorflow/_tensorflow.py @@ -334,7 +334,7 @@ def get_batch_function( def get_stack_function( batch_mode: BatchMode, -) -> Callable[[Union[List[None], List[tf.Tensor]]], Optional[tf.Tensor],]: +) -> Callable[[Union[List[None], List[tf.Tensor]]], Optional[tf.Tensor]]: """Return a function to stack sample-wise atomic elements.""" if batch_mode == "default": return _stack_default diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 639ac398d80f36cbb6f97585511de71db510a7b1..1ad815967221aacfddde1b15c7b79a2cf526565a 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -373,7 +373,7 @@ class TorchModel(Model): def compute_batch_predictions( self, batch: Batch, - ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray],]: + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: inputs, y_true, s_wght = self._unpack_batch(batch) if y_true is None: raise TypeError( diff --git a/pyproject.toml b/pyproject.toml index d9d4a1d34ea18ce2f1f4a0383bb4b5324c401b16..90175b68688a06a2be78ce29532160bd633d4286 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ docs = [ ] # test-specific dependencies tests = [ - "black ~= 23.0", + "black ~= 24.0", "mypy ~= 1.0", "pylint ~= 3.0", "pytest ~= 7.4", diff --git a/test/dataset/dataset_testbase.py b/test/dataset/dataset_testbase.py index 7dc36b80c52fb4cf863a2de0c83f142510cfb8c9..818df3f3286c8bfcaf89224b950fca2bcfb9bdbb 100644 --- a/test/dataset/dataset_testbase.py +++ b/test/dataset/dataset_testbase.py @@ -27,7 +27,6 @@ from declearn.test_utils import assert_batch_equal, to_numpy class DatasetTestToolbox: - """TestCase fixture-provider protocol.""" # pylint: disable=too-few-public-methods @@ -46,8 +45,7 @@ class DatasetTestToolbox: class DatasetTestSuite: - - """Base tests for declearn Dataset abstract methods""" + """Base tests for declearn Dataset abstract methods.""" def test_generate_batches_batchsize(self, toolbox: DatasetTestToolbox): """Test batch_size argument to test_generate_batches method""" diff --git a/test/dataset/test_torch_dataset.py b/test/dataset/test_torch_dataset.py index 1c8e8ce1feda176907b212166c4a3d93a9dc7a00..8af2fa4a43f83302c7faf3faff674f96260487cf 100644 --- a/test/dataset/test_torch_dataset.py +++ b/test/dataset/test_torch_dataset.py @@ -43,8 +43,7 @@ SEED = 0 class CustomDataset(torch.utils.data.Dataset): - - """Basic torch.utils.data.Dataset for testing purposes""" + """Basic torch.utils.data.Dataset for testing purposes.""" def __init__(self, inputs, labels, weights) -> None: self.inputs = inputs @@ -66,8 +65,7 @@ class CustomDataset(torch.utils.data.Dataset): class TorchDatasetTestToolbox(DatasetTestToolbox): - - """Toolbox for Torch Dataset""" + """Toolbox for Torch Dataset.""" # pylint: disable=too-few-public-methods diff --git a/test/main/test_checkpoint.py b/test/main/test_checkpoint.py index 03ee3a2c42f31cfb379bd72f3b425b5562421791..61c26f17ddb10b21eca2a796210fcd3171eb540b 100644 --- a/test/main/test_checkpoint.py +++ b/test/main/test_checkpoint.py @@ -89,8 +89,7 @@ def create_config_file(checkpointer: Checkpointer, type_obj: str) -> str: class TestCheckpointer: - - """Unit tests for Checkpointer class""" + """Unit tests for Checkpointer class.""" def test_init_default(self, tmp_path: str) -> None: """Test `Checkpointer.__init__` with `max_history=None`.""" diff --git a/test/metrics/test_binary_roc.py b/test/metrics/test_binary_roc.py index c1a290743f007ee1c1dcd4ee8f65f4db18deabaa..5010f712bfed2cf8ecd9ac3c2a15fa5f48cd33c0 100644 --- a/test/metrics/test_binary_roc.py +++ b/test/metrics/test_binary_roc.py @@ -78,13 +78,11 @@ def test_case_fixture( ) -def _test_case_1d() -> ( - Tuple[ - Dict[str, np.ndarray], - Dict[str, Union[float, np.ndarray]], - Dict[str, Union[float, np.ndarray]], - ] -): +def _test_case_1d() -> Tuple[ + Dict[str, np.ndarray], + Dict[str, Union[float, np.ndarray]], + Dict[str, Union[float, np.ndarray]], +]: """Return a test case with 1-D samples (standard binary classif).""" # similar inputs as for Binary APR; pylint: disable=duplicate-code inputs = { @@ -124,13 +122,11 @@ def _test_case_1d() -> ( return inputs, states, scores -def _test_case_2d() -> ( - Tuple[ - Dict[str, np.ndarray], - Dict[str, Union[float, np.ndarray]], - Dict[str, Union[float, np.ndarray]], - ] -): +def _test_case_2d() -> Tuple[ + Dict[str, np.ndarray], + Dict[str, Union[float, np.ndarray]], + Dict[str, Union[float, np.ndarray]], +]: """Return a test case with 2-D samples (multilabel binary classif).""" # similar inputs as for Binary APR; pylint: disable=duplicate-code inputs = {