From 21f9b0d96efcfa2f0286bc96db56aeb61bf93d80 Mon Sep 17 00:00:00 2001 From: BIGAUD Nathan <nathan.bigaud@inria.fr> Date: Wed, 1 Feb 2023 17:29:25 +0100 Subject: [PATCH] Add ClassVar type-hints. Details : * In `/optimizer`, `/metrics`, and `/aggregator` :`name` and `aux_name` * In `/communications` : `protocol` * In `/dataset` : `_type_key` * In `/data_info` : `field`, `type`, and `doc` Excluding all classes decorated with `@dataclasses.dataclass`, see https://stackoverflow.com/a/52099457 --- declearn/aggregator/_api.py | 6 ++--- declearn/aggregator/_base.py | 8 +++--- declearn/aggregator/_gma.py | 8 +++--- declearn/communication/api/_client.py | 8 ++---- declearn/communication/api/_server.py | 4 +-- declearn/communication/grpc/_client.py | 7 +++-- declearn/communication/websockets/_client.py | 4 +-- declearn/data_info/_base.py | 9 ++++--- declearn/data_info/_fields.py | 27 ++++++++++---------- declearn/dataset/_base.py | 5 ++-- declearn/dataset/_inmemory.py | 5 ++-- declearn/main/utils/_training.py | 4 +-- declearn/metrics/_api.py | 2 +- declearn/metrics/_classif.py | 6 ++--- declearn/metrics/_mean.py | 6 ++--- declearn/metrics/_roc_auc.py | 5 ++-- declearn/optimizer/modules/_adaptive.py | 11 ++++---- declearn/optimizer/modules/_api.py | 7 +++-- declearn/optimizer/modules/_clipping.py | 6 ++--- declearn/optimizer/modules/_momentum.py | 8 +++--- declearn/optimizer/modules/_noise.py | 6 ++--- declearn/optimizer/modules/_scaffold.py | 12 ++++----- declearn/optimizer/regularizers/_api.py | 5 ++-- declearn/optimizer/regularizers/_base.py | 9 +++---- examples/adding_rmsprop/readme.md | 2 +- test/optimizer/test_optimizer.py | 6 ++--- 26 files changed, 82 insertions(+), 104 deletions(-) diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index e2cd5df4..9e3fa75e 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -3,13 +3,11 @@ """Model updates aggregation API.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict - +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.utils import create_types_registry, register_type - __all__ = [ "Aggregator", ] @@ -59,7 +57,7 @@ class Aggregator(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented + name: ClassVar[str] = NotImplemented def __init_subclass__( cls, diff --git a/declearn/aggregator/_base.py b/declearn/aggregator/_base.py index e1b64f6a..419650ad 100644 --- a/declearn/aggregator/_base.py +++ b/declearn/aggregator/_base.py @@ -2,12 +2,10 @@ """FedAvg-like mean-aggregation class.""" -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional - -from declearn.model.api import Vector from declearn.aggregator._api import Aggregator - +from declearn.model.api import Vector __all__ = [ "AveragingAggregator", @@ -24,7 +22,7 @@ class AveragingAggregator(Aggregator): that use simple weighting schemes. """ - name = "averaging" + name: ClassVar[str] = "averaging" def __init__( self, diff --git a/declearn/aggregator/_gma.py b/declearn/aggregator/_gma.py index 98363e17..232fb54e 100644 --- a/declearn/aggregator/_gma.py +++ b/declearn/aggregator/_gma.py @@ -2,12 +2,10 @@ """Gradient Masked Averaging aggregation class.""" -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional - -from declearn.model.api import Vector from declearn.aggregator._base import AveragingAggregator - +from declearn.model.api import Vector __all__ = [ "GradientMaskedAveraging", @@ -42,7 +40,7 @@ class GradientMaskedAveraging(AveragingAggregator): https://arxiv.org/abs/2201.11986 """ - name = "gradient-masked-averaging" + name: ClassVar[str] = "gradient-masked-averaging" def __init__( self, diff --git a/declearn/communication/api/_client.py b/declearn/communication/api/_client.py index e0383436..926b5d2f 100644 --- a/declearn/communication/api/_client.py +++ b/declearn/communication/api/_client.py @@ -5,8 +5,7 @@ import logging import types from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Type, Union - +from typing import Any, ClassVar, Dict, Optional, Type, Union from declearn.communication.messaging import ( Empty, @@ -18,7 +17,6 @@ from declearn.communication.messaging import ( ) from declearn.utils import create_types_registry, get_logger, register_type - __all__ = [ "NetworkClient", ] @@ -58,7 +56,7 @@ class NetworkClient(metaclass=ABCMeta): probably be rejected by the server if the client has not registered. """ - protocol: str = NotImplemented + protocol: ClassVar[str] = NotImplemented def __init_subclass__(cls, register: bool = True) -> None: """Automate the type-registration of NetworkClient subclasses.""" @@ -107,7 +105,6 @@ class NetworkClient(metaclass=ABCMeta): The return type is communication-protocol dependent. """ - return NotImplemented # similar to NetworkServer API; pylint: disable=duplicate-code @@ -118,7 +115,6 @@ class NetworkClient(metaclass=ABCMeta): Note: this method can be called safely even if the client is already running (simply having no effect). """ - return None @abstractmethod async def stop(self) -> None: diff --git a/declearn/communication/api/_server.py b/declearn/communication/api/_server.py index 1194ee30..1f12c1ca 100644 --- a/declearn/communication/api/_server.py +++ b/declearn/communication/api/_server.py @@ -6,7 +6,7 @@ import asyncio import logging import types from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Set, Type, Union +from typing import Any, Dict, Optional, Set, Type, Union, ClassVar from declearn.communication.api._service import MessagesHandler @@ -54,7 +54,7 @@ class NetworkServer(metaclass=ABCMeta): of the awaitable `wait_for_clients` method. """ - protocol: str = NotImplemented + protocol: ClassVar[str] = NotImplemented def __init_subclass__(cls, register: bool = True) -> None: """Automate the type-registration of NetworkServer subclasses.""" diff --git a/declearn/communication/grpc/_client.py b/declearn/communication/grpc/_client.py index 941d7f76..468ce0d8 100644 --- a/declearn/communication/grpc/_client.py +++ b/declearn/communication/grpc/_client.py @@ -3,17 +3,16 @@ """Client-side communication endpoint implementation using gRPC""" import logging -from typing import Any, Dict, Optional, Union +from typing import Any, ClassVar, Dict, Optional, Union import grpc # type: ignore from declearn.communication.api import NetworkClient -from declearn.communication.messaging import Message, parse_message_from_string from declearn.communication.grpc.protobufs import message_pb2 from declearn.communication.grpc.protobufs.message_pb2_grpc import ( MessageBoardStub, ) - +from declearn.communication.messaging import Message, parse_message_from_string __all__ = [ "GrpcClient", @@ -26,7 +25,7 @@ CHUNK_LENGTH = 2**22 - 50 # 2**22 - sys.getsizeof("") - 1 class GrpcClient(NetworkClient): """Client-side communication endpoint using gRPC.""" - protocol = "grpc" + protocol: ClassVar[str] = "grpc" def __init__( self, diff --git a/declearn/communication/websockets/_client.py b/declearn/communication/websockets/_client.py index 59627c50..f8774f46 100644 --- a/declearn/communication/websockets/_client.py +++ b/declearn/communication/websockets/_client.py @@ -5,7 +5,7 @@ import asyncio import logging import ssl -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, ClassVar import websockets as ws from websockets.client import WebSocketClientProtocol @@ -25,7 +25,7 @@ CHUNK_LENGTH = 100000 class WebsocketsClient(NetworkClient): """Client-side communication endpoint using WebSockets.""" - protocol = "websockets" + protocol: ClassVar[str] = "websockets" def __init__( self, diff --git a/declearn/data_info/_base.py b/declearn/data_info/_base.py index 99bc1f68..0f0fb9b4 100644 --- a/declearn/data_info/_base.py +++ b/declearn/data_info/_base.py @@ -34,7 +34,7 @@ data_info fields, are implemented (although unexposed) here. import warnings from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Dict, List, Optional, Set, Tuple, Type, ClassVar __all__ = [ @@ -75,9 +75,9 @@ class DataInfoField(metaclass=ABCMeta): is called, run `is_valid` on each and every input. """ - field: str - types: Tuple[Type, ...] - doc: str + field: ClassVar[str] = NotImplemented + types: ClassVar[Tuple[Type, ...]] = NotImplemented + doc: ClassVar[str] = NotImplemented @classmethod def is_valid( @@ -85,6 +85,7 @@ class DataInfoField(metaclass=ABCMeta): value: Any, ) -> bool: """Check that a given value may belong to this field.""" + # false-pos; pylint: disable=isinstance-second-argument-not-valid-type return isinstance(value, cls.types) @classmethod diff --git a/declearn/data_info/_fields.py b/declearn/data_info/_fields.py index 7d456646..d84e9e08 100644 --- a/declearn/data_info/_fields.py +++ b/declearn/data_info/_fields.py @@ -2,13 +2,12 @@ """DataInfoField subclasses specifying common 'data_info' metadata fields.""" -from typing import Any, List, Optional, Set +from typing import Any, ClassVar, List, Optional, Set, Tuple, Type import numpy as np from declearn.data_info._base import DataInfoField, register_data_info_field - __all__ = [ "ClassesField", "InputShapeField", @@ -21,9 +20,9 @@ __all__ = [ class ClassesField(DataInfoField): """Specifications for 'classes' data_info field.""" - field = "classes" - types = (list, set, tuple, np.ndarray) - doc = "Set of classification targets, combined by union." + field: ClassVar[str] = "classes" + types: ClassVar[Tuple[Type, ...]] = (list, set, tuple, np.ndarray) + doc: ClassVar[str] = "Set of classification targets, combined by union." @classmethod def is_valid( @@ -47,9 +46,9 @@ class ClassesField(DataInfoField): class InputShapeField(DataInfoField): """Specifications for 'input_shape' data_info field.""" - field = "input_shape" - types = (tuple, list) - doc = "Input features' batched shape, checked to be equal." + field: ClassVar[str] = "input_shape" + types: ClassVar[Tuple[Type, ...]] = (tuple, list) + doc: ClassVar[str] = "Input features' batched shape, checked to be equal." @classmethod def is_valid( @@ -96,9 +95,9 @@ class InputShapeField(DataInfoField): class NbFeaturesField(DataInfoField): """Specifications for 'n_features' data_info field.""" - field = "n_features" - types = (int,) - doc = "Number of input features, checked to be equal." + field: ClassVar[str] = "n_features" + types: ClassVar[Tuple[Type, ...]] = (int,) + doc: ClassVar[str] = "Number of input features, checked to be equal." @classmethod def is_valid( @@ -128,9 +127,9 @@ class NbFeaturesField(DataInfoField): class NbSamplesField(DataInfoField): """Specifications for 'n_samples' data_info field.""" - field = "n_samples" - types = (int,) - doc = "Number of data samples, combined by summation." + field: ClassVar[str] = "n_samples" + types: ClassVar[Tuple[Type, ...]] = (int,) + doc: ClassVar[str] = "Number of data samples, combined by summation." @classmethod def is_valid( diff --git a/declearn/dataset/_base.py b/declearn/dataset/_base.py index 665ba2b9..ffbb4805 100644 --- a/declearn/dataset/_base.py +++ b/declearn/dataset/_base.py @@ -4,12 +4,11 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Iterator, Optional, Set +from typing import Any, ClassVar, Iterator, Optional, Set from declearn.typing import Batch from declearn.utils import access_registered, create_types_registry, json_load - __all__ = [ "DataSpecs", "Dataset", @@ -41,7 +40,7 @@ class Dataset(metaclass=ABCMeta): straightforward to specify as part of FL algorithms. """ - _type_key: str = NotImplemented + _type_key: ClassVar[str] = NotImplemented @abstractmethod def save_to_json( diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 97987cff..6ab54441 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -4,7 +4,7 @@ import functools import os -from typing import Any, Dict, Iterator, List, Optional, Set, Union +from typing import Any, ClassVar, Dict, Iterator, List, Optional, Set, Union import numpy as np import pandas as pd # type: ignore @@ -17,7 +17,6 @@ from declearn.dataset._sparse import sparse_from_file, sparse_to_file from declearn.typing import Batch from declearn.utils import json_dump, json_load, register_type - __all__ = [ "InMemoryDataset", ] @@ -56,7 +55,7 @@ class InMemoryDataset(Dataset): # attributes serve clarity; pylint: disable=too-many-instance-attributes - _type_key = "InMemoryDataset" + _type_key: ClassVar[str] = "InMemoryDataset" def __init__( self, diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index ea4b2106..be142608 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -3,7 +3,7 @@ """Wrapper to run local training and evaluation rounds in a FL process.""" import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Union import numpy as np @@ -96,7 +96,7 @@ class TrainingManager: class LossMetric(MeanMetric, register=False): """Ad hoc Metric wrapping a model's loss function.""" - name = "loss" + name: ClassVar[str] = "loss" def metric_func( self, y_true: np.ndarray, y_pred: np.ndarray diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py index c408b991..54ea61ae 100644 --- a/declearn/metrics/_api.py +++ b/declearn/metrics/_api.py @@ -104,7 +104,7 @@ class Metric(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: ClassVar[str] + name: ClassVar[str] = NotImplemented def __init__( self, diff --git a/declearn/metrics/_classif.py b/declearn/metrics/_classif.py index b6cb9dfb..6dc60811 100644 --- a/declearn/metrics/_classif.py +++ b/declearn/metrics/_classif.py @@ -2,7 +2,7 @@ """Iterative and federative classification evaluation metrics.""" -from typing import Any, Collection, Dict, Optional, Union +from typing import Any, ClassVar, Collection, Dict, Optional, Union import numpy as np import sklearn # type: ignore @@ -41,7 +41,7 @@ class BinaryAccuracyPrecisionRecall(Metric): Confusion matrix of predictions. Values: [[TN, FP], [FN, TP]] """ - name = "binary-classif" + name: ClassVar[str] = "binary-classif" def __init__( self, @@ -131,7 +131,7 @@ class MulticlassAccuracyPrecisionRecall(Metric): were predicted to belong to label j. """ - name = "multi-classif" + name: ClassVar[str] = "multi-classif" def __init__( self, diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py index 573cbed9..b9f8abc3 100644 --- a/declearn/metrics/_mean.py +++ b/declearn/metrics/_mean.py @@ -3,7 +3,7 @@ """Iterative and federative generic evaluation metrics.""" from abc import ABCMeta, abstractmethod -from typing import Dict, Optional, Union +from typing import ClassVar, Dict, Optional, Union import numpy as np @@ -110,7 +110,7 @@ class MeanAbsoluteError(MeanMetric): summed over channels for (>=2)-dimensional inputs). """ - name = "mae" + name: ClassVar[str] = "mae" def metric_func( self, @@ -138,7 +138,7 @@ class MeanSquaredError(MeanMetric): summed over channels for (>=2)-dimensional inputs). """ - name = "mse" + name: ClassVar[str] = "mse" def metric_func( self, diff --git a/declearn/metrics/_roc_auc.py b/declearn/metrics/_roc_auc.py index 7a3d2980..3772f84d 100644 --- a/declearn/metrics/_roc_auc.py +++ b/declearn/metrics/_roc_auc.py @@ -2,7 +2,7 @@ """Iterative and federative ROC AUC evaluation metrics.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, Optional, Tuple, Union import numpy as np import sklearn # type: ignore @@ -10,7 +10,6 @@ import sklearn.metrics # type: ignore from declearn.metrics._api import Metric - __all__ = [ "BinaryRocAUC", ] @@ -42,7 +41,7 @@ class BinaryRocAUC(Metric): unless its """ - name = "binary-roc" + name: ClassVar[str] = "binary-roc" def __init__( self, diff --git a/declearn/optimizer/modules/_adaptive.py b/declearn/optimizer/modules/_adaptive.py index 38528f73..3c9ea808 100644 --- a/declearn/optimizer/modules/_adaptive.py +++ b/declearn/optimizer/modules/_adaptive.py @@ -2,13 +2,12 @@ """Adaptive algorithms for optimizers, implemented as plug-in modules.""" -from typing import Any, Dict, Optional, Union +from typing import Any, ClassVar, Dict, Optional, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule from declearn.optimizer.modules._momentum import EWMAModule, YogiMomentumModule - __all__ = [ "AdaGradModule", "AdamModule", @@ -38,7 +37,7 @@ class AdaGradModule(OptiModule): https://jmlr.org/papers/v12/duchi11a.html """ - name = "adagrad" + name: ClassVar[str] = "adagrad" def __init__( self, @@ -102,7 +101,7 @@ class RMSPropModule(OptiModule): Average of its Recent Magnitude. """ - name = "rmsprop" + name: ClassVar[str] = "rmsprop" def __init__( self, @@ -184,7 +183,7 @@ class AdamModule(OptiModule): https://arxiv.org/abs/1904.09237 """ - name = "adam" + name: ClassVar[str] = "adam" def __init__( self, @@ -304,7 +303,7 @@ class YogiModule(AdamModule): https://arxiv.org/abs/1904.09237 """ - name = "yogi" + name: ClassVar[str] = "yogi" def __init__( self, diff --git a/declearn/optimizer/modules/_api.py b/declearn/optimizer/modules/_api.py index 75444d6b..5d02f7c7 100644 --- a/declearn/optimizer/modules/_api.py +++ b/declearn/optimizer/modules/_api.py @@ -3,7 +3,7 @@ """Base API for plug-in gradients-alteration algorithms.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional from declearn.model.api import Vector from declearn.utils import ( @@ -76,9 +76,8 @@ class OptiModule(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented - - aux_name: Optional[str] = None + name: ClassVar[str] = NotImplemented + aux_name: ClassVar[Optional[str]] = None def __init_subclass__( cls, diff --git a/declearn/optimizer/modules/_clipping.py b/declearn/optimizer/modules/_clipping.py index c30691bb..cfda6cd2 100644 --- a/declearn/optimizer/modules/_clipping.py +++ b/declearn/optimizer/modules/_clipping.py @@ -2,13 +2,11 @@ """Batch-averaged gradients clipping plug-in modules.""" -from typing import Any, Dict - +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule - __all__ = ["L2Clipping"] @@ -32,7 +30,7 @@ class L2Clipping(OptiModule): to bound the sensitivity associated to that action. """ - name = "l2-clipping" + name: ClassVar[str] = "l2-clipping" def __init__( self, diff --git a/declearn/optimizer/modules/_momentum.py b/declearn/optimizer/modules/_momentum.py index 1b97bf27..cdc5580f 100644 --- a/declearn/optimizer/modules/_momentum.py +++ b/declearn/optimizer/modules/_momentum.py @@ -2,7 +2,7 @@ """Base API and common examples of plug-in gradients-alteration algorithms.""" -from typing import Any, Dict, Union +from typing import Any, ClassVar, Dict, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule @@ -48,7 +48,7 @@ class MomentumModule(OptiModule): https://proceedings.mlr.press/v28/sutskever13.pdf """ - name = "momentum" + name: ClassVar[str] = "momentum" def __init__( self, @@ -114,7 +114,7 @@ class EWMAModule(OptiModule): decaying moving-average of past gradients. """ - name = "ewma" + name: ClassVar[str] = "ewma" def __init__( self, @@ -184,7 +184,7 @@ class YogiMomentumModule(EWMAModule): Adaptive Methods for Nonconvex Optimization. """ - name = "yogi-momentum" + name: ClassVar[str] = "yogi-momentum" def run( self, diff --git a/declearn/optimizer/modules/_noise.py b/declearn/optimizer/modules/_noise.py index cf91f076..651300c0 100644 --- a/declearn/optimizer/modules/_noise.py +++ b/declearn/optimizer/modules/_noise.py @@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod from random import SystemRandom -from typing import Any, Dict, Optional, Tuple +from typing import Any, ClassVar, Dict, Optional, Tuple import numpy as np import scipy.stats # type: ignore @@ -26,7 +26,7 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): or slower cryptographically secure pseudo-random numbers (CSPRN). """ - name = "abstract-noise" + name: ClassVar[str] = "abstract-noise" def __init__( self, @@ -98,7 +98,7 @@ class GaussianNoiseModule(NoiseModule): or slower cryptographically secure pseudo-random numbers (CSPRN). """ - name = "gaussian-noise" + name: ClassVar[str] = "gaussian-noise" def __init__( self, diff --git a/declearn/optimizer/modules/_scaffold.py b/declearn/optimizer/modules/_scaffold.py index b36ce060..80d2b212 100644 --- a/declearn/optimizer/modules/_scaffold.py +++ b/declearn/optimizer/modules/_scaffold.py @@ -16,13 +16,11 @@ References: https://arxiv.org/abs/1910.06378 """ -from typing import Any, Dict, List, Optional, Union - +from typing import Any, ClassVar, Dict, List, Optional, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule - __all__ = [ "ScaffoldClientModule", "ScaffoldServerModule", @@ -80,8 +78,8 @@ class ScaffoldClientModule(OptiModule): https://arxiv.org/abs/1910.06378 """ - name = "scaffold-client" - aux_name = "scaffold" + name: ClassVar[str] = "scaffold-client" + aux_name: ClassVar[str] = "scaffold" def __init__( self, @@ -214,8 +212,8 @@ class ScaffoldServerModule(OptiModule): https://arxiv.org/abs/1910.06378 """ - name = "scaffold-server" - aux_name = "scaffold" + name: ClassVar[str] = "scaffold-server" + aux_name: ClassVar[str] = "scaffold" def __init__( self, diff --git a/declearn/optimizer/regularizers/_api.py b/declearn/optimizer/regularizers/_api.py index a199dab0..c45e52c2 100644 --- a/declearn/optimizer/regularizers/_api.py +++ b/declearn/optimizer/regularizers/_api.py @@ -3,7 +3,7 @@ """Base API for loss regularization optimizer plug-ins.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.utils import ( @@ -12,7 +12,6 @@ from declearn.utils import ( register_type, ) - __all__ = [ "Regularizer", ] @@ -68,7 +67,7 @@ class Regularizer(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented + name: ClassVar[str] = NotImplemented def __init_subclass__( cls, diff --git a/declearn/optimizer/regularizers/_base.py b/declearn/optimizer/regularizers/_base.py index 0bcac9cf..f05376f3 100644 --- a/declearn/optimizer/regularizers/_base.py +++ b/declearn/optimizer/regularizers/_base.py @@ -2,12 +2,11 @@ """Common plug-in loss-regularization plug-ins.""" -from typing import Optional +from typing import ClassVar, Optional from declearn.model.api import Vector from declearn.optimizer.regularizers._api import Regularizer - __all__ = [ "FedProxRegularizer", "LassoRegularizer", @@ -40,7 +39,7 @@ class FedProxRegularizer(Regularizer): https://arxiv.org/abs/1812.06127 """ - name = "fedprox" + name: ClassVar[str] = "fedprox" def __init__( self, @@ -76,7 +75,7 @@ class LassoRegularizer(Regularizer): grads += alpha * sign(weights) """ - name = "lasso" + name: ClassVar[str] = "lasso" def run( self, @@ -97,7 +96,7 @@ class RidgeRegularizer(Regularizer): grads += alpha * 2 * weights """ - name = "ridge" + name: ClassVar[str] = "ridge" def run( self, diff --git a/examples/adding_rmsprop/readme.md b/examples/adding_rmsprop/readme.md index 7d47d6e9..40b748e5 100644 --- a/examples/adding_rmsprop/readme.md +++ b/examples/adding_rmsprop/readme.md @@ -65,7 +65,7 @@ class RMSPropModule(OptiModule): # Identifier, that must be unique across modules for type-registration # purposes. This enables specifying the module in configuration files. - name = "rmsprop" + name:ClassVar[str] = "rmsprop" # Define optimizer parameters, here beta and eps diff --git a/test/optimizer/test_optimizer.py b/test/optimizer/test_optimizer.py index b4fe63fa..ac5e7b01 100644 --- a/test/optimizer/test_optimizer.py +++ b/test/optimizer/test_optimizer.py @@ -3,7 +3,7 @@ """Unit tests for `declearn.optimizer.Optimizer`.""" -from typing import Any, Dict, Tuple +from typing import Any, ClassVar, Dict, Tuple from unittest import mock from uuid import uuid4 @@ -19,7 +19,7 @@ from declearn.test_utils import assert_json_serializable_dict class MockOptiModule(OptiModule): """Type-registered mock OptiModule subclass.""" - name = f"mock-{uuid4()}" + name: ClassVar[str] = f"mock-{uuid4()}" def __init__(self, **kwargs: Any) -> None: super().__init__() @@ -35,7 +35,7 @@ class MockOptiModule(OptiModule): class MockRegularizer(Regularizer): """Type-registered mock Regularizer subclass.""" - name = f"mock-{uuid4()}" + name: ClassVar[str] = f"mock-{uuid4()}" def __init__(self, **kwargs: Any) -> None: super().__init__() -- GitLab