From fc20c9ddf9765ef7f652fff654be4f6e37cd4b0c Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Fri, 16 Feb 2024 17:45:39 +0100 Subject: [PATCH] Update 'black' version to '2.4.*'. --- declearn/dataset/tensorflow/_tensorflow.py | 2 +- declearn/model/torch/_model.py | 2 +- pyproject.toml | 2 +- test/dataset/dataset_testbase.py | 4 +--- test/dataset/test_torch_dataset.py | 6 ++---- test/main/test_checkpoint.py | 3 +-- test/metrics/test_binary_roc.py | 24 +++++++++------------- 7 files changed, 17 insertions(+), 26 deletions(-) diff --git a/declearn/dataset/tensorflow/_tensorflow.py b/declearn/dataset/tensorflow/_tensorflow.py index 070d2f76..02bf602d 100644 --- a/declearn/dataset/tensorflow/_tensorflow.py +++ b/declearn/dataset/tensorflow/_tensorflow.py @@ -334,7 +334,7 @@ def get_batch_function( def get_stack_function( batch_mode: BatchMode, -) -> Callable[[Union[List[None], List[tf.Tensor]]], Optional[tf.Tensor],]: +) -> Callable[[Union[List[None], List[tf.Tensor]]], Optional[tf.Tensor]]: """Return a function to stack sample-wise atomic elements.""" if batch_mode == "default": return _stack_default diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 639ac398..1ad81596 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -373,7 +373,7 @@ class TorchModel(Model): def compute_batch_predictions( self, batch: Batch, - ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray],]: + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: inputs, y_true, s_wght = self._unpack_batch(batch) if y_true is None: raise TypeError( diff --git a/pyproject.toml b/pyproject.toml index d9d4a1d3..90175b68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ docs = [ ] # test-specific dependencies tests = [ - "black ~= 23.0", + "black ~= 24.0", "mypy ~= 1.0", "pylint ~= 3.0", "pytest ~= 7.4", diff --git a/test/dataset/dataset_testbase.py b/test/dataset/dataset_testbase.py index 7dc36b80..818df3f3 100644 --- a/test/dataset/dataset_testbase.py +++ b/test/dataset/dataset_testbase.py @@ -27,7 +27,6 @@ from declearn.test_utils import assert_batch_equal, to_numpy class DatasetTestToolbox: - """TestCase fixture-provider protocol.""" # pylint: disable=too-few-public-methods @@ -46,8 +45,7 @@ class DatasetTestToolbox: class DatasetTestSuite: - - """Base tests for declearn Dataset abstract methods""" + """Base tests for declearn Dataset abstract methods.""" def test_generate_batches_batchsize(self, toolbox: DatasetTestToolbox): """Test batch_size argument to test_generate_batches method""" diff --git a/test/dataset/test_torch_dataset.py b/test/dataset/test_torch_dataset.py index 1c8e8ce1..8af2fa4a 100644 --- a/test/dataset/test_torch_dataset.py +++ b/test/dataset/test_torch_dataset.py @@ -43,8 +43,7 @@ SEED = 0 class CustomDataset(torch.utils.data.Dataset): - - """Basic torch.utils.data.Dataset for testing purposes""" + """Basic torch.utils.data.Dataset for testing purposes.""" def __init__(self, inputs, labels, weights) -> None: self.inputs = inputs @@ -66,8 +65,7 @@ class CustomDataset(torch.utils.data.Dataset): class TorchDatasetTestToolbox(DatasetTestToolbox): - - """Toolbox for Torch Dataset""" + """Toolbox for Torch Dataset.""" # pylint: disable=too-few-public-methods diff --git a/test/main/test_checkpoint.py b/test/main/test_checkpoint.py index 03ee3a2c..61c26f17 100644 --- a/test/main/test_checkpoint.py +++ b/test/main/test_checkpoint.py @@ -89,8 +89,7 @@ def create_config_file(checkpointer: Checkpointer, type_obj: str) -> str: class TestCheckpointer: - - """Unit tests for Checkpointer class""" + """Unit tests for Checkpointer class.""" def test_init_default(self, tmp_path: str) -> None: """Test `Checkpointer.__init__` with `max_history=None`.""" diff --git a/test/metrics/test_binary_roc.py b/test/metrics/test_binary_roc.py index c1a29074..5010f712 100644 --- a/test/metrics/test_binary_roc.py +++ b/test/metrics/test_binary_roc.py @@ -78,13 +78,11 @@ def test_case_fixture( ) -def _test_case_1d() -> ( - Tuple[ - Dict[str, np.ndarray], - Dict[str, Union[float, np.ndarray]], - Dict[str, Union[float, np.ndarray]], - ] -): +def _test_case_1d() -> Tuple[ + Dict[str, np.ndarray], + Dict[str, Union[float, np.ndarray]], + Dict[str, Union[float, np.ndarray]], +]: """Return a test case with 1-D samples (standard binary classif).""" # similar inputs as for Binary APR; pylint: disable=duplicate-code inputs = { @@ -124,13 +122,11 @@ def _test_case_1d() -> ( return inputs, states, scores -def _test_case_2d() -> ( - Tuple[ - Dict[str, np.ndarray], - Dict[str, Union[float, np.ndarray]], - Dict[str, Union[float, np.ndarray]], - ] -): +def _test_case_2d() -> Tuple[ + Dict[str, np.ndarray], + Dict[str, Union[float, np.ndarray]], + Dict[str, Union[float, np.ndarray]], +]: """Return a test case with 2-D samples (multilabel binary classif).""" # similar inputs as for Binary APR; pylint: disable=duplicate-code inputs = { -- GitLab