Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 21f9b0d9 authored by BIGAUD Nathan's avatar BIGAUD Nathan Committed by ANDREY Paul
Browse files

Add ClassVar type-hints.

Details :
* In `/optimizer`, `/metrics`, and `/aggregator` :`name` and `aux_name`
* In `/communications` : `protocol`
* In `/dataset` : `_type_key`
* In `/data_info` : `field`, `type`, and `doc`

Excluding all classes decorated with `@dataclasses.dataclass`,
see https://stackoverflow.com/a/52099457
parent ddc04028
No related branches found
No related tags found
1 merge request!23Release version 2.0
Showing
with 64 additions and 82 deletions
......@@ -3,13 +3,11 @@
"""Model updates aggregation API."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict
from typing import Any, ClassVar, Dict
from declearn.model.api import Vector
from declearn.utils import create_types_registry, register_type
__all__ = [
"Aggregator",
]
......@@ -59,7 +57,7 @@ class Aggregator(metaclass=ABCMeta):
See `declearn.utils.register_type` for details on types registration.
"""
name: str = NotImplemented
name: ClassVar[str] = NotImplemented
def __init_subclass__(
cls,
......
......@@ -2,12 +2,10 @@
"""FedAvg-like mean-aggregation class."""
from typing import Any, Dict, Optional
from typing import Any, ClassVar, Dict, Optional
from declearn.model.api import Vector
from declearn.aggregator._api import Aggregator
from declearn.model.api import Vector
__all__ = [
"AveragingAggregator",
......@@ -24,7 +22,7 @@ class AveragingAggregator(Aggregator):
that use simple weighting schemes.
"""
name = "averaging"
name: ClassVar[str] = "averaging"
def __init__(
self,
......
......@@ -2,12 +2,10 @@
"""Gradient Masked Averaging aggregation class."""
from typing import Any, Dict, Optional
from typing import Any, ClassVar, Dict, Optional
from declearn.model.api import Vector
from declearn.aggregator._base import AveragingAggregator
from declearn.model.api import Vector
__all__ = [
"GradientMaskedAveraging",
......@@ -42,7 +40,7 @@ class GradientMaskedAveraging(AveragingAggregator):
https://arxiv.org/abs/2201.11986
"""
name = "gradient-masked-averaging"
name: ClassVar[str] = "gradient-masked-averaging"
def __init__(
self,
......
......@@ -5,8 +5,7 @@
import logging
import types
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Type, Union
from typing import Any, ClassVar, Dict, Optional, Type, Union
from declearn.communication.messaging import (
Empty,
......@@ -18,7 +17,6 @@ from declearn.communication.messaging import (
)
from declearn.utils import create_types_registry, get_logger, register_type
__all__ = [
"NetworkClient",
]
......@@ -58,7 +56,7 @@ class NetworkClient(metaclass=ABCMeta):
probably be rejected by the server if the client has not registered.
"""
protocol: str = NotImplemented
protocol: ClassVar[str] = NotImplemented
def __init_subclass__(cls, register: bool = True) -> None:
"""Automate the type-registration of NetworkClient subclasses."""
......@@ -107,7 +105,6 @@ class NetworkClient(metaclass=ABCMeta):
The return type is communication-protocol dependent.
"""
return NotImplemented
# similar to NetworkServer API; pylint: disable=duplicate-code
......@@ -118,7 +115,6 @@ class NetworkClient(metaclass=ABCMeta):
Note: this method can be called safely even if the
client is already running (simply having no effect).
"""
return None
@abstractmethod
async def stop(self) -> None:
......
......@@ -6,7 +6,7 @@ import asyncio
import logging
import types
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Set, Type, Union
from typing import Any, Dict, Optional, Set, Type, Union, ClassVar
from declearn.communication.api._service import MessagesHandler
......@@ -54,7 +54,7 @@ class NetworkServer(metaclass=ABCMeta):
of the awaitable `wait_for_clients` method.
"""
protocol: str = NotImplemented
protocol: ClassVar[str] = NotImplemented
def __init_subclass__(cls, register: bool = True) -> None:
"""Automate the type-registration of NetworkServer subclasses."""
......
......@@ -3,17 +3,16 @@
"""Client-side communication endpoint implementation using gRPC"""
import logging
from typing import Any, Dict, Optional, Union
from typing import Any, ClassVar, Dict, Optional, Union
import grpc # type: ignore
from declearn.communication.api import NetworkClient
from declearn.communication.messaging import Message, parse_message_from_string
from declearn.communication.grpc.protobufs import message_pb2
from declearn.communication.grpc.protobufs.message_pb2_grpc import (
MessageBoardStub,
)
from declearn.communication.messaging import Message, parse_message_from_string
__all__ = [
"GrpcClient",
......@@ -26,7 +25,7 @@ CHUNK_LENGTH = 2**22 - 50 # 2**22 - sys.getsizeof("") - 1
class GrpcClient(NetworkClient):
"""Client-side communication endpoint using gRPC."""
protocol = "grpc"
protocol: ClassVar[str] = "grpc"
def __init__(
self,
......
......@@ -5,7 +5,7 @@
import asyncio
import logging
import ssl
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Union, ClassVar
import websockets as ws
from websockets.client import WebSocketClientProtocol
......@@ -25,7 +25,7 @@ CHUNK_LENGTH = 100000
class WebsocketsClient(NetworkClient):
"""Client-side communication endpoint using WebSockets."""
protocol = "websockets"
protocol: ClassVar[str] = "websockets"
def __init__(
self,
......
......@@ -34,7 +34,7 @@ data_info fields, are implemented (although unexposed) here.
import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type
from typing import Any, Dict, List, Optional, Set, Tuple, Type, ClassVar
__all__ = [
......@@ -75,9 +75,9 @@ class DataInfoField(metaclass=ABCMeta):
is called, run `is_valid` on each and every input.
"""
field: str
types: Tuple[Type, ...]
doc: str
field: ClassVar[str] = NotImplemented
types: ClassVar[Tuple[Type, ...]] = NotImplemented
doc: ClassVar[str] = NotImplemented
@classmethod
def is_valid(
......@@ -85,6 +85,7 @@ class DataInfoField(metaclass=ABCMeta):
value: Any,
) -> bool:
"""Check that a given value may belong to this field."""
# false-pos; pylint: disable=isinstance-second-argument-not-valid-type
return isinstance(value, cls.types)
@classmethod
......
......@@ -2,13 +2,12 @@
"""DataInfoField subclasses specifying common 'data_info' metadata fields."""
from typing import Any, List, Optional, Set
from typing import Any, ClassVar, List, Optional, Set, Tuple, Type
import numpy as np
from declearn.data_info._base import DataInfoField, register_data_info_field
__all__ = [
"ClassesField",
"InputShapeField",
......@@ -21,9 +20,9 @@ __all__ = [
class ClassesField(DataInfoField):
"""Specifications for 'classes' data_info field."""
field = "classes"
types = (list, set, tuple, np.ndarray)
doc = "Set of classification targets, combined by union."
field: ClassVar[str] = "classes"
types: ClassVar[Tuple[Type, ...]] = (list, set, tuple, np.ndarray)
doc: ClassVar[str] = "Set of classification targets, combined by union."
@classmethod
def is_valid(
......@@ -47,9 +46,9 @@ class ClassesField(DataInfoField):
class InputShapeField(DataInfoField):
"""Specifications for 'input_shape' data_info field."""
field = "input_shape"
types = (tuple, list)
doc = "Input features' batched shape, checked to be equal."
field: ClassVar[str] = "input_shape"
types: ClassVar[Tuple[Type, ...]] = (tuple, list)
doc: ClassVar[str] = "Input features' batched shape, checked to be equal."
@classmethod
def is_valid(
......@@ -96,9 +95,9 @@ class InputShapeField(DataInfoField):
class NbFeaturesField(DataInfoField):
"""Specifications for 'n_features' data_info field."""
field = "n_features"
types = (int,)
doc = "Number of input features, checked to be equal."
field: ClassVar[str] = "n_features"
types: ClassVar[Tuple[Type, ...]] = (int,)
doc: ClassVar[str] = "Number of input features, checked to be equal."
@classmethod
def is_valid(
......@@ -128,9 +127,9 @@ class NbFeaturesField(DataInfoField):
class NbSamplesField(DataInfoField):
"""Specifications for 'n_samples' data_info field."""
field = "n_samples"
types = (int,)
doc = "Number of data samples, combined by summation."
field: ClassVar[str] = "n_samples"
types: ClassVar[Tuple[Type, ...]] = (int,)
doc: ClassVar[str] = "Number of data samples, combined by summation."
@classmethod
def is_valid(
......
......@@ -4,12 +4,11 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Iterator, Optional, Set
from typing import Any, ClassVar, Iterator, Optional, Set
from declearn.typing import Batch
from declearn.utils import access_registered, create_types_registry, json_load
__all__ = [
"DataSpecs",
"Dataset",
......@@ -41,7 +40,7 @@ class Dataset(metaclass=ABCMeta):
straightforward to specify as part of FL algorithms.
"""
_type_key: str = NotImplemented
_type_key: ClassVar[str] = NotImplemented
@abstractmethod
def save_to_json(
......
......@@ -4,7 +4,7 @@
import functools
import os
from typing import Any, Dict, Iterator, List, Optional, Set, Union
from typing import Any, ClassVar, Dict, Iterator, List, Optional, Set, Union
import numpy as np
import pandas as pd # type: ignore
......@@ -17,7 +17,6 @@ from declearn.dataset._sparse import sparse_from_file, sparse_to_file
from declearn.typing import Batch
from declearn.utils import json_dump, json_load, register_type
__all__ = [
"InMemoryDataset",
]
......@@ -56,7 +55,7 @@ class InMemoryDataset(Dataset):
# attributes serve clarity; pylint: disable=too-many-instance-attributes
_type_key = "InMemoryDataset"
_type_key: ClassVar[str] = "InMemoryDataset"
def __init__(
self,
......
......@@ -3,7 +3,7 @@
"""Wrapper to run local training and evaluation rounds in a FL process."""
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, ClassVar, Dict, List, Optional, Union
import numpy as np
......@@ -96,7 +96,7 @@ class TrainingManager:
class LossMetric(MeanMetric, register=False):
"""Ad hoc Metric wrapping a model's loss function."""
name = "loss"
name: ClassVar[str] = "loss"
def metric_func(
self, y_true: np.ndarray, y_pred: np.ndarray
......
......@@ -104,7 +104,7 @@ class Metric(metaclass=ABCMeta):
See `declearn.utils.register_type` for details on types registration.
"""
name: ClassVar[str]
name: ClassVar[str] = NotImplemented
def __init__(
self,
......
......@@ -2,7 +2,7 @@
"""Iterative and federative classification evaluation metrics."""
from typing import Any, Collection, Dict, Optional, Union
from typing import Any, ClassVar, Collection, Dict, Optional, Union
import numpy as np
import sklearn # type: ignore
......@@ -41,7 +41,7 @@ class BinaryAccuracyPrecisionRecall(Metric):
Confusion matrix of predictions. Values: [[TN, FP], [FN, TP]]
"""
name = "binary-classif"
name: ClassVar[str] = "binary-classif"
def __init__(
self,
......@@ -131,7 +131,7 @@ class MulticlassAccuracyPrecisionRecall(Metric):
were predicted to belong to label j.
"""
name = "multi-classif"
name: ClassVar[str] = "multi-classif"
def __init__(
self,
......
......@@ -3,7 +3,7 @@
"""Iterative and federative generic evaluation metrics."""
from abc import ABCMeta, abstractmethod
from typing import Dict, Optional, Union
from typing import ClassVar, Dict, Optional, Union
import numpy as np
......@@ -110,7 +110,7 @@ class MeanAbsoluteError(MeanMetric):
summed over channels for (>=2)-dimensional inputs).
"""
name = "mae"
name: ClassVar[str] = "mae"
def metric_func(
self,
......@@ -138,7 +138,7 @@ class MeanSquaredError(MeanMetric):
summed over channels for (>=2)-dimensional inputs).
"""
name = "mse"
name: ClassVar[str] = "mse"
def metric_func(
self,
......
......@@ -2,7 +2,7 @@
"""Iterative and federative ROC AUC evaluation metrics."""
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, ClassVar, Dict, Optional, Tuple, Union
import numpy as np
import sklearn # type: ignore
......@@ -10,7 +10,6 @@ import sklearn.metrics # type: ignore
from declearn.metrics._api import Metric
__all__ = [
"BinaryRocAUC",
]
......@@ -42,7 +41,7 @@ class BinaryRocAUC(Metric):
unless its
"""
name = "binary-roc"
name: ClassVar[str] = "binary-roc"
def __init__(
self,
......
......@@ -2,13 +2,12 @@
"""Adaptive algorithms for optimizers, implemented as plug-in modules."""
from typing import Any, Dict, Optional, Union
from typing import Any, ClassVar, Dict, Optional, Union
from declearn.model.api import Vector
from declearn.optimizer.modules._api import OptiModule
from declearn.optimizer.modules._momentum import EWMAModule, YogiMomentumModule
__all__ = [
"AdaGradModule",
"AdamModule",
......@@ -38,7 +37,7 @@ class AdaGradModule(OptiModule):
https://jmlr.org/papers/v12/duchi11a.html
"""
name = "adagrad"
name: ClassVar[str] = "adagrad"
def __init__(
self,
......@@ -102,7 +101,7 @@ class RMSPropModule(OptiModule):
Average of its Recent Magnitude.
"""
name = "rmsprop"
name: ClassVar[str] = "rmsprop"
def __init__(
self,
......@@ -184,7 +183,7 @@ class AdamModule(OptiModule):
https://arxiv.org/abs/1904.09237
"""
name = "adam"
name: ClassVar[str] = "adam"
def __init__(
self,
......@@ -304,7 +303,7 @@ class YogiModule(AdamModule):
https://arxiv.org/abs/1904.09237
"""
name = "yogi"
name: ClassVar[str] = "yogi"
def __init__(
self,
......
......@@ -3,7 +3,7 @@
"""Base API for plug-in gradients-alteration algorithms."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, ClassVar, Dict, Optional
from declearn.model.api import Vector
from declearn.utils import (
......@@ -76,9 +76,8 @@ class OptiModule(metaclass=ABCMeta):
See `declearn.utils.register_type` for details on types registration.
"""
name: str = NotImplemented
aux_name: Optional[str] = None
name: ClassVar[str] = NotImplemented
aux_name: ClassVar[Optional[str]] = None
def __init_subclass__(
cls,
......
......@@ -2,13 +2,11 @@
"""Batch-averaged gradients clipping plug-in modules."""
from typing import Any, Dict
from typing import Any, ClassVar, Dict
from declearn.model.api import Vector
from declearn.optimizer.modules._api import OptiModule
__all__ = ["L2Clipping"]
......@@ -32,7 +30,7 @@ class L2Clipping(OptiModule):
to bound the sensitivity associated to that action.
"""
name = "l2-clipping"
name: ClassVar[str] = "l2-clipping"
def __init__(
self,
......
......@@ -2,7 +2,7 @@
"""Base API and common examples of plug-in gradients-alteration algorithms."""
from typing import Any, Dict, Union
from typing import Any, ClassVar, Dict, Union
from declearn.model.api import Vector
from declearn.optimizer.modules._api import OptiModule
......@@ -48,7 +48,7 @@ class MomentumModule(OptiModule):
https://proceedings.mlr.press/v28/sutskever13.pdf
"""
name = "momentum"
name: ClassVar[str] = "momentum"
def __init__(
self,
......@@ -114,7 +114,7 @@ class EWMAModule(OptiModule):
decaying moving-average of past gradients.
"""
name = "ewma"
name: ClassVar[str] = "ewma"
def __init__(
self,
......@@ -184,7 +184,7 @@ class YogiMomentumModule(EWMAModule):
Adaptive Methods for Nonconvex Optimization.
"""
name = "yogi-momentum"
name: ClassVar[str] = "yogi-momentum"
def run(
self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment