Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 38dbe123 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Merge branch 'gardening' into 'develop'

Minor gardening around the package

See merge request !44
parents a748ae71 eef2033f
No related branches found
No related tags found
1 merge request!44Minor gardening around the package
Pipeline #801969 passed
Showing
with 283 additions and 92 deletions
......@@ -17,22 +17,32 @@
"""Model updates aggregating API and implementations.
An Aggregator is typically meant to be used on a round-wise basis by
An `Aggregator` is typically meant to be used on a round-wise basis by
the orchestrating server of a centralized federated learning process,
to aggregate the client-wise model updated into a Vector that may then
be used as "gradients" by the server's Optimizer to update the global
to aggregate the client-wise model updated into a `Vector` that may then
be used as "gradients" by the server's `Optimizer` to update the global
model.
This declearn submodule provides with:
API tools
---------
* [Aggregator][declearn.aggregator.Aggregator]:
abstract class defining an API for Vector aggregation
Abstract base class defining an API for Vector aggregation.
* [list_aggregators][declearn.aggregator.list_aggregators]:
Return a mapping of registered Aggregator subclasses.
Concrete classes
----------------
* [AveragingAggregator][declearn.aggregator.AveragingAggregator]:
average-based-aggregation Aggregator subclass
Average-based-aggregation Aggregator subclass.
* [GradientMaskedAveraging][declearn.aggregator.GradientMaskedAveraging]:
gradient Masked Averaging Aggregator subclass
Gradient Masked Averaging Aggregator subclass.
"""
from ._api import Aggregator
from ._api import Aggregator, list_aggregators
from ._base import AveragingAggregator
from ._gma import GradientMaskedAveraging
......@@ -18,18 +18,25 @@
"""Model updates aggregation API."""
from abc import ABCMeta, abstractmethod
from typing import Any, ClassVar, Dict
from typing import Any, ClassVar, Dict, Type, TypeVar
from typing_extensions import Self # future: import from typing (py >=3.11)
from declearn.model.api import Vector
from declearn.utils import create_types_registry, register_type
from declearn.utils import (
access_types_mapping,
create_types_registry,
register_type,
)
__all__ = [
"Aggregator",
]
T = TypeVar("T")
@create_types_registry
class Aggregator(metaclass=ABCMeta):
"""Abstract class defining an API for Vector aggregation.
......@@ -90,9 +97,9 @@ class Aggregator(metaclass=ABCMeta):
@abstractmethod
def aggregate(
self,
updates: Dict[str, Vector],
updates: Dict[str, Vector[T]],
n_steps: Dict[str, int], # revise: abstract~generalize kwargs use
) -> Vector:
) -> Vector[T]:
"""Aggregate input vectors into a single one.
Parameters
......@@ -109,6 +116,11 @@ class Aggregator(metaclass=ABCMeta):
gradients: Vector
Aggregated updates, as a Vector - treated as gradients by
the server-side optimizer.
Raises
------
TypeError
If the input `updates` are an empty dict.
"""
def get_config(
......@@ -124,3 +136,29 @@ class Aggregator(metaclass=ABCMeta):
) -> Self:
"""Instantiate an Aggregator from its configuration dict."""
return cls(**config)
def list_aggregators() -> Dict[str, Type[Aggregator]]:
"""Return a mapping of registered Aggregator subclasses.
This function aims at making it easy for end-users to list and access
all available Aggregator classes at any given time. The returned dict
uses unique identifier keys, which may be used to specify the desired
algorithm as part of a federated learning process without going through
the fuss of importing and instantiating it manually.
Note that the mapping will include all declearn-provided plug-ins,
but also registered plug-ins provided by user or third-party code.
See also
--------
* [declearn.aggregator.Aggregator][]:
API-defining abstract base class for the aggregation algorithms.
Returns
-------
mapping:
Dictionary mapping unique str identifiers to `Aggregator` class
constructors.
"""
return access_types_mapping("Aggregator")
......@@ -19,7 +19,6 @@
from typing import Any, ClassVar, Dict, Optional
from typing_extensions import Self # future: import from typing (py >=3.11)
from declearn.aggregator._api import Aggregator
from declearn.model.api import Vector
......@@ -76,13 +75,6 @@ class AveragingAggregator(Aggregator):
"client_weights": self.client_weights,
}
@classmethod
def from_config(
cls,
config: Dict[str, Any],
) -> Self:
return cls(**config)
def aggregate(
self,
updates: Dict[str, Vector],
......
......@@ -22,7 +22,7 @@ import functools
from typing import Any, Dict, Union
from declearn.aggregator import Aggregator
from declearn.aggregator import Aggregator, AveragingAggregator
from declearn.optimizer import Optimizer
from declearn.utils import TomlConfig, access_registered, deserialize_object
......@@ -95,7 +95,9 @@ class FLOptimConfig(TomlConfig):
server_opt: Optimizer = dataclasses.field(
default_factory=functools.partial(Optimizer, lrate=1.0)
)
aggregator: Aggregator = dataclasses.field(default_factory=Aggregator)
aggregator: Aggregator = dataclasses.field(
default_factory=AveragingAggregator
)
@classmethod
def parse_client_opt(
......@@ -147,7 +149,7 @@ class FLOptimConfig(TomlConfig):
- (opt.) group: str used to retrieve the registered class
- (opt.) config: dict specifying kwargs for the constructor
- any other field will be added to the `config` kwargs dict
- as None (or missing kwarg), using default AverageAggregator()
- as None (or missing kwarg), using default AveragingAggregator()
"""
# Case when using the default value: delegate to the default parser.
if inputs is None:
......
......@@ -18,7 +18,7 @@
"""Model abstraction API."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Set, Tuple
from typing import Any, Dict, Generic, Optional, Set, Tuple, TypeVar
import numpy as np
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -33,8 +33,12 @@ __all__ = [
]
VectorT = TypeVar("VectorT", bound=Vector)
"""Type-annotation for the Vector subclass proper to a given Model."""
@create_types_registry
class Model(metaclass=ABCMeta):
class Model(Generic[VectorT], metaclass=ABCMeta):
"""Abstract class defining an API to manipulate a ML model.
A 'Model' is an abstraction that defines a generic interface
......@@ -59,6 +63,21 @@ class Model(metaclass=ABCMeta):
"""Instantiate a Model interface wrapping a 'model' object."""
self._model = model
def get_wrapped_model(self) -> Any:
"""Getter to access the wrapped framework-specific model object.
This getter should be used sparingly, so as to avoid undesirable
side effects. In particular, it should not be used in declearn
backend code (but may be in examples or tests), as it is merely
a way for end-users to access the wrapped model after training.
Returns
-------
model:
Wrapped model, of (framework/Model-subclass)-specific type.
"""
return self._model
@property
@abstractmethod
def device_policy(
......@@ -119,7 +138,7 @@ class Model(metaclass=ABCMeta):
def get_weights(
self,
trainable: bool = False,
) -> Vector:
) -> VectorT:
"""Return the model's weights, optionally excluding frozen ones.
Parameters
......@@ -140,7 +159,7 @@ class Model(metaclass=ABCMeta):
@abstractmethod
def set_weights(
self,
weights: Vector,
weights: VectorT,
trainable: bool = False,
) -> None:
"""Assign values to the model's weights.
......@@ -176,7 +195,7 @@ class Model(metaclass=ABCMeta):
self,
batch: Batch,
max_norm: Optional[float] = None,
) -> Vector:
) -> VectorT:
"""Compute and return gradients computed over a given data batch.
Compute the average gradients of the model's loss with respect
......@@ -204,7 +223,7 @@ class Model(metaclass=ABCMeta):
@abstractmethod
def apply_updates(
self,
updates: Vector,
updates: VectorT,
) -> None:
"""Apply updates to the model's weights."""
......
......@@ -19,7 +19,10 @@
import operator
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union
from typing import (
# fmt: off
Any, Callable, Dict, Generic, Optional, Set, Tuple, Type, TypeVar, Union
)
from numpy.typing import ArrayLike
from typing_extensions import Self # future: import from typing (Py>=3.11)
......@@ -37,10 +40,15 @@ __all__ = [
VECTOR_TYPES = {} # type: Dict[Type[Any], Type[Vector]]
"""Private constant holding registered Vector types."""
T = TypeVar("T")
"""Type-annotation for the data structures proper to a given Vector class."""
@create_types_registry
class Vector(metaclass=ABCMeta):
class Vector(Generic[T], metaclass=ABCMeta):
"""Abstract class defining an API to manipulate (sets of) data arrays.
A Vector is an abstraction used to wrap a collection of data
......@@ -62,27 +70,27 @@ class Vector(metaclass=ABCMeta):
"""
@property
def _op_add(self) -> Callable[[Any, Any], Any]:
def _op_add(self) -> Callable[[Any, Any], T]:
"""Framework-compatible addition operator."""
return operator.add
@property
def _op_sub(self) -> Callable[[Any, Any], Any]:
def _op_sub(self) -> Callable[[Any, Any], T]:
"""Framework-compatible substraction operator."""
return operator.sub
@property
def _op_mul(self) -> Callable[[Any, Any], Any]:
def _op_mul(self) -> Callable[[Any, Any], T]:
"""Framework-compatible multiplication operator."""
return operator.mul
@property
def _op_div(self) -> Callable[[Any, Any], Any]:
def _op_div(self) -> Callable[[Any, Any], T]:
"""Framework-compatible true division operator."""
return operator.truediv
@property
def _op_pow(self) -> Callable[[Any, Any], Any]:
def _op_pow(self) -> Callable[[Any, Any], T]:
"""Framework-compatible power operator."""
return operator.pow
......@@ -108,13 +116,13 @@ class Vector(metaclass=ABCMeta):
def __init__(
self,
coefs: Dict[str, Any],
coefs: Dict[str, T],
) -> None:
"""Instantiate the Vector to wrap a collection of data arrays.
Parameters
----------
coefs: dict[str, any]
coefs: dict[str, <T>]
Dict grouping a named collection of data arrays.
The supported types of that dict's values depends
on the concrete `Vector` subclass being used.
......@@ -123,7 +131,7 @@ class Vector(metaclass=ABCMeta):
@staticmethod
def build(
coefs: Dict[str, Any],
coefs: Dict[str, T],
) -> "Vector":
"""Instantiate a Vector, inferring its exact subtype from coefs'.
......@@ -136,7 +144,7 @@ class Vector(metaclass=ABCMeta):
Parameters
----------
coefs: dict[str, any]
coefs: dict[str, <T>]
Dict grouping a named collection of data arrays, that
all belong to the same framework.
......@@ -189,7 +197,10 @@ class Vector(metaclass=ABCMeta):
indexed by the coefficient's name.
"""
try:
return {key: coef.shape for key, coef in self.coefs.items()}
return {
key: coef.shape # type: ignore # exception caught
for key, coef in self.coefs.items()
}
except AttributeError as exc:
raise NotImplementedError(
"Wrapped coefficients appear not to implement `.shape`.\n"
......@@ -210,7 +221,10 @@ class Vector(metaclass=ABCMeta):
concrete framework of the Vector.
"""
try:
return {key: str(coef.dtype) for key, coef in self.coefs.items()}
return {
key: str(coef.dtype) # type: ignore # exception caught
for key, coef in self.coefs.items()
}
except AttributeError as exc:
raise NotImplementedError(
"Wrapped coefficients appear not to implement `.dtype`.\n"
......@@ -261,7 +275,7 @@ class Vector(metaclass=ABCMeta):
def apply_func(
self,
func: Callable[..., Any],
func: Callable[..., T],
*args: Any,
**kwargs: Any,
) -> Self:
......@@ -290,14 +304,15 @@ class Vector(metaclass=ABCMeta):
def _apply_operation(
self,
other: Any,
func: Callable[[Any, Any], Any],
func: Callable[[Any, Any], T],
) -> Self:
"""Apply an operation to combine this vector with another.
Parameters
----------
other: Vector
Vector with the same names, shapes and dtypes as this one.
other:
Vector with the same names, shapes and dtypes as this one;
or scalar object on which to operate (e.g. a float value).
func: function(<T>, <T>) -> <T>
Function to be applied to combine the data arrays stored
in this vector and the `other` one.
......
......@@ -22,10 +22,7 @@ import inspect
import io
import warnings
from random import SystemRandom
from typing import (
# fmt: off
Any, Callable, Dict, List, Optional, Set, Tuple, Union
)
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import haiku as hk
import jax
......@@ -205,7 +202,7 @@ class HaikuModel(Model):
params = {k: v for k, v in params.items() if k in self._trainable}
return JaxNumpyVector(params)
def set_weights( # type: ignore # Vector subtype specification
def set_weights(
self,
weights: JaxNumpyVector,
trainable: bool = False,
......@@ -466,7 +463,7 @@ class HaikuModel(Model):
output = [list(map(convert, inputs)), convert(y_true), convert(s_wght)]
return output # type: ignore
def apply_updates( # type: ignore # Vector subtype specification
def apply_updates(
self,
updates: JaxNumpyVector,
) -> None:
......
......@@ -74,23 +74,23 @@ class JaxNumpyVector(Vector):
"""
@property
def _op_add(self) -> Callable[[Any, Any], Any]:
def _op_add(self) -> Callable[[Any, Any], jax.Array]:
return jnp.add
@property
def _op_sub(self) -> Callable[[Any, Any], Any]:
def _op_sub(self) -> Callable[[Any, Any], jax.Array]:
return jnp.subtract
@property
def _op_mul(self) -> Callable[[Any, Any], Any]:
def _op_mul(self) -> Callable[[Any, Any], jax.Array]:
return jnp.multiply
@property
def _op_div(self) -> Callable[[Any, Any], Any]:
def _op_div(self) -> Callable[[Any, Any], jax.Array]:
return jnp.divide
@property
def _op_pow(self) -> Callable[[Any, Any], Any]:
def _op_pow(self) -> Callable[[Any, Any], jax.Array]:
return jnp.power
@property
......@@ -104,7 +104,7 @@ class JaxNumpyVector(Vector):
def _apply_operation(
self,
other: Any,
func: Callable[[Any, Any], Any],
func: Callable[[jax.Array, Any], jax.Array],
) -> Self:
# Ensure 'other' JaxNumpyVector shares this vector's device placement.
if isinstance(other, JaxNumpyVector):
......
......@@ -58,29 +58,35 @@ class NumpyVector(Vector):
"""
@property
def _op_add(self) -> Callable[[Any, Any], Any]:
def _op_add(self) -> Callable[[Any, Any], np.ndarray]:
return np.add
@property
def _op_sub(self) -> Callable[[Any, Any], Any]:
def _op_sub(self) -> Callable[[Any, Any], np.ndarray]:
return np.subtract
@property
def _op_mul(self) -> Callable[[Any, Any], Any]:
def _op_mul(self) -> Callable[[Any, Any], np.ndarray]:
return np.multiply
@property
def _op_div(self) -> Callable[[Any, Any], Any]:
def _op_div(self) -> Callable[[Any, Any], np.ndarray]:
return np.divide
@property
def _op_pow(self) -> Callable[[Any, Any], Any]:
def _op_pow(self) -> Callable[[Any, Any], np.ndarray]:
return np.power
def __init__(self, coefs: Dict[str, np.ndarray]) -> None:
def __init__(
self,
coefs: Dict[str, np.ndarray],
) -> None:
super().__init__(coefs)
def __eq__(self, other: Any) -> bool:
def __eq__(
self,
other: Any,
) -> bool:
valid = isinstance(other, NumpyVector)
if valid:
valid = self.coefs.keys() == other.coefs.keys()
......
......@@ -282,7 +282,7 @@ class SklearnSGDModel(Model):
}
return NumpyVector(weights)
def set_weights( # type: ignore # Vector subtype specification
def set_weights(
self,
weights: NumpyVector,
trainable: bool = False,
......@@ -356,7 +356,7 @@ class SklearnSGDModel(Model):
# Compute gradients based on weights' update.
return w_srt - w_end
def apply_updates( # type: ignore # Vector subtype specification
def apply_updates(
self,
updates: NumpyVector,
) -> None:
......
......@@ -183,7 +183,7 @@ class TensorflowModel(Model):
)
return TensorflowVector({var.name: var.value() for var in variables})
def set_weights( # type: ignore # Vector subtype specification
def set_weights(
self,
weights: TensorflowVector,
trainable: bool = False,
......@@ -319,7 +319,7 @@ class TensorflowModel(Model):
outp.append(tf.reduce_mean(grad * s_wght, axis=0))
return outp
def apply_updates( # type: ignore # Vector subtype specification
def apply_updates(
self,
updates: TensorflowVector,
) -> None:
......
......@@ -168,7 +168,7 @@ class TorchModel(Model):
# Note: calling `tensor.clone()` to return a copy rather than a view.
return TorchVector({k: t.detach().clone() for k, t in weights.items()})
def set_weights( # type: ignore # Vector subtype specification
def set_weights(
self,
weights: TorchVector,
trainable: bool = False,
......@@ -378,7 +378,7 @@ class TorchModel(Model):
return grads_fn
return functorch.compile.aot_function(grads_fn, functorch.compile.nop)
def apply_updates( # type: ignore # Vector subtype specification
def apply_updates(
self,
updates: TorchVector,
) -> None:
......
......@@ -32,8 +32,16 @@ Submodules providing with plug-in algorithms:
Gradients-alteration algorithms, implemented as plug-in modules.
* [regularizers][declearn.optimizer.regularizers]:
Loss-regularization algorithms, implemented as plug-in modules.
Utils to list available plug-ins:
* [list_optim_modules][declearn.optimizer.list_optim_modules]:
Return a mapping of registered OptiModule subclasses.
* [list_optim_regularizers][declearn.optimizer.list_optim_regularizers]:
Return a mapping of registered Regularizer subclasses.
"""
from . import modules, regularizers
from ._base import Optimizer
from ._utils import list_optim_modules, list_optim_regularizers
......@@ -17,7 +17,10 @@
"""Base class to define gradient-descent-based optimizers."""
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import (
# fmt: off
Any, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
)
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -31,6 +34,9 @@ __all__ = [
]
T = TypeVar("T")
class Optimizer:
"""Base class to define gradient-descent-based optimizers.
......@@ -103,6 +109,13 @@ class Optimizer:
[1] Loshchilov & Hutter, 2019.
Decoupled Weight Decay Regularization.
https://arxiv.org/abs/1711.05101
See also
--------
- [declearn.optimizer.list_optim_modules][]:
Return a mapping of registered OptiModule subclasses.
- [declearn.optimizer.list_optim_regularizers][]:
Return a mapping of registered Regularizer subclasses.
"""
def __init__(
......@@ -255,9 +268,9 @@ class Optimizer:
def compute_updates_from_gradients(
self,
model: Model,
gradients: Vector,
) -> Vector:
model: Model[Vector[T]],
gradients: Vector[T],
) -> Vector[T]:
"""Compute and return model updates based on pre-computed gradients.
Parameters
......@@ -393,8 +406,8 @@ class Optimizer:
def apply_gradients(
self,
model: Model,
gradients: Vector,
model: Model[Vector[T]],
gradients: Vector[T],
) -> None:
"""Compute and apply model updates based on pre-computed gradients.
......
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to list available optimizer plug-ins (OptiModule and Regularizer)."""
from typing import Dict, Type
from declearn.optimizer.modules import OptiModule
from declearn.optimizer.regularizers import Regularizer
from declearn.utils import access_types_mapping
__all__ = [
"list_optim_modules",
"list_optim_regularizers",
]
def list_optim_modules() -> Dict[str, Type[OptiModule]]:
"""Return a mapping of registered OptiModule subclasses.
This function aims at making it easy for end-users to list and access
all available OptiModule optimizer plug-ins at any given time. The
returned dict uses unique identifier keys, which may be used to add
the associated plug-in to a [declearn.optimizer.Optimizer][] without
going through the fuss of importing and instantiating it manually.
Note that the mapping will include all declearn-provided plug-ins,
but also registered plug-ins provided by user or third-party code.
See also
--------
* [declearn.optimizer.modules.OptiModule][]:
API-defining abstract base class for the OptiModule plug-ins.
* [declearn.optimizer.list_optim_regularizers][]:
Counterpart function for Regularizer plug-ins.
Returns
-------
mapping:
Dictionary mapping unique str identifiers to OptiModule
class constructors.
"""
return access_types_mapping("OptiModule")
def list_optim_regularizers() -> Dict[str, Type[Regularizer]]:
"""Return a mapping of registered Regularizer subclasses.
This function aims at making it easy for end-users to list and access
all available Regularizer optimizer plug-ins at any given time. The
returned dict uses unique identifier keys, which may be used to add
the associated plug-in to a [declearn.optimizer.Optimizer][] without
going through the fuss of importing and instantiating it manually.
Note that the mapping will include all declearn-provided plug-ins,
but also registered plug-ins provided by user or third-party code.
See also
--------
* [declearn.optimizer.regularizers.Regularizer][]:
API-defining abstract base class for the Regularizer plug-ins.
* [declearn.optimizer.list_optim_modules][]:
Counterpart function for OptiModule plug-ins.
Returns
-------
mapping:
Dictionary mapping unique str identifiers to Regularizer
class constructors.
"""
return access_types_mapping("Regularizer")
......@@ -18,7 +18,7 @@
"""Base API for plug-in gradients-alteration algorithms."""
from abc import ABCMeta, abstractmethod
from typing import Any, ClassVar, Dict, Optional
from typing import Any, ClassVar, Dict, Optional, TypeVar
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -34,6 +34,9 @@ __all__ = [
]
T = TypeVar("T")
@create_types_registry
class OptiModule(metaclass=ABCMeta):
"""Abstract class defining an API to implement gradients adaptation tools.
......@@ -117,8 +120,8 @@ class OptiModule(metaclass=ABCMeta):
@abstractmethod
def run(
self,
gradients: Vector,
) -> Vector:
gradients: Vector[T],
) -> Vector[T]:
"""Apply the module's algorithm to input gradients.
Please refer to the module's main docstring for details
......
......@@ -81,7 +81,7 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False):
gradients: Vector,
) -> Vector:
if not NumpyVector in gradients.compatible_vector_types:
raise TypeError(
raise TypeError( # pragma: no cover
f"{self.__class__.__name__} requires input gradients to "
"be compatible with NumpyVector, which is not the case "
f"of {type(gradients).__name__}."
......@@ -95,7 +95,7 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False):
for key in gradients.coefs
}
# Add the sampled noise to the gradients and return them.
# Silence warnings about sparse gradients getting sparsified.
# Silence warnings about sparse gradients getting densified.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*densifying.*", RuntimeWarning)
return gradients + NumpyVector(noise)
......@@ -166,4 +166,6 @@ class GaussianNoiseModule(NoiseModule):
# false-positive; pylint: disable=no-member
return self._rng.normal(scale=self.std, size=shape).astype(dtype)
# Theoretically-unreachable case.
raise RuntimeError("Unexpected `GaussianeNoiseModule._rng` type.")
raise RuntimeError( # pragma: no cover
"Unexpected `GaussianeNoiseModule._rng` type."
)
......@@ -18,7 +18,7 @@
"""Base API for loss regularization optimizer plug-ins."""
from abc import ABCMeta, abstractmethod
from typing import Any, ClassVar, Dict
from typing import Any, ClassVar, Dict, TypeVar
from typing_extensions import Self # future: import from typing (py >=3.11)
......@@ -34,6 +34,9 @@ __all__ = [
]
T = TypeVar("T")
@create_types_registry
class Regularizer(metaclass=ABCMeta):
"""Abstract class defining an API to implement loss-regularizers.
......@@ -115,9 +118,9 @@ class Regularizer(metaclass=ABCMeta):
@abstractmethod
def run(
self,
gradients: Vector,
weights: Vector,
) -> Vector:
gradients: Vector[T],
weights: Vector[T],
) -> Vector[T]:
"""Compute and add the regularization term's derivative to gradients.
Parameters
......
......@@ -175,7 +175,7 @@ def server_to_client_network(
"Convert server network config to client network config."
return NetworkClientConfig.from_params(
protocol=network_cfg.protocol,
server_uri=f"ws://localhost:{network_cfg.port}",
server_uri=network_cfg.build_server().uri,
name="replaceme",
)
......
......@@ -22,7 +22,7 @@ import multiprocessing as mp
import sys
import traceback
from queue import Queue
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Union
__all__ = [
"run_as_processes",
......@@ -105,12 +105,9 @@ def run_as_processes(
def add_exception_catching(
func: Callable[..., Any],
queue: Queue,
name: Optional[str] = None,
name: str,
) -> Callable[..., Any]:
"""Wrap a function to catch exceptions and put them in a Queue."""
if not name:
name = func.__name__
return functools.partial(wrapped, func=func, queue=queue, name=name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment