Mentions légales du service

Skip to content
Snippets Groups Projects
Commit a748ae71 authored by BIGAUD Nathan's avatar BIGAUD Nathan
Browse files

Merge branch 'haiku-jax' into 'develop'

Add support for Jax / Haiku

See merge request !32
parents 2258d64f 11274f41
No related branches found
No related tags found
1 merge request!32Add support for Jax / Haiku
Pipeline #801710 passed
......@@ -42,6 +42,9 @@ The automatically-imported submodules implemented here are:
Optional Submodules
-------------------
The optional-dependency-based submodules that may be manually imported are:
* haiku: jax- and haiku-interfacing tools
- HaikuModel: Model to wrap a haiku-transformable model function.
- JaxNumpyVector: Vector for jax array data structures.
* [tensorflow][declearn.model.tensorflow]:
TensorFlow-interfacing tools
......@@ -65,6 +68,7 @@ from . import api
from . import sklearn
OPTIONAL_MODULES = [
"jax",
"tensorflow",
"torch",
]
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Haiku models interfacing tools.
This submodule provides with a generic interface to wrap up
any Haiku module instance that is to be trained
through gradient descent.
This module exposes:
* HaikuModel: Model subclass to wrap haiku.Model objects
* JaxNumpyVector: Vector subclass to wrap jax.numpy.ndarray objects
"""
from . import utils
from ._vector import JaxNumpyVector
from ._model import HaikuModel
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model subclass to wrap Haiku models."""
import functools
import inspect
import io
import warnings
from random import SystemRandom
from typing import (
# fmt: off
Any, Callable, Dict, List, Optional, Set, Tuple, Union
)
import haiku as hk
import jax
import jax.numpy as jnp
import joblib # type: ignore
import numpy as np
from typing_extensions import Self
from declearn.data_info import aggregate_data_info
from declearn.model._utils import raise_on_stringsets_mismatch
from declearn.model.api import Model
from declearn.model.haiku._vector import JaxNumpyVector
from declearn.model.haiku.utils import select_device
from declearn.typing import Batch
from declearn.utils import DevicePolicy, get_device_policy, register_type
__all__ = [
"HaikuModel",
]
# alias for unpacked Batch structures, converted to jax arrays
# FUTURE: add support for lists of inputs
JaxBatch = Tuple[List[jax.Array], Optional[jax.Array], Optional[jax.Array]]
@register_type(name="HaikuModel", group="Model")
class HaikuModel(Model):
"""Model wrapper for Haiku Model instances.
This `Model` subclass is designed to wrap a `hk.Module`
instance to be learned federatively.
Notes regarding device management (CPU, GPU, etc.):
* By default, jax places data and operations on GPU whenever one
is available.
* Our `HaikuModel` instead consults the device-placement policy (via
`declearn.utils.get_device_policy`), places the wrapped haiku model's
weights there, and runs computations defined under public methods on
that device.
* Note that there is no guarantee that calling a private method directly
will result in abiding by that policy. Hence, be careful when writing
custom code, and use your own context managers to get guarantees.
* Note that if the global device-placement policy is updated, this will
only be propagated to existing instances by manually calling their
`update_device_policy` method.
* You may consult the device policy enforced by a HaikuModel
instance by accessing its `device_policy` property.
"""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
model: Callable[[jax.Array], jax.Array],
loss: Callable[[jax.Array, jax.Array], jax.Array],
seed: Optional[int] = None,
) -> None:
"""Instantiate a Model interface wrapping a jax-haiku model.
Parameters
----------
model: callable(jax.Array) -> jax.Array
Function encapsulating a `haiku.Module` such that `model(x)`
returns `haiku_module(x)`, constituting a model's forward.
loss: callable(jax.Array, jax.Array) -> jax.Array
Jax-compatible function that defines the model's loss.
It must expect `y_pred` and `y_true` as input arguments (in
that order) and return sample-wise loss values.
seed: int or None, default=None
Random seed used to initialize the haiku-wrapped Pseudo-random
number generator. If none is provided, draw an integer between
0 and 10^6 using `random.SystemRandom`.
"""
super().__init__(model)
# Assign loss module.
self._loss_fn = loss
# Get pure functions from haiku transform.
self._model_fn = model
self._transformed_model = hk.transform(model)
# Select the device where to place computations.
policy = get_device_policy()
self._device = select_device(gpu=policy.gpu, idx=policy.idx)
# Create model state attributes.
self._params = {} # type: hk.Params
self._pnames = [] # type: List[str]
self._trainable = [] # type: List[str]
# Initialize the PRNG.
if seed is None:
seed = int(SystemRandom().random() * 10e6)
self._rng_gen = hk.PRNGSequence(seed)
# Initialized and data_info utils
self._initialized = False
self.data_info = {} # type: Dict[str, Any]
@property
def device_policy(
self,
) -> DevicePolicy:
device = self._device
return DevicePolicy(gpu=(device.platform == "gpu"), idx=device.id)
@property
def required_data_info(
self,
) -> Set[str]:
return set() if self._initialized else {"data_type", "features_shape"}
def initialize(
self,
data_info: Dict[str, Any],
) -> None:
if self._initialized:
return
# Check that required fields are available and of valid type.
self.data_info = aggregate_data_info(
[data_info], self.required_data_info
)
# Initialize the model parameters.
params = self._transformed_model.init(
next(self._rng_gen),
jnp.zeros(
(1, *data_info["features_shape"]), data_info["data_type"]
),
)
self._params = jax.device_put(params, self._device)
self._pnames = [
f"{layer}:{weight}"
for layer, weights in self._params.items()
for weight in weights
]
self._trainable = self._pnames.copy()
self._initialized = True
def get_config(
self,
) -> Dict[str, Any]:
warnings.warn(
"Our custom Haiku serialization relies on pickle,"
"which may be unsafe."
)
with io.BytesIO() as buffer:
joblib.dump(self._model_fn, buffer)
model = buffer.getbuffer().hex()
with io.BytesIO() as buffer:
joblib.dump(self._loss_fn, buffer)
loss = buffer.getbuffer().hex()
return {
"model": model,
"loss": loss,
"data_info": self.data_info,
}
@classmethod
def from_config(
cls,
config: Dict[str, Any],
) -> Self:
with io.BytesIO(bytes.fromhex(config["model"])) as buffer:
model = joblib.load(buffer)
with io.BytesIO(bytes.fromhex(config["loss"])) as buffer:
loss = joblib.load(buffer)
model = cls(model=model, loss=loss)
if config.get("data_info"):
model.initialize(config["data_info"])
return model
def get_weights(
self,
trainable: bool = False,
) -> JaxNumpyVector:
params = {
f"{layer}:{wname}": value
for layer, weights in self._params.items()
for wname, value in weights.items()
}
if trainable:
params = {k: v for k, v in params.items() if k in self._trainable}
return JaxNumpyVector(params)
def set_weights( # type: ignore # Vector subtype specification
self,
weights: JaxNumpyVector,
trainable: bool = False,
) -> None:
if not isinstance(weights, JaxNumpyVector):
raise TypeError("HaikuModel requires JaxNumpyVector weights.")
self._verify_weights_compatibility(weights, trainable=trainable)
for key, val in weights.coefs.items():
layer, weight = key.rsplit(":", 1)
self._params[layer][weight] = val.copy() # type: ignore
def _verify_weights_compatibility(
self,
vector: JaxNumpyVector,
trainable: bool = False,
) -> None:
"""Verify that a vector has the same names as the model's weights.
Parameters
----------
vector: JaxNumpyVector
Vector wrapping weight-related coefficients (e.g. weight
values or gradient-based updates).
trainable: bool, default=False
Whether to restrict the comparision to the model's trainable
weights rather than to all of its weights.
Raises
------
KeyError:
In case some expected keys are missing, or additional keys
are present. Be verbose about the identified mismatch(es).
"""
received = set(vector.coefs)
expected = set(self._trainable if trainable else self._pnames)
raise_on_stringsets_mismatch(
received, expected, context="model weights"
)
def set_trainable_weights(
self,
criterion: Union[
Callable[[str, str, jax.Array], bool],
Dict[str, Dict[str, Any]],
List[str],
],
) -> None:
"""Sets the index of trainable weights.
The split can be done by providing a functions applying conditions on
the named weights, as haiku users are used to do, but can also accept
an explicit dict of names or even the index of the parameter leaves
stored by our HaikuModel.
Example use :
>>> self.get_named_weights() = {'linear': {'w': None, 'b': None}}
Using a function as input
>>> criterion = lambda layer, name, value: name == 'w'
>>> self.set_trainable_weights(criterion)
>>> self._trainable
[0]
Using a dictionnary or pytree
>>> criterion = {'linear': {'b': None}}
>>> self.set_trainable_weights(criterion)
>>> self._trainable
[1]
Note : model needs to be initialized
Arguments
--------
criterion : Callable or dict(str,dict(str,any)) or list(int)
Criterion to be used to identify trainable params. If Callable,
must be a function taking in the name of the module (e.g.
layer name), the element name (e.g. parameter name) and the
corresponding data and returning a boolean. See
[the haiku doc](https://tinyurl.com/3v28upaz)
for details. If a list of integers, should represent the index of
trainable parameters in the parameter tree leaves. If a dict,
should be formatted as a pytree.
"""
if not self._initialized:
raise ValueError("Model needs to be initialized first")
if (
isinstance(criterion, list)
and all(isinstance(c, str) for c in criterion)
and all(c in self._pnames for c in criterion)
):
self._trainable = criterion
else:
self._trainable = [] # reset if needed
if inspect.isfunction(criterion):
include_fn = (
criterion
) # type: Callable[[str, str, jax.Array], bool]
elif isinstance(criterion, dict):
include_fn = self._build_include_fn(criterion)
else:
raise TypeError(
"The provided criterion does not conform "
"to the expected format and or type."
)
gen = hk.data_structures.traverse(self._params)
for layer, name, value in gen:
if include_fn(layer, name, value):
self._trainable.append(f"{layer}:{name}")
@staticmethod
def _build_include_fn(
node_names: Dict[str, Dict[str, Any]],
) -> Callable[[str, str, jax.Array], bool]:
"""Build an equality-checking function for parameters' traversal."""
def include_fn(layer: str, name: str, value: jax.Array) -> bool:
# mandatory signature; pylint: disable=unused-argument
if layer in list(node_names.keys()):
return name in list(node_names[layer].keys())
return False
return include_fn
def get_weight_names(
self,
trainable: bool = False,
) -> List[str]:
"""Return the list of names of the model's weights.
Parameters
----------
trainable: bool
Whether to return only the names of trainable weights,
rather than including both trainable and frozen ones.
Returns
-------
names:
Ordered list of model weights' names.
"""
return self._trainable.copy() if trainable else self._pnames.copy()
def compute_batch_gradients(
self,
batch: Batch,
max_norm: Optional[float] = None,
) -> JaxNumpyVector:
# Unpack input batch and prepare model parameters.
inputs = self._unpack_batch(batch)
train_params, fixed_params = hk.data_structures.partition(
predicate=lambda l, w, _: f"{l}:{w}" in self._trainable,
structure=self._params,
)
rng = next(self._rng_gen)
# Compute batch-averaged gradients, opt. clipped on a per-sample basis.
if max_norm:
grads = self._clipped_grad_fn(
train_params, fixed_params, rng, inputs, max_norm
)
grads = [value.mean(0) for value in grads]
else:
grads = jax.tree_util.tree_leaves(
self._grad_fn(train_params, fixed_params, rng, inputs)
)
# Return the gradients, flattened into a JaxNumpyVector container.
return JaxNumpyVector(dict(zip(self._trainable, grads)))
@functools.cached_property
def _grad_fn(
self,
) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]:
"""Lazy-built jax function to compute batch-averaged gradients."""
return jax.jit(jax.grad(self._forward))
def _forward(
self,
train_params: hk.Params,
fixed_params: hk.Params,
rng: jax.Array,
batch: JaxBatch,
) -> jax.Array:
"""The forward pass chaining the model to the loss as a pure function.
Parameters
-------
params: haiku.Params
The model parameters, as a nested dict of jax arrays.
rng: jax.Array
A jax pseudo-random number generator (PRNG) key.
batch: (list[jax.Array], jax.Array, optional[jax.Array])
Batch of jax-converted inputs, comprising (in that order)
input data, ground-truth labels and optional sample weights.
Returns
-------
loss: jax.Array
The mean loss over the input data provided.
"""
inputs, y_true, s_wght = batch
params = hk.data_structures.merge(train_params, fixed_params)
y_pred = self._transformed_model.apply(params, rng, *inputs)
s_loss = self._loss_fn(y_pred, y_true) # type: ignore
if s_wght is not None:
s_loss = s_loss * s_wght
return jnp.mean(s_loss)
@functools.cached_property
def _clipped_grad_fn(
self,
) -> Callable[
[hk.Params, hk.Params, jax.Array, JaxBatch, float], List[jax.Array]
]:
"""Lazy-built jax function to compute clipped sample-wise gradients.
Note : The vmap in_axis parameters work thank to the jax feature of
applying optional parameters to pytrees.
"""
def clipped_grad_fn(
train_params: hk.Params,
fixed_params: hk.Params,
rng: jax.Array,
batch: JaxBatch,
max_norm: float,
) -> List[jax.Array]:
"""Compute and clip gradients wrt parameters for a sample."""
inputs, y_true, s_wght = batch
batch = (inputs, y_true, None)
grads = jax.grad(self._forward)(
train_params, fixed_params, rng, batch
)
grads_flat = [
grad / jnp.maximum(jnp.linalg.norm(grad) / max_norm, 1.0)
for grad in jax.tree_util.tree_leaves(grads)
]
if s_wght is not None:
grads_flat = [g * s_wght for g in grads_flat]
return grads_flat
in_axes = [None, None, None, 0, None] # map on inputs' first dimension
return jax.jit(jax.vmap(clipped_grad_fn, in_axes))
@staticmethod
def _unpack_batch(batch: Batch) -> JaxBatch:
"""Unpack and enforce jnp.array conversion to an input data batch."""
def convert(data: Any) -> Optional[jax.Array]:
if (data is None) or isinstance(data, jax.Array):
return data
if isinstance(data, np.ndarray):
return jnp.array(data) # pylint: disable=no-member
raise TypeError("HaikuModel requires numpy or jax.numpy data.")
# similar code to TorchModel; pylint: disable=duplicate-code
# Convert batched data to jax Arrays.
inputs, y_true, s_wght = batch
if not isinstance(inputs, (tuple, list)):
inputs = [inputs]
output = [list(map(convert, inputs)), convert(y_true), convert(s_wght)]
return output # type: ignore
def apply_updates( # type: ignore # Vector subtype specification
self,
updates: JaxNumpyVector,
) -> None:
if not isinstance(updates, JaxNumpyVector):
raise TypeError("HaikuModel requires JaxNumpyVector updates.")
self._verify_weights_compatibility(updates, trainable=True)
for key, val in updates.coefs.items():
layer, weight = key.rsplit(":", 1)
self._params[layer][weight] += val # type: ignore
def compute_batch_predictions(
self,
batch: Batch,
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
inputs, y_true, s_wght = self._unpack_batch(batch)
if y_true is None:
raise TypeError(
"`HaikuModel.compute_batch_predictions` received a "
"batch with `y_true=None`, which is unsupported. Please "
"correct the inputs, or override this method to support "
"creating labels from the base inputs."
)
y_pred = self._transformed_model.apply(
self._params, next(self._rng_gen), *inputs
)
return (
np.asarray(y_true),
np.asarray(y_pred),
None if s_wght is None else np.asarray(s_wght),
)
def loss_function(
self,
y_true: np.ndarray,
y_pred: np.ndarray,
) -> np.ndarray:
s_loss = self._loss_fn(jnp.array(y_pred), jnp.array(y_true))
return np.array(s_loss).squeeze()
def update_device_policy(
self,
policy: Optional[DevicePolicy] = None,
) -> None:
# similar code to tensorflow Model; pylint: disable=duplicate-code
# Select the device to use based on the provided or global policy.
if policy is None:
policy = get_device_policy()
device = select_device(gpu=policy.gpu, idx=policy.idx)
# When needed, re-create the model to force moving it to the device.
if self._device is not device:
self._device = device
self._params = jax.device_put(self._params, self._device)
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JaxNumpyVector data arrays container."""
from typing import Any, Callable, Dict, Optional, Set, Type
import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config as jaxconfig
from typing_extensions import Self # future: import from typing (Py>=3.11)
from declearn.model.api import Vector, register_vector_type
from declearn.model.haiku.utils import select_device
from declearn.model.sklearn import NumpyVector
from declearn.utils import get_device_policy
__all__ = [
"JaxNumpyVector",
]
jaxconfig.update("jax_enable_x64", True) # enable float64 support
@register_vector_type(jax.Array)
class JaxNumpyVector(Vector):
"""Vector subclass to store jax.numpy.ndarray coefficients.
This Vector is designed to store a collection of named
jax numpy arrays or scalars, enabling computations that are
either applied to each and every coefficient, or imply
two sets of aligned coefficients (i.e. two JaxNumpyVector
instances with similar coefficients specifications).
Use `vector.coefs` to access the stored coefficients.
Notes
-----
- A `JaxnumpyVector` can be operated with either a:
- scalar value
- `NumpyVector` that has similar specifications
- `JaxNumpyVector` that has similar specifications
=> resulting in a `JaxNumpyVector` in each of these cases.
- The wrapped arrays may be placed on any device (CPU, GPU...)
and may not be all on the same device.
- The device-placement of the initial `JaxNumpyVector`'s data
is preserved by operations, including with `NumpyVector`.
- When combining two `JaxNumpyVector`, the device-placement
of the left-most one is used; in that case, one ends up with
`gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a
warning will be emitted to prevent silent un-optimized copies.
- When deserializing a `JaxNumpyVector` (either by directly using
`JaxNumpyVector.unpack` or loading one from a JSON dump), loaded
arrays are placed based on the global device-placement policy
(accessed via `declearn.utils.get_device_policy`). Thus it may
have a different device-placement schema than at dump time but
should be coherent with that of `HaikuModel` computations.
"""
@property
def _op_add(self) -> Callable[[Any, Any], Any]:
return jnp.add
@property
def _op_sub(self) -> Callable[[Any, Any], Any]:
return jnp.subtract
@property
def _op_mul(self) -> Callable[[Any, Any], Any]:
return jnp.multiply
@property
def _op_div(self) -> Callable[[Any, Any], Any]:
return jnp.divide
@property
def _op_pow(self) -> Callable[[Any, Any], Any]:
return jnp.power
@property
def compatible_vector_types(self) -> Set[Type[Vector]]:
types = super().compatible_vector_types
return types.union({NumpyVector, JaxNumpyVector})
def __init__(self, coefs: Dict[str, jax.Array]) -> None:
super().__init__(coefs)
def _apply_operation(
self,
other: Any,
func: Callable[[Any, Any], Any],
) -> Self:
# Ensure 'other' JaxNumpyVector shares this vector's device placement.
if isinstance(other, JaxNumpyVector):
coefs = {
key: jax.device_put(val, self.coefs[key].device())
for key, val in other.coefs.items()
}
other = JaxNumpyVector(coefs)
return super()._apply_operation(other, func)
def __eq__(self, other: Any) -> bool:
valid = isinstance(other, JaxNumpyVector)
valid = valid and (self.coefs.keys() == other.coefs.keys())
return valid and all(
jnp.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs
)
def sign(
self,
) -> Self:
return self.apply_func(jnp.sign)
def minimum(
self,
other: Any,
) -> Self:
if isinstance(other, JaxNumpyVector):
return self._apply_operation(other, jnp.minimum)
return self.apply_func(jnp.minimum, other)
def maximum(
self,
other: Any,
) -> Self:
if isinstance(other, Vector):
return self._apply_operation(other, jnp.maximum)
return self.apply_func(jnp.maximum, other)
def sum(
self,
axis: Optional[int] = None,
keepdims: bool = False,
) -> Self:
coefs = {
key: jnp.array(jnp.sum(val, axis=axis, keepdims=keepdims))
for key, val in self.coefs.items()
}
return self.__class__(coefs)
def pack(
self,
) -> Dict[str, Any]:
return {key: np.asarray(arr) for key, arr in self.coefs.items()}
@classmethod
def unpack(
cls,
data: Dict[str, Any],
) -> Self:
policy = get_device_policy()
device = select_device(gpu=policy.gpu, idx=policy.idx)
coefs = {k: jax.device_put(arr, device) for k, arr in data.items()}
return cls(coefs)
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for jax backend support code.
GPU/CPU backing device management utils:
* select_device:
Select a backing device to use based on inputs and availability.
"""
from ._gpu import select_device
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for GPU support and device management in jax."""
import warnings
from typing import Optional
import jax
import jaxlib.xla_extension as xe
__all__ = ["select_device"]
def select_device(
gpu: bool,
idx: Optional[int] = None,
) -> xe.Device: # pylint: disable=c-extension-no-member
"""Select a backing device to use based on inputs and availability.
Parameters
----------
gpu: bool
Whether to select a GPU device rather than the CPU one.
idx: int or None, default=None
Optional pre-selected device index. Only used when `gpu=True`.
If `idx is None` or exceeds the number of available GPU devices,
use the first available one.
Warns
-----
UserWarning:
If `gpu=True` but no GPU is available.
If `idx` exceeds the number of available GPU devices.
Returns
-------
device: jaxlib.xla_extension.Device
Selected device.
"""
idx = 0 if idx is None else idx
# List available CPU or GPU devices.
device_type = "gpu" if gpu else "cpu"
devices = [d for d in jax.devices() if d.platform == device_type]
# Case when no GPU is available: warn and use a CPU instead.
if gpu and not devices:
warnings.warn(
"Cannot use a GPU device: either CUDA is unavailable "
"or no GPU is visible to jax."
)
device_type, idx = "cpu", 0
devices = jax.devices(device_type)
# similar code to tensorflow util; pylint: disable=duplicate-code
# Case when the desired device index is invalid: select another one.
if idx >= len(devices):
warnings.warn(
f"Cannot use {device_type} device n°{idx}: index is out-of-range."
f"\nUsing {device_type} device n°0 instead."
)
idx = 0
# Return the selected device.
return devices[idx]
......@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow models interfacing tools.
"""Pytorch models interfacing tools.
This submodule provides with a generic interface to wrap up any PyTorch
`torch.nn.Module` instance that is to be trained with gradient descent.
......
......@@ -35,7 +35,7 @@ __all__ = [
]
FrameworkType = Literal["numpy", "tensorflow", "torch"]
FrameworkType = Literal["numpy", "tensorflow", "torch", "jax"]
def list_available_frameworks() -> List[FrameworkType]:
......@@ -81,6 +81,9 @@ class GradientsTestCase:
if self.framework == "torch":
module = importlib.import_module("declearn.model.torch")
return module.TorchVector
if self.framework == "jax":
module = importlib.import_module("declearn.model.haiku")
return module.JaxNumpyVector
raise ValueError(f"Invalid framework '{self.framework}'")
def convert(self, array: np.ndarray) -> ArrayLike:
......@@ -94,12 +97,17 @@ class GradientsTestCase:
if self.framework == "torch":
torch = importlib.import_module("torch")
return torch.from_numpy(array)
if self.framework == "jax":
jnp = importlib.import_module("jax.numpy")
return jnp.asarray(array)
raise ValueError(f"Invalid framework '{self.framework}'")
def to_numpy(self, array: ArrayLike) -> np.ndarray:
"""Convert an input framework-based structure to a numpy array."""
if isinstance(array, np.ndarray):
return array
if self.framework == "jax":
return np.asarray(array)
if self.framework == "tensorflow": # add support for IndexedSlices
tensorflow = importlib.import_module("tensorflow")
if isinstance(array, tensorflow.IndexedSlices):
......
......@@ -46,10 +46,12 @@ dependencies = [
]
[project.optional-dependencies]
# all non-docs, non-tests extra dependencies
all = [
all = [ # all non-tests extra dependencies
"dm-haiku == 0.0.9",
"jax == 0.4.4",
"functorch",
"grpcio >= 1.45",
"jax[cpu] == 0.4.4",
"opacus ~= 1.1",
"protobuf >= 3.19",
"tensorflow ~= 2.5",
......@@ -64,6 +66,10 @@ grpc = [
"grpcio >= 1.45",
"protobuf >= 3.19",
]
haiku = [
"jax == 0.4.4",
"dm-haiku == 0.0.9",
]
tensorflow = [
"tensorflow ~= 2.5",
]
......
......@@ -62,24 +62,44 @@ from declearn.test_utils import FrameworkType
from declearn.utils import run_as_processes
from declearn.utils import set_device_policy
# pylint: disable=ungrouped-imports; optional frameworks' dependencies
# optional frameworks' dependencies pylint: disable=ungrouped-imports
# pylint: disable=duplicate-code
# tensorflow imports
try:
import tensorflow as tf # type: ignore
from declearn.model.tensorflow import TensorflowModel
except ModuleNotFoundError:
pass
else:
from declearn.model.tensorflow import TensorflowModel
# torch imports
try:
import torch
except ModuleNotFoundError:
pass
else:
from declearn.model.torch import TorchModel
# pylint: enable=duplicate-code
# haiku and jax imports
try:
import haiku as hk
import jax
except ModuleNotFoundError:
pass
else:
from declearn.model.haiku import HaikuModel
def haiku_model_fn(inputs: jax.Array) -> jax.Array:
"""Simple linear model implemented with Haiku."""
return hk.Linear(1)(inputs)
def haiku_loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array:
"""Sample-wise squared error loss function."""
return (y_pred - y_true) ** 2
SEED = 0
R2_THRESHOLD = 0.999
# pylint: disable=too-many-function-args
def get_model(framework: FrameworkType) -> Model:
"""Set up a simple toy regression model."""
......@@ -101,6 +121,8 @@ def get_model(framework: FrameworkType) -> Model:
torch.nn.Flatten(0),
)
model = TorchModel(torchmod, loss=torch.nn.MSELoss())
elif framework == "jax":
model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn)
else:
raise ValueError("unrecognised framework")
return model
......@@ -245,7 +267,9 @@ def _server_routine(
# pylint: disable=too-many-arguments
# Set up the FederatedServer.
model = get_model(framework)
netwk = NetworkServerConfig("websockets", "127.0.0.1", 8765)
netwk = NetworkServerConfig.from_params(
protocol="websockets", host="127.0.0.1", port=8765
)
optim = FLOptimConfig.from_params(
aggregator="averaging",
client_opt={
......@@ -254,7 +278,6 @@ def _server_routine(
},
server_opt=1.0,
)
server = FederatedServer(
model,
netwk,
......@@ -281,8 +304,9 @@ def _client_routine(
name: str = "client",
) -> None:
"""Routine to run a FL client, called by `run_declearn_experiment`."""
# Run the declearn FL client routine.
netwk = NetworkClientConfig("websockets", "ws://localhost:8765", name)
netwk = NetworkClientConfig.from_params(
protocol="websockets", server_uri="ws://localhost:8765", name=name
)
client = FederatedClient(netwk, train, valid)
client.run()
......
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for HaikuModel."""
import sys
from typing import Any, Callable, Dict, List, Literal, Union
import numpy as np
import pytest
try:
import haiku as hk
import jax
import jax.numpy as jnp
from jax.config import config as jaxconfig
except ModuleNotFoundError:
pytest.skip("jax and/or haiku are unavailable", allow_module_level=True)
from declearn.model.haiku import HaikuModel, JaxNumpyVector
from declearn.typing import Batch
from declearn.utils import set_device_policy
# dirty trick to import from `model_testing.py`;
# pylint: disable=wrong-import-order, wrong-import-position
sys.path.append(".")
from model_testing import ModelTestCase, ModelTestSuite
# Overriding float32 default in jax
jaxconfig.update("jax_enable_x64", True)
def cnn_fn(inputs: jax.Array) -> jax.Array:
"""Simple CNN in a purely functional form"""
stack = [
hk.Conv2D(output_channels=32, kernel_shape=(7, 7), padding="SAME"),
jax.nn.relu,
hk.MaxPool(window_shape=(8, 8, 1), strides=(8, 8, 1), padding="VALID"),
hk.Conv2D(output_channels=16, kernel_shape=(5, 5), padding="SAME"),
jax.nn.relu,
hk.AvgPool(window_shape=(8, 8, 1), strides=(8, 8, 1), padding="VALID"),
hk.Reshape((16,)),
hk.Linear(1),
]
model = hk.Sequential(stack) # type: ignore
return model(inputs)
def mlp_fn(inputs: jax.Array) -> jax.Array:
"""Simple MLP in a purely functional form"""
model = hk.nets.MLP([32, 16, 1])
return model(inputs)
def rnn_fn(inputs: jax.Array) -> jax.Array:
"""Simple RNN in a purely functional form"""
inputs = inputs[None, :] if len(inputs.shape) == 1 else inputs
core = hk.DeepRNN(
[
hk.Embed(100, 32),
hk.LSTM(32),
jax.nn.tanh,
]
)
batch_size = inputs.shape[0]
initial_state = core.initial_state(batch_size)
logits, _ = hk.dynamic_unroll(
core, inputs, initial_state, time_major=False
)
return hk.Linear(1)(logits)[:, -1, :]
def loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array:
"""Per-sample binary cross entropy"""
y_pred = jax.nn.sigmoid(y_pred)
y_pred = jnp.squeeze(y_pred)
log_p, log_not_p = jnp.log(y_pred), jnp.log(1.0 - y_pred)
return -y_true * log_p - (1.0 - y_true) * log_not_p
class HaikuTestCase(ModelTestCase):
"""Tensorflow Keras test-case-provider fixture.
Implemented architectures are:
* "MLP":
- input: 64-dimensional features vectors
- stack: 32-neurons fully-connected layer with ReLU
16-neurons fully-connected layer with ReLU
1 output neuron with sigmoid activation
* "RNN":
- input: 128-tokens-sequence in a 100-tokens-vocabulary
- stack: 32-dimensional embedding matrix
16-neurons LSTM layer with tanh activation
1 output neuron with sigmoid activation
* "CNN":
- input: 64x64 image with 3 channels (normalized values)
- stack: 32 7x7 conv. filters, then 8x8 max pooling
16 5x5 conv. filters, then 8x8 avg pooling
1 output neuron with sigmoid activation
"""
vector_cls = JaxNumpyVector
tensor_cls = jax.Array
def __init__(
self,
kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
device: Literal["cpu", "gpu"],
) -> None:
"""Specify the desired model architecture."""
if kind not in ("MLP", "MLP-tune", "RNN", "CNN"):
raise ValueError(f"Invalid test architecture: '{kind}'.")
if device not in ("cpu", "gpu"):
raise ValueError(f"Invalid device choice for test: '{device}'.")
self.kind = kind
self.device = device
set_device_policy(gpu=(device == "gpu"), idx=0)
@staticmethod
def to_numpy(
tensor: Any,
) -> np.ndarray:
"""Convert an input jax jax.Array to a numpy jax.Array."""
assert isinstance(tensor, jax.Array)
return np.asarray(tensor)
@property
def dataset(
self,
) -> List[Batch]:
"""Suited toy binary-classification dataset."""
rng = np.random.default_rng(seed=0)
if self.kind.startswith("MLP"):
inputs = rng.normal(size=(2, 32, 64)).astype("float32")
elif self.kind == "RNN":
inputs = rng.choice(100, size=(2, 32, 128))
elif self.kind == "CNN":
inputs = rng.normal(size=(2, 32, 64, 64, 3)).astype("float32")
labels = rng.choice(2, size=(2, 32))
inputs = jnp.asarray(inputs) # type: ignore
labels = jnp.asarray(labels) # type: ignore
batches = list(zip(inputs, labels, [None, None]))
return batches # type: ignore
@property
def model(self) -> HaikuModel:
"""Suited toy binary-classification haiku models."""
if self.kind == "CNN":
shape = [64, 64, 3]
model_fn = cnn_fn
elif self.kind.startswith("MLP"):
shape = [64]
model_fn = mlp_fn
elif self.kind == "RNN":
shape = [128]
model_fn = rnn_fn
model = HaikuModel(model_fn, loss_fn)
model.initialize(
{
"features_shape": shape,
"data_type": "int" if self.kind == "RNN" else "float32",
}
)
if self.kind == "MLP-tune":
names = model.get_weight_names()
model.set_trainable_weights([names[i] for i in range(3)])
return model
def assert_correct_device(
self,
vector: JaxNumpyVector,
) -> None:
"""Raise if a vector is backed on the wrong type of device."""
name = f"{self.device}:0"
assert all(
f"{arr.device().platform}:{arr.device().id}" == name
for arr in vector.coefs.values()
)
def get_trainable_criterion(
self,
c_type: str,
) -> Union[
List[str],
Dict[str, Dict[str, Any]],
Callable[[str, str, jax.Array], bool],
]:
"Build different weight freezing criteria"
if c_type == "names":
names = self.model.get_weight_names()
return [names[2], names[3]]
if c_type == "pytree":
params = getattr(self.model, "_params")
return {k: v for i, (k, v) in enumerate(params.items()) if i != 1}
if c_type == "predicate":
return lambda m, n, p: n != "b"
raise KeyError(f"Invalid 'c_type' parameter: {c_type}.")
@pytest.fixture(name="test_case")
def fixture_test_case(
kind: Literal["MLP", "MLP-tune", "RNN", "CNN"],
device: Literal["cpu", "gpu"],
) -> HaikuTestCase:
"""Fixture to access a TensorflowTestCase."""
return HaikuTestCase(kind, device)
DEVICES = ["cpu"]
if any(d.platform == "gpu" for d in jax.devices()):
DEVICES.append("gpu")
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"])
class TestHaikuModel(ModelTestSuite):
"""Unit tests for declearn.model.tensorflow.TensorflowModel."""
@pytest.mark.filterwarnings("ignore: Our custom Haiku serialization")
def test_serialization(
self,
test_case: ModelTestCase,
) -> None:
super().test_serialization(test_case)
@pytest.mark.parametrize(
"criterion_type", ["names", "pytree", "predicate"]
)
def test_get_frozen_weights(
self,
test_case: HaikuTestCase,
criterion_type: str,
) -> None:
"""Check that `get_weights` behaves properly with frozen weights."""
model = test_case.model # type: HaikuModel
criterion = test_case.get_trainable_criterion(criterion_type)
model.set_trainable_weights(criterion) # freeze some weights
w_all = model.get_weights()
w_trn = model.get_weights(trainable=True)
assert set(w_trn.coefs).issubset(w_all.coefs) # check on keys
n_params = len(model.get_weight_names())
n_trainable = len(model.get_weight_names(trainable=True))
assert n_trainable < n_params
assert len(w_trn.coefs) == n_trainable
assert len(w_all.coefs) == n_params
@pytest.mark.parametrize(
"criterion_type", ["names", "pytree", "predicate"]
)
def test_set_frozen_weights(
self,
test_case: HaikuTestCase,
criterion_type: str,
) -> None:
"""Check that `set_weights` behaves properly with frozen weights."""
# similar code to TorchModel tests; pylint: disable=duplicate-code
# Setup a model with some frozen weights, and gather trainable ones.
model = test_case.model
criterion = test_case.get_trainable_criterion(criterion_type)
model.set_trainable_weights(criterion) # freeze some weights
w_trn = model.get_weights(trainable=True)
# Test that `set_weights` works if and only if properly parametrized.
with pytest.raises(KeyError):
model.set_weights(w_trn)
with pytest.raises(KeyError):
model.set_weights(model.get_weights(), trainable=True)
model.set_weights(w_trn, trainable=True)
def test_proper_model_placement(
self,
test_case: HaikuTestCase,
) -> None:
"""Check that at instantiation, model weights are properly placed."""
model = test_case.model
policy = model.device_policy
assert policy.gpu == (test_case.device == "gpu")
assert policy.idx == 0
params = jax.tree_util.tree_leaves(getattr(model, "_params"))
device = f"{test_case.device}:0"
for arr in params:
assert f"{arr.device().platform}:{arr.device().id}" == device
......@@ -17,8 +17,8 @@
"""Unit tests for TensorflowModel."""
import warnings
import sys
import warnings
from typing import Any, List, Literal
import numpy as np
......@@ -38,7 +38,7 @@ from declearn.utils import set_device_policy
# dirty trick to import from `model_testing.py`;
# pylint: disable=wrong-import-order, wrong-import-position
sys.path.append(".")
from model_testing import ModelTestSuite, ModelTestCase
from model_testing import ModelTestCase, ModelTestSuite
class TensorflowTestCase(ModelTestCase):
......@@ -51,7 +51,7 @@ class TensorflowTestCase(ModelTestCase):
16-neurons fully-connected layer with ReLU
1 output neuron with sigmoid activation
* "MLP-tune":
- same as NLP, but freeze the first layer of the stack
- same as MLP, but freeze the first layer of the stack
* "RNN":
- input: 128-tokens-sequence in a 100-tokens-vocabulary
- stack: 32-dimensional embedding matrix
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment