diff --git a/declearn/main/privacy/_dp_trainer.py b/declearn/main/privacy/_dp_trainer.py index fb79e7837b467cbedfa9d730da08829bc288f179..60b419aa79270d8b714dc66a6eca3cb4402223e0 100644 --- a/declearn/main/privacy/_dp_trainer.py +++ b/declearn/main/privacy/_dp_trainer.py @@ -72,9 +72,12 @@ class DPTrainingManager(TrainingManager): valid_data: Optional[Dataset] = None, metrics: Union[MetricSet, List[MetricInputType], None] = None, logger: Union[logging.Logger, str, None] = None, + verbose: bool = True, ) -> None: # inherited signature; pylint: disable=too-many-arguments - super().__init__(model, optim, train_data, valid_data, metrics, logger) + super().__init__( + model, optim, train_data, valid_data, metrics, logger, verbose + ) # Add DP-related fields: accountant, clipping norm and budget. self.accountant = None # type: Optional[IAccountant] self.sclip_norm = None # type: Optional[float] diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index 901d953868e64b48a84d30a945c59d680862edb8..92ec0e80e603320765e8e2c6e3a29d2970b728a5 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -21,6 +21,7 @@ import logging from typing import Any, ClassVar, Dict, List, Optional, Union import numpy as np +import tqdm from declearn.communication import messaging from declearn.dataset import Dataset @@ -51,6 +52,7 @@ class TrainingManager: valid_data: Optional[Dataset] = None, metrics: Union[MetricSet, List[MetricInputType], None] = None, logger: Union[logging.Logger, str, None] = None, + verbose: bool = True, ) -> None: """Instantiate the client-side training and evaluation process. @@ -74,6 +76,9 @@ class TrainingManager: Logger to use, or name of a logger to set up with `declearn.utils.get_logger`. If None, use `type(self).__name__`. + verbose: bool, default=True + Whether to display progress bars when running training + and validation rounds. """ # arguments serve modularity; pylint: disable=too-many-arguments self.model = model @@ -84,6 +89,7 @@ class TrainingManager: if not isinstance(logger, logging.Logger): logger = get_logger(logger or f"{type(self).__name__}") self.logger = logger + self.verbose = verbose def _prepare_metrics( self, @@ -224,6 +230,8 @@ class TrainingManager: ) # Run batch train steps for as long as constraints allow it. stop_training = False + if self.verbose: + progress_bar = tqdm.tqdm(desc="Training round", unit=" steps") while not (stop_training or epochs.saturated): for batch in self.train_data.generate_batches(**batch_cfg): try: @@ -232,6 +240,8 @@ class TrainingManager: self.logger.warning("Interrupting training round: %s", exc) stop_training = True break + if self.verbose: + progress_bar.update() constraints.increment() if constraints.saturated: stop_training = True @@ -334,9 +344,13 @@ class TrainingManager: self.metrics.reset() # Run batch evaluation steps for as long as constraints allow it. dataset = self.valid_data or self.train_data + if self.verbose: + progress_bar = tqdm.tqdm(desc="Evaluation round", unit=" batches") for batch in dataset.generate_batches(**batch_cfg): inputs = self.model.compute_batch_predictions(batch) self.metrics.update(*inputs) + if self.verbose: + progress_bar.update() constraints.increment() if constraints.saturated: break diff --git a/pyproject.toml b/pyproject.toml index 131e9ff8e927a08496c02f92985cb5b8219a51a3..1a3f87a1d16bce78f91d72d61768407cc0008c14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "requests ~= 2.18", "scikit-learn >= 1.0", "tomli >= 2.0 ; python_version < '3.11'", + "tqdm ~= 4.62", "typing_extensions >= 4.0", ]