diff --git a/declearn/aggregator/_api.py b/declearn/aggregator/_api.py index e2cd5df454e68db96e62613ddd299f9bd7d932f6..9e3fa75e34e8a53515750f301f51e7dfca01994a 100644 --- a/declearn/aggregator/_api.py +++ b/declearn/aggregator/_api.py @@ -3,13 +3,11 @@ """Model updates aggregation API.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict - +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.utils import create_types_registry, register_type - __all__ = [ "Aggregator", ] @@ -59,7 +57,7 @@ class Aggregator(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented + name: ClassVar[str] = NotImplemented def __init_subclass__( cls, diff --git a/declearn/aggregator/_base.py b/declearn/aggregator/_base.py index e1b64f6a5da12569de81acb0b1dda858c19b69da..419650adbbb57b900da7c489d91a124f0fea24cf 100644 --- a/declearn/aggregator/_base.py +++ b/declearn/aggregator/_base.py @@ -2,12 +2,10 @@ """FedAvg-like mean-aggregation class.""" -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional - -from declearn.model.api import Vector from declearn.aggregator._api import Aggregator - +from declearn.model.api import Vector __all__ = [ "AveragingAggregator", @@ -24,7 +22,7 @@ class AveragingAggregator(Aggregator): that use simple weighting schemes. """ - name = "averaging" + name: ClassVar[str] = "averaging" def __init__( self, diff --git a/declearn/aggregator/_gma.py b/declearn/aggregator/_gma.py index 98363e1783e740ee230c4894ed4d03467948878a..232fb54ecc6b9ca2fca1e2870c9d727f4c66fd8c 100644 --- a/declearn/aggregator/_gma.py +++ b/declearn/aggregator/_gma.py @@ -2,12 +2,10 @@ """Gradient Masked Averaging aggregation class.""" -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional - -from declearn.model.api import Vector from declearn.aggregator._base import AveragingAggregator - +from declearn.model.api import Vector __all__ = [ "GradientMaskedAveraging", @@ -42,7 +40,7 @@ class GradientMaskedAveraging(AveragingAggregator): https://arxiv.org/abs/2201.11986 """ - name = "gradient-masked-averaging" + name: ClassVar[str] = "gradient-masked-averaging" def __init__( self, diff --git a/declearn/communication/api/_client.py b/declearn/communication/api/_client.py index e0383436fd6adaa8f2797797f4d54a237d93582d..926b5d2f2b55fbeb30d7abdf70af945261284cb9 100644 --- a/declearn/communication/api/_client.py +++ b/declearn/communication/api/_client.py @@ -5,8 +5,7 @@ import logging import types from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Type, Union - +from typing import Any, ClassVar, Dict, Optional, Type, Union from declearn.communication.messaging import ( Empty, @@ -18,7 +17,6 @@ from declearn.communication.messaging import ( ) from declearn.utils import create_types_registry, get_logger, register_type - __all__ = [ "NetworkClient", ] @@ -58,7 +56,7 @@ class NetworkClient(metaclass=ABCMeta): probably be rejected by the server if the client has not registered. """ - protocol: str = NotImplemented + protocol: ClassVar[str] = NotImplemented def __init_subclass__(cls, register: bool = True) -> None: """Automate the type-registration of NetworkClient subclasses.""" @@ -107,7 +105,6 @@ class NetworkClient(metaclass=ABCMeta): The return type is communication-protocol dependent. """ - return NotImplemented # similar to NetworkServer API; pylint: disable=duplicate-code @@ -118,7 +115,6 @@ class NetworkClient(metaclass=ABCMeta): Note: this method can be called safely even if the client is already running (simply having no effect). """ - return None @abstractmethod async def stop(self) -> None: diff --git a/declearn/communication/api/_server.py b/declearn/communication/api/_server.py index 1194ee303aed87ebc3c68740bee7a2c8bdfec0d8..1f12c1ca010a3a6c38d10e9964c2c14630af7450 100644 --- a/declearn/communication/api/_server.py +++ b/declearn/communication/api/_server.py @@ -6,7 +6,7 @@ import asyncio import logging import types from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Set, Type, Union +from typing import Any, Dict, Optional, Set, Type, Union, ClassVar from declearn.communication.api._service import MessagesHandler @@ -54,7 +54,7 @@ class NetworkServer(metaclass=ABCMeta): of the awaitable `wait_for_clients` method. """ - protocol: str = NotImplemented + protocol: ClassVar[str] = NotImplemented def __init_subclass__(cls, register: bool = True) -> None: """Automate the type-registration of NetworkServer subclasses.""" diff --git a/declearn/communication/grpc/_client.py b/declearn/communication/grpc/_client.py index 941d7f761715c5b33b3e08bf572a0b36eaf98686..468ce0d83401ed70d751bbf4398e721e73e9f2be 100644 --- a/declearn/communication/grpc/_client.py +++ b/declearn/communication/grpc/_client.py @@ -3,17 +3,16 @@ """Client-side communication endpoint implementation using gRPC""" import logging -from typing import Any, Dict, Optional, Union +from typing import Any, ClassVar, Dict, Optional, Union import grpc # type: ignore from declearn.communication.api import NetworkClient -from declearn.communication.messaging import Message, parse_message_from_string from declearn.communication.grpc.protobufs import message_pb2 from declearn.communication.grpc.protobufs.message_pb2_grpc import ( MessageBoardStub, ) - +from declearn.communication.messaging import Message, parse_message_from_string __all__ = [ "GrpcClient", @@ -26,7 +25,7 @@ CHUNK_LENGTH = 2**22 - 50 # 2**22 - sys.getsizeof("") - 1 class GrpcClient(NetworkClient): """Client-side communication endpoint using gRPC.""" - protocol = "grpc" + protocol: ClassVar[str] = "grpc" def __init__( self, diff --git a/declearn/communication/websockets/_client.py b/declearn/communication/websockets/_client.py index 59627c50180345fab328c7187a0282624414d434..f8774f4656d2fc98d59a9d2d11b998c890a266b1 100644 --- a/declearn/communication/websockets/_client.py +++ b/declearn/communication/websockets/_client.py @@ -5,7 +5,7 @@ import asyncio import logging import ssl -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, ClassVar import websockets as ws from websockets.client import WebSocketClientProtocol @@ -25,7 +25,7 @@ CHUNK_LENGTH = 100000 class WebsocketsClient(NetworkClient): """Client-side communication endpoint using WebSockets.""" - protocol = "websockets" + protocol: ClassVar[str] = "websockets" def __init__( self, diff --git a/declearn/data_info/_base.py b/declearn/data_info/_base.py index 99bc1f680fb878ef68595db8447801e2824e9580..0f0fb9b412e144164f50637796051e266bb11efd 100644 --- a/declearn/data_info/_base.py +++ b/declearn/data_info/_base.py @@ -34,7 +34,7 @@ data_info fields, are implemented (although unexposed) here. import warnings from abc import ABCMeta, abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Dict, List, Optional, Set, Tuple, Type, ClassVar __all__ = [ @@ -75,9 +75,9 @@ class DataInfoField(metaclass=ABCMeta): is called, run `is_valid` on each and every input. """ - field: str - types: Tuple[Type, ...] - doc: str + field: ClassVar[str] = NotImplemented + types: ClassVar[Tuple[Type, ...]] = NotImplemented + doc: ClassVar[str] = NotImplemented @classmethod def is_valid( @@ -85,6 +85,7 @@ class DataInfoField(metaclass=ABCMeta): value: Any, ) -> bool: """Check that a given value may belong to this field.""" + # false-pos; pylint: disable=isinstance-second-argument-not-valid-type return isinstance(value, cls.types) @classmethod diff --git a/declearn/data_info/_fields.py b/declearn/data_info/_fields.py index 7d456646f11728b426e246a200b98d654fbb9152..d84e9e08d9ff618476e1fbc3ee7994f8e25b2b15 100644 --- a/declearn/data_info/_fields.py +++ b/declearn/data_info/_fields.py @@ -2,13 +2,12 @@ """DataInfoField subclasses specifying common 'data_info' metadata fields.""" -from typing import Any, List, Optional, Set +from typing import Any, ClassVar, List, Optional, Set, Tuple, Type import numpy as np from declearn.data_info._base import DataInfoField, register_data_info_field - __all__ = [ "ClassesField", "InputShapeField", @@ -21,9 +20,9 @@ __all__ = [ class ClassesField(DataInfoField): """Specifications for 'classes' data_info field.""" - field = "classes" - types = (list, set, tuple, np.ndarray) - doc = "Set of classification targets, combined by union." + field: ClassVar[str] = "classes" + types: ClassVar[Tuple[Type, ...]] = (list, set, tuple, np.ndarray) + doc: ClassVar[str] = "Set of classification targets, combined by union." @classmethod def is_valid( @@ -47,9 +46,9 @@ class ClassesField(DataInfoField): class InputShapeField(DataInfoField): """Specifications for 'input_shape' data_info field.""" - field = "input_shape" - types = (tuple, list) - doc = "Input features' batched shape, checked to be equal." + field: ClassVar[str] = "input_shape" + types: ClassVar[Tuple[Type, ...]] = (tuple, list) + doc: ClassVar[str] = "Input features' batched shape, checked to be equal." @classmethod def is_valid( @@ -96,9 +95,9 @@ class InputShapeField(DataInfoField): class NbFeaturesField(DataInfoField): """Specifications for 'n_features' data_info field.""" - field = "n_features" - types = (int,) - doc = "Number of input features, checked to be equal." + field: ClassVar[str] = "n_features" + types: ClassVar[Tuple[Type, ...]] = (int,) + doc: ClassVar[str] = "Number of input features, checked to be equal." @classmethod def is_valid( @@ -128,9 +127,9 @@ class NbFeaturesField(DataInfoField): class NbSamplesField(DataInfoField): """Specifications for 'n_samples' data_info field.""" - field = "n_samples" - types = (int,) - doc = "Number of data samples, combined by summation." + field: ClassVar[str] = "n_samples" + types: ClassVar[Tuple[Type, ...]] = (int,) + doc: ClassVar[str] = "Number of data samples, combined by summation." @classmethod def is_valid( diff --git a/declearn/dataset/_base.py b/declearn/dataset/_base.py index 665ba2b99fb2787eab3c3e1c4b2c13d05c99d3d6..ffbb48059bdde0aac938db2786bca1941015b61c 100644 --- a/declearn/dataset/_base.py +++ b/declearn/dataset/_base.py @@ -4,12 +4,11 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Any, Iterator, Optional, Set +from typing import Any, ClassVar, Iterator, Optional, Set from declearn.typing import Batch from declearn.utils import access_registered, create_types_registry, json_load - __all__ = [ "DataSpecs", "Dataset", @@ -41,7 +40,7 @@ class Dataset(metaclass=ABCMeta): straightforward to specify as part of FL algorithms. """ - _type_key: str = NotImplemented + _type_key: ClassVar[str] = NotImplemented @abstractmethod def save_to_json( diff --git a/declearn/dataset/_inmemory.py b/declearn/dataset/_inmemory.py index 97987cffd4c0b6953559f23edfa3d0f2b15f3fab..6ab5444166ba894d24c3a04a4c51aac3e151a6e4 100644 --- a/declearn/dataset/_inmemory.py +++ b/declearn/dataset/_inmemory.py @@ -4,7 +4,7 @@ import functools import os -from typing import Any, Dict, Iterator, List, Optional, Set, Union +from typing import Any, ClassVar, Dict, Iterator, List, Optional, Set, Union import numpy as np import pandas as pd # type: ignore @@ -17,7 +17,6 @@ from declearn.dataset._sparse import sparse_from_file, sparse_to_file from declearn.typing import Batch from declearn.utils import json_dump, json_load, register_type - __all__ = [ "InMemoryDataset", ] @@ -56,7 +55,7 @@ class InMemoryDataset(Dataset): # attributes serve clarity; pylint: disable=too-many-instance-attributes - _type_key = "InMemoryDataset" + _type_key: ClassVar[str] = "InMemoryDataset" def __init__( self, diff --git a/declearn/main/utils/_training.py b/declearn/main/utils/_training.py index ea4b2106afaef3d5ce6ce42ec19bb54733b35ebe..be142608915fe188b23ac023390c8b89aa8d75bb 100644 --- a/declearn/main/utils/_training.py +++ b/declearn/main/utils/_training.py @@ -3,7 +3,7 @@ """Wrapper to run local training and evaluation rounds in a FL process.""" import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Union import numpy as np @@ -96,7 +96,7 @@ class TrainingManager: class LossMetric(MeanMetric, register=False): """Ad hoc Metric wrapping a model's loss function.""" - name = "loss" + name: ClassVar[str] = "loss" def metric_func( self, y_true: np.ndarray, y_pred: np.ndarray diff --git a/declearn/metrics/_api.py b/declearn/metrics/_api.py index c408b9918dfcc4e4f16fe2138533309721e72554..54ea61ae85be6d1a2734a736356858ba21a3b417 100644 --- a/declearn/metrics/_api.py +++ b/declearn/metrics/_api.py @@ -104,7 +104,7 @@ class Metric(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: ClassVar[str] + name: ClassVar[str] = NotImplemented def __init__( self, diff --git a/declearn/metrics/_classif.py b/declearn/metrics/_classif.py index b6cb9dfb3dd47ada463a399eb1c3d8a552597683..6dc608115ff7eca9ba2f0ec28af152222033c69f 100644 --- a/declearn/metrics/_classif.py +++ b/declearn/metrics/_classif.py @@ -2,7 +2,7 @@ """Iterative and federative classification evaluation metrics.""" -from typing import Any, Collection, Dict, Optional, Union +from typing import Any, ClassVar, Collection, Dict, Optional, Union import numpy as np import sklearn # type: ignore @@ -41,7 +41,7 @@ class BinaryAccuracyPrecisionRecall(Metric): Confusion matrix of predictions. Values: [[TN, FP], [FN, TP]] """ - name = "binary-classif" + name: ClassVar[str] = "binary-classif" def __init__( self, @@ -131,7 +131,7 @@ class MulticlassAccuracyPrecisionRecall(Metric): were predicted to belong to label j. """ - name = "multi-classif" + name: ClassVar[str] = "multi-classif" def __init__( self, diff --git a/declearn/metrics/_mean.py b/declearn/metrics/_mean.py index 573cbed906be3ee1a09f528a59670163c971a5d4..b9f8abc37b3a8d7d04f94f55943416b0cf72fdfb 100644 --- a/declearn/metrics/_mean.py +++ b/declearn/metrics/_mean.py @@ -3,7 +3,7 @@ """Iterative and federative generic evaluation metrics.""" from abc import ABCMeta, abstractmethod -from typing import Dict, Optional, Union +from typing import ClassVar, Dict, Optional, Union import numpy as np @@ -110,7 +110,7 @@ class MeanAbsoluteError(MeanMetric): summed over channels for (>=2)-dimensional inputs). """ - name = "mae" + name: ClassVar[str] = "mae" def metric_func( self, @@ -138,7 +138,7 @@ class MeanSquaredError(MeanMetric): summed over channels for (>=2)-dimensional inputs). """ - name = "mse" + name: ClassVar[str] = "mse" def metric_func( self, diff --git a/declearn/metrics/_roc_auc.py b/declearn/metrics/_roc_auc.py index 7a3d2980c21d82a31d85784cc008ba3a1cd38eec..3772f84dc5ec3acd5565d52b6582d8ad9ac6cbd9 100644 --- a/declearn/metrics/_roc_auc.py +++ b/declearn/metrics/_roc_auc.py @@ -2,7 +2,7 @@ """Iterative and federative ROC AUC evaluation metrics.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, Optional, Tuple, Union import numpy as np import sklearn # type: ignore @@ -10,7 +10,6 @@ import sklearn.metrics # type: ignore from declearn.metrics._api import Metric - __all__ = [ "BinaryRocAUC", ] @@ -42,7 +41,7 @@ class BinaryRocAUC(Metric): unless its """ - name = "binary-roc" + name: ClassVar[str] = "binary-roc" def __init__( self, diff --git a/declearn/optimizer/modules/_adaptive.py b/declearn/optimizer/modules/_adaptive.py index 38528f730f9f8e3c4c147649a77af2c4f9277870..3c9ea80896bda17f9e5bdcebee402f8598d915ef 100644 --- a/declearn/optimizer/modules/_adaptive.py +++ b/declearn/optimizer/modules/_adaptive.py @@ -2,13 +2,12 @@ """Adaptive algorithms for optimizers, implemented as plug-in modules.""" -from typing import Any, Dict, Optional, Union +from typing import Any, ClassVar, Dict, Optional, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule from declearn.optimizer.modules._momentum import EWMAModule, YogiMomentumModule - __all__ = [ "AdaGradModule", "AdamModule", @@ -38,7 +37,7 @@ class AdaGradModule(OptiModule): https://jmlr.org/papers/v12/duchi11a.html """ - name = "adagrad" + name: ClassVar[str] = "adagrad" def __init__( self, @@ -102,7 +101,7 @@ class RMSPropModule(OptiModule): Average of its Recent Magnitude. """ - name = "rmsprop" + name: ClassVar[str] = "rmsprop" def __init__( self, @@ -184,7 +183,7 @@ class AdamModule(OptiModule): https://arxiv.org/abs/1904.09237 """ - name = "adam" + name: ClassVar[str] = "adam" def __init__( self, @@ -304,7 +303,7 @@ class YogiModule(AdamModule): https://arxiv.org/abs/1904.09237 """ - name = "yogi" + name: ClassVar[str] = "yogi" def __init__( self, diff --git a/declearn/optimizer/modules/_api.py b/declearn/optimizer/modules/_api.py index 75444d6b1a78c142eb4ac52c5ab20879a93dad4a..5d02f7c7fad6eef02b259c85702837ae8c12215c 100644 --- a/declearn/optimizer/modules/_api.py +++ b/declearn/optimizer/modules/_api.py @@ -3,7 +3,7 @@ """Base API for plug-in gradients-alteration algorithms.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional from declearn.model.api import Vector from declearn.utils import ( @@ -76,9 +76,8 @@ class OptiModule(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented - - aux_name: Optional[str] = None + name: ClassVar[str] = NotImplemented + aux_name: ClassVar[Optional[str]] = None def __init_subclass__( cls, diff --git a/declearn/optimizer/modules/_clipping.py b/declearn/optimizer/modules/_clipping.py index c30691bb4511151f44d9c7022bd2ae4c75822fed..cfda6cd2dcb24e39743fe1811b12c6b82a723941 100644 --- a/declearn/optimizer/modules/_clipping.py +++ b/declearn/optimizer/modules/_clipping.py @@ -2,13 +2,11 @@ """Batch-averaged gradients clipping plug-in modules.""" -from typing import Any, Dict - +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule - __all__ = ["L2Clipping"] @@ -32,7 +30,7 @@ class L2Clipping(OptiModule): to bound the sensitivity associated to that action. """ - name = "l2-clipping" + name: ClassVar[str] = "l2-clipping" def __init__( self, diff --git a/declearn/optimizer/modules/_momentum.py b/declearn/optimizer/modules/_momentum.py index 1b97bf2736e2dc2709282fd9214cc097b8aa1b6f..cdc5580f2b53bb01da00fb65097aec3424f0fd54 100644 --- a/declearn/optimizer/modules/_momentum.py +++ b/declearn/optimizer/modules/_momentum.py @@ -2,7 +2,7 @@ """Base API and common examples of plug-in gradients-alteration algorithms.""" -from typing import Any, Dict, Union +from typing import Any, ClassVar, Dict, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule @@ -48,7 +48,7 @@ class MomentumModule(OptiModule): https://proceedings.mlr.press/v28/sutskever13.pdf """ - name = "momentum" + name: ClassVar[str] = "momentum" def __init__( self, @@ -114,7 +114,7 @@ class EWMAModule(OptiModule): decaying moving-average of past gradients. """ - name = "ewma" + name: ClassVar[str] = "ewma" def __init__( self, @@ -184,7 +184,7 @@ class YogiMomentumModule(EWMAModule): Adaptive Methods for Nonconvex Optimization. """ - name = "yogi-momentum" + name: ClassVar[str] = "yogi-momentum" def run( self, diff --git a/declearn/optimizer/modules/_noise.py b/declearn/optimizer/modules/_noise.py index cf91f076ceffe2c87c5346f725018c81a04f9a73..651300c07104a1e6ef1d8af0d6effffa0025ec39 100644 --- a/declearn/optimizer/modules/_noise.py +++ b/declearn/optimizer/modules/_noise.py @@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod from random import SystemRandom -from typing import Any, Dict, Optional, Tuple +from typing import Any, ClassVar, Dict, Optional, Tuple import numpy as np import scipy.stats # type: ignore @@ -26,7 +26,7 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): or slower cryptographically secure pseudo-random numbers (CSPRN). """ - name = "abstract-noise" + name: ClassVar[str] = "abstract-noise" def __init__( self, @@ -98,7 +98,7 @@ class GaussianNoiseModule(NoiseModule): or slower cryptographically secure pseudo-random numbers (CSPRN). """ - name = "gaussian-noise" + name: ClassVar[str] = "gaussian-noise" def __init__( self, diff --git a/declearn/optimizer/modules/_scaffold.py b/declearn/optimizer/modules/_scaffold.py index b36ce060099aa008b9879e6f696d2494a59e6dfe..80d2b21272fe62f0d57da3b6c1f2e601a51d51c3 100644 --- a/declearn/optimizer/modules/_scaffold.py +++ b/declearn/optimizer/modules/_scaffold.py @@ -16,13 +16,11 @@ References: https://arxiv.org/abs/1910.06378 """ -from typing import Any, Dict, List, Optional, Union - +from typing import Any, ClassVar, Dict, List, Optional, Union from declearn.model.api import Vector from declearn.optimizer.modules._api import OptiModule - __all__ = [ "ScaffoldClientModule", "ScaffoldServerModule", @@ -80,8 +78,8 @@ class ScaffoldClientModule(OptiModule): https://arxiv.org/abs/1910.06378 """ - name = "scaffold-client" - aux_name = "scaffold" + name: ClassVar[str] = "scaffold-client" + aux_name: ClassVar[str] = "scaffold" def __init__( self, @@ -214,8 +212,8 @@ class ScaffoldServerModule(OptiModule): https://arxiv.org/abs/1910.06378 """ - name = "scaffold-server" - aux_name = "scaffold" + name: ClassVar[str] = "scaffold-server" + aux_name: ClassVar[str] = "scaffold" def __init__( self, diff --git a/declearn/optimizer/regularizers/_api.py b/declearn/optimizer/regularizers/_api.py index a199dab091c88ba122d0b92ff3dbc9e75388784e..c45e52c28b63c2924139a35b88d39b53ee5a6dcb 100644 --- a/declearn/optimizer/regularizers/_api.py +++ b/declearn/optimizer/regularizers/_api.py @@ -3,7 +3,7 @@ """Base API for loss regularization optimizer plug-ins.""" from abc import ABCMeta, abstractmethod -from typing import Any, Dict +from typing import Any, ClassVar, Dict from declearn.model.api import Vector from declearn.utils import ( @@ -12,7 +12,6 @@ from declearn.utils import ( register_type, ) - __all__ = [ "Regularizer", ] @@ -68,7 +67,7 @@ class Regularizer(metaclass=ABCMeta): See `declearn.utils.register_type` for details on types registration. """ - name: str = NotImplemented + name: ClassVar[str] = NotImplemented def __init_subclass__( cls, diff --git a/declearn/optimizer/regularizers/_base.py b/declearn/optimizer/regularizers/_base.py index 0bcac9cf2b6cda276d27b3dfb5be91f9f1d1ea36..f05376f3a2c4ee9f3bee16282734314d203c2c4d 100644 --- a/declearn/optimizer/regularizers/_base.py +++ b/declearn/optimizer/regularizers/_base.py @@ -2,12 +2,11 @@ """Common plug-in loss-regularization plug-ins.""" -from typing import Optional +from typing import ClassVar, Optional from declearn.model.api import Vector from declearn.optimizer.regularizers._api import Regularizer - __all__ = [ "FedProxRegularizer", "LassoRegularizer", @@ -40,7 +39,7 @@ class FedProxRegularizer(Regularizer): https://arxiv.org/abs/1812.06127 """ - name = "fedprox" + name: ClassVar[str] = "fedprox" def __init__( self, @@ -76,7 +75,7 @@ class LassoRegularizer(Regularizer): grads += alpha * sign(weights) """ - name = "lasso" + name: ClassVar[str] = "lasso" def run( self, @@ -97,7 +96,7 @@ class RidgeRegularizer(Regularizer): grads += alpha * 2 * weights """ - name = "ridge" + name: ClassVar[str] = "ridge" def run( self, diff --git a/examples/adding_rmsprop/readme.md b/examples/adding_rmsprop/readme.md index 7d47d6e91f2c32c7aacc7f142a8437f24bbd175d..40b748e5a37475944def19bcdc67642991e45f43 100644 --- a/examples/adding_rmsprop/readme.md +++ b/examples/adding_rmsprop/readme.md @@ -65,7 +65,7 @@ class RMSPropModule(OptiModule): # Identifier, that must be unique across modules for type-registration # purposes. This enables specifying the module in configuration files. - name = "rmsprop" + name:ClassVar[str] = "rmsprop" # Define optimizer parameters, here beta and eps diff --git a/test/optimizer/test_optimizer.py b/test/optimizer/test_optimizer.py index b4fe63fa592cbda0f3ebcbf716b220d1d396a752..ac5e7b010278f9a1001370a07407c5d1121bf673 100644 --- a/test/optimizer/test_optimizer.py +++ b/test/optimizer/test_optimizer.py @@ -3,7 +3,7 @@ """Unit tests for `declearn.optimizer.Optimizer`.""" -from typing import Any, Dict, Tuple +from typing import Any, ClassVar, Dict, Tuple from unittest import mock from uuid import uuid4 @@ -19,7 +19,7 @@ from declearn.test_utils import assert_json_serializable_dict class MockOptiModule(OptiModule): """Type-registered mock OptiModule subclass.""" - name = f"mock-{uuid4()}" + name: ClassVar[str] = f"mock-{uuid4()}" def __init__(self, **kwargs: Any) -> None: super().__init__() @@ -35,7 +35,7 @@ class MockOptiModule(OptiModule): class MockRegularizer(Regularizer): """Type-registered mock Regularizer subclass.""" - name = f"mock-{uuid4()}" + name: ClassVar[str] = f"mock-{uuid4()}" def __init__(self, **kwargs: Any) -> None: super().__init__()