diff --git a/declearn/main/utils/__init__.py b/declearn/main/utils/__init__.py index 55267c31ce50beed1695c0a62e3e468fe6bb302f..b779241e131c8e46c9999f846c4f3b1f0af0b2bb 100644 --- a/declearn/main/utils/__init__.py +++ b/declearn/main/utils/__init__.py @@ -6,4 +6,4 @@ from ._checkpoint import Checkpointer from ._constraints import Constraint, ConstraintSet, TimeoutConstraint from ._data_info import AggregationError, aggregate_clients_data_info from ._dataclasses import EvaluateConfig, RegisterConfig, TrainingConfig -from ._early_stop import EarlyStopping +from ._early_stop import EarlyStopping, EarlyStopConfig diff --git a/declearn/main/utils/_early_stop.py b/declearn/main/utils/_early_stop.py index 32e049f29f37c1e6d3d40f3c2334aba3488e194c..86c15346c1750f6a7d6df01b62f06fa68e8279ef 100644 --- a/declearn/main/utils/_early_stop.py +++ b/declearn/main/utils/_early_stop.py @@ -5,8 +5,12 @@ from typing import Optional +from declearn.utils import dataclass_from_init + + __all__ = [ "EarlyStopping", + "EarlyStopConfig", ] @@ -91,3 +95,6 @@ class EarlyStopping: else: self._n_iter_stuck = 0 return self.keep_training + + +EarlyStopConfig = dataclass_from_init(EarlyStopping, name="EarlyStopConfig")