diff --git a/declearn/model/__init__.py b/declearn/model/__init__.py index 5a84dc83c7ad4381526c2cc8a409495c69b02313..71b604e8ad37c380324cd623c2294259ffc54327 100644 --- a/declearn/model/__init__.py +++ b/declearn/model/__init__.py @@ -42,6 +42,9 @@ The automatically-imported submodules implemented here are: Optional Submodules ------------------- The optional-dependency-based submodules that may be manually imported are: +* haiku: jax- and haiku-interfacing tools + - HaikuModel: Model to wrap a haiku-transformable model function. + - JaxNumpyVector: Vector for jax array data structures. * [tensorflow][declearn.model.tensorflow]: TensorFlow-interfacing tools @@ -65,6 +68,7 @@ from . import api from . import sklearn OPTIONAL_MODULES = [ + "jax", "tensorflow", "torch", ] diff --git a/declearn/model/haiku/__init__.py b/declearn/model/haiku/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1adb8f1f2d0d9ec2593a3dcacc1d25bd808070cf --- /dev/null +++ b/declearn/model/haiku/__init__.py @@ -0,0 +1,31 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Haiku models interfacing tools. + +This submodule provides with a generic interface to wrap up +any Haiku module instance that is to be trained +through gradient descent. + +This module exposes: +* HaikuModel: Model subclass to wrap haiku.Model objects +* JaxNumpyVector: Vector subclass to wrap jax.numpy.ndarray objects +""" + +from . import utils +from ._vector import JaxNumpyVector +from ._model import HaikuModel diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py new file mode 100644 index 0000000000000000000000000000000000000000..17d19c87af034f41c9064dad7ac9460bf1cafb48 --- /dev/null +++ b/declearn/model/haiku/_model.py @@ -0,0 +1,521 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model subclass to wrap Haiku models.""" + +import functools +import inspect +import io +import warnings +from random import SystemRandom +from typing import ( + # fmt: off + Any, Callable, Dict, List, Optional, Set, Tuple, Union +) + +import haiku as hk +import jax +import jax.numpy as jnp +import joblib # type: ignore +import numpy as np +from typing_extensions import Self + +from declearn.data_info import aggregate_data_info +from declearn.model._utils import raise_on_stringsets_mismatch +from declearn.model.api import Model +from declearn.model.haiku._vector import JaxNumpyVector +from declearn.model.haiku.utils import select_device +from declearn.typing import Batch +from declearn.utils import DevicePolicy, get_device_policy, register_type + +__all__ = [ + "HaikuModel", +] + +# alias for unpacked Batch structures, converted to jax arrays +# FUTURE: add support for lists of inputs +JaxBatch = Tuple[List[jax.Array], Optional[jax.Array], Optional[jax.Array]] + + +@register_type(name="HaikuModel", group="Model") +class HaikuModel(Model): + """Model wrapper for Haiku Model instances. + + This `Model` subclass is designed to wrap a `hk.Module` + instance to be learned federatively. + + Notes regarding device management (CPU, GPU, etc.): + * By default, jax places data and operations on GPU whenever one + is available. + * Our `HaikuModel` instead consults the device-placement policy (via + `declearn.utils.get_device_policy`), places the wrapped haiku model's + weights there, and runs computations defined under public methods on + that device. + * Note that there is no guarantee that calling a private method directly + will result in abiding by that policy. Hence, be careful when writing + custom code, and use your own context managers to get guarantees. + * Note that if the global device-placement policy is updated, this will + only be propagated to existing instances by manually calling their + `update_device_policy` method. + * You may consult the device policy enforced by a HaikuModel + instance by accessing its `device_policy` property. + """ + + # pylint: disable=too-many-instance-attributes + + def __init__( + self, + model: Callable[[jax.Array], jax.Array], + loss: Callable[[jax.Array, jax.Array], jax.Array], + seed: Optional[int] = None, + ) -> None: + """Instantiate a Model interface wrapping a jax-haiku model. + + Parameters + ---------- + model: callable(jax.Array) -> jax.Array + Function encapsulating a `haiku.Module` such that `model(x)` + returns `haiku_module(x)`, constituting a model's forward. + loss: callable(jax.Array, jax.Array) -> jax.Array + Jax-compatible function that defines the model's loss. + It must expect `y_pred` and `y_true` as input arguments (in + that order) and return sample-wise loss values. + seed: int or None, default=None + Random seed used to initialize the haiku-wrapped Pseudo-random + number generator. If none is provided, draw an integer between + 0 and 10^6 using `random.SystemRandom`. + """ + super().__init__(model) + # Assign loss module. + self._loss_fn = loss + # Get pure functions from haiku transform. + self._model_fn = model + self._transformed_model = hk.transform(model) + # Select the device where to place computations. + policy = get_device_policy() + self._device = select_device(gpu=policy.gpu, idx=policy.idx) + # Create model state attributes. + self._params = {} # type: hk.Params + self._pnames = [] # type: List[str] + self._trainable = [] # type: List[str] + # Initialize the PRNG. + if seed is None: + seed = int(SystemRandom().random() * 10e6) + self._rng_gen = hk.PRNGSequence(seed) + # Initialized and data_info utils + self._initialized = False + self.data_info = {} # type: Dict[str, Any] + + @property + def device_policy( + self, + ) -> DevicePolicy: + device = self._device + return DevicePolicy(gpu=(device.platform == "gpu"), idx=device.id) + + @property + def required_data_info( + self, + ) -> Set[str]: + return set() if self._initialized else {"data_type", "features_shape"} + + def initialize( + self, + data_info: Dict[str, Any], + ) -> None: + if self._initialized: + return + # Check that required fields are available and of valid type. + self.data_info = aggregate_data_info( + [data_info], self.required_data_info + ) + # Initialize the model parameters. + params = self._transformed_model.init( + next(self._rng_gen), + jnp.zeros( + (1, *data_info["features_shape"]), data_info["data_type"] + ), + ) + self._params = jax.device_put(params, self._device) + self._pnames = [ + f"{layer}:{weight}" + for layer, weights in self._params.items() + for weight in weights + ] + self._trainable = self._pnames.copy() + self._initialized = True + + def get_config( + self, + ) -> Dict[str, Any]: + warnings.warn( + "Our custom Haiku serialization relies on pickle," + "which may be unsafe." + ) + with io.BytesIO() as buffer: + joblib.dump(self._model_fn, buffer) + model = buffer.getbuffer().hex() + with io.BytesIO() as buffer: + joblib.dump(self._loss_fn, buffer) + loss = buffer.getbuffer().hex() + return { + "model": model, + "loss": loss, + "data_info": self.data_info, + } + + @classmethod + def from_config( + cls, + config: Dict[str, Any], + ) -> Self: + with io.BytesIO(bytes.fromhex(config["model"])) as buffer: + model = joblib.load(buffer) + with io.BytesIO(bytes.fromhex(config["loss"])) as buffer: + loss = joblib.load(buffer) + model = cls(model=model, loss=loss) + if config.get("data_info"): + model.initialize(config["data_info"]) + return model + + def get_weights( + self, + trainable: bool = False, + ) -> JaxNumpyVector: + params = { + f"{layer}:{wname}": value + for layer, weights in self._params.items() + for wname, value in weights.items() + } + if trainable: + params = {k: v for k, v in params.items() if k in self._trainable} + return JaxNumpyVector(params) + + def set_weights( # type: ignore # Vector subtype specification + self, + weights: JaxNumpyVector, + trainable: bool = False, + ) -> None: + if not isinstance(weights, JaxNumpyVector): + raise TypeError("HaikuModel requires JaxNumpyVector weights.") + self._verify_weights_compatibility(weights, trainable=trainable) + for key, val in weights.coefs.items(): + layer, weight = key.rsplit(":", 1) + self._params[layer][weight] = val.copy() # type: ignore + + def _verify_weights_compatibility( + self, + vector: JaxNumpyVector, + trainable: bool = False, + ) -> None: + """Verify that a vector has the same names as the model's weights. + + Parameters + ---------- + vector: JaxNumpyVector + Vector wrapping weight-related coefficients (e.g. weight + values or gradient-based updates). + trainable: bool, default=False + Whether to restrict the comparision to the model's trainable + weights rather than to all of its weights. + + Raises + ------ + KeyError: + In case some expected keys are missing, or additional keys + are present. Be verbose about the identified mismatch(es). + """ + received = set(vector.coefs) + expected = set(self._trainable if trainable else self._pnames) + raise_on_stringsets_mismatch( + received, expected, context="model weights" + ) + + def set_trainable_weights( + self, + criterion: Union[ + Callable[[str, str, jax.Array], bool], + Dict[str, Dict[str, Any]], + List[str], + ], + ) -> None: + """Sets the index of trainable weights. + + The split can be done by providing a functions applying conditions on + the named weights, as haiku users are used to do, but can also accept + an explicit dict of names or even the index of the parameter leaves + stored by our HaikuModel. + + Example use : + >>> self.get_named_weights() = {'linear': {'w': None, 'b': None}} + Using a function as input + >>> criterion = lambda layer, name, value: name == 'w' + >>> self.set_trainable_weights(criterion) + >>> self._trainable + [0] + Using a dictionnary or pytree + >>> criterion = {'linear': {'b': None}} + >>> self.set_trainable_weights(criterion) + >>> self._trainable + [1] + + Note : model needs to be initialized + + Arguments + -------- + criterion : Callable or dict(str,dict(str,any)) or list(int) + Criterion to be used to identify trainable params. If Callable, + must be a function taking in the name of the module (e.g. + layer name), the element name (e.g. parameter name) and the + corresponding data and returning a boolean. See + [the haiku doc](https://tinyurl.com/3v28upaz) + for details. If a list of integers, should represent the index of + trainable parameters in the parameter tree leaves. If a dict, + should be formatted as a pytree. + + """ + if not self._initialized: + raise ValueError("Model needs to be initialized first") + if ( + isinstance(criterion, list) + and all(isinstance(c, str) for c in criterion) + and all(c in self._pnames for c in criterion) + ): + self._trainable = criterion + else: + self._trainable = [] # reset if needed + if inspect.isfunction(criterion): + include_fn = ( + criterion + ) # type: Callable[[str, str, jax.Array], bool] + elif isinstance(criterion, dict): + include_fn = self._build_include_fn(criterion) + else: + raise TypeError( + "The provided criterion does not conform " + "to the expected format and or type." + ) + gen = hk.data_structures.traverse(self._params) + for layer, name, value in gen: + if include_fn(layer, name, value): + self._trainable.append(f"{layer}:{name}") + + @staticmethod + def _build_include_fn( + node_names: Dict[str, Dict[str, Any]], + ) -> Callable[[str, str, jax.Array], bool]: + """Build an equality-checking function for parameters' traversal.""" + + def include_fn(layer: str, name: str, value: jax.Array) -> bool: + # mandatory signature; pylint: disable=unused-argument + if layer in list(node_names.keys()): + return name in list(node_names[layer].keys()) + return False + + return include_fn + + def get_weight_names( + self, + trainable: bool = False, + ) -> List[str]: + """Return the list of names of the model's weights. + + Parameters + ---------- + trainable: bool + Whether to return only the names of trainable weights, + rather than including both trainable and frozen ones. + + Returns + ------- + names: + Ordered list of model weights' names. + """ + return self._trainable.copy() if trainable else self._pnames.copy() + + def compute_batch_gradients( + self, + batch: Batch, + max_norm: Optional[float] = None, + ) -> JaxNumpyVector: + # Unpack input batch and prepare model parameters. + inputs = self._unpack_batch(batch) + train_params, fixed_params = hk.data_structures.partition( + predicate=lambda l, w, _: f"{l}:{w}" in self._trainable, + structure=self._params, + ) + rng = next(self._rng_gen) + # Compute batch-averaged gradients, opt. clipped on a per-sample basis. + if max_norm: + grads = self._clipped_grad_fn( + train_params, fixed_params, rng, inputs, max_norm + ) + grads = [value.mean(0) for value in grads] + else: + grads = jax.tree_util.tree_leaves( + self._grad_fn(train_params, fixed_params, rng, inputs) + ) + # Return the gradients, flattened into a JaxNumpyVector container. + return JaxNumpyVector(dict(zip(self._trainable, grads))) + + @functools.cached_property + def _grad_fn( + self, + ) -> Callable[[hk.Params, hk.Params, jax.Array, JaxBatch], hk.Params]: + """Lazy-built jax function to compute batch-averaged gradients.""" + return jax.jit(jax.grad(self._forward)) + + def _forward( + self, + train_params: hk.Params, + fixed_params: hk.Params, + rng: jax.Array, + batch: JaxBatch, + ) -> jax.Array: + """The forward pass chaining the model to the loss as a pure function. + + Parameters + ------- + params: haiku.Params + The model parameters, as a nested dict of jax arrays. + rng: jax.Array + A jax pseudo-random number generator (PRNG) key. + batch: (list[jax.Array], jax.Array, optional[jax.Array]) + Batch of jax-converted inputs, comprising (in that order) + input data, ground-truth labels and optional sample weights. + + Returns + ------- + loss: jax.Array + The mean loss over the input data provided. + """ + inputs, y_true, s_wght = batch + params = hk.data_structures.merge(train_params, fixed_params) + y_pred = self._transformed_model.apply(params, rng, *inputs) + s_loss = self._loss_fn(y_pred, y_true) # type: ignore + if s_wght is not None: + s_loss = s_loss * s_wght + return jnp.mean(s_loss) + + @functools.cached_property + def _clipped_grad_fn( + self, + ) -> Callable[ + [hk.Params, hk.Params, jax.Array, JaxBatch, float], List[jax.Array] + ]: + """Lazy-built jax function to compute clipped sample-wise gradients. + + Note : The vmap in_axis parameters work thank to the jax feature of + applying optional parameters to pytrees. + """ + + def clipped_grad_fn( + train_params: hk.Params, + fixed_params: hk.Params, + rng: jax.Array, + batch: JaxBatch, + max_norm: float, + ) -> List[jax.Array]: + """Compute and clip gradients wrt parameters for a sample.""" + inputs, y_true, s_wght = batch + batch = (inputs, y_true, None) + grads = jax.grad(self._forward)( + train_params, fixed_params, rng, batch + ) + grads_flat = [ + grad / jnp.maximum(jnp.linalg.norm(grad) / max_norm, 1.0) + for grad in jax.tree_util.tree_leaves(grads) + ] + if s_wght is not None: + grads_flat = [g * s_wght for g in grads_flat] + return grads_flat + + in_axes = [None, None, None, 0, None] # map on inputs' first dimension + return jax.jit(jax.vmap(clipped_grad_fn, in_axes)) + + @staticmethod + def _unpack_batch(batch: Batch) -> JaxBatch: + """Unpack and enforce jnp.array conversion to an input data batch.""" + + def convert(data: Any) -> Optional[jax.Array]: + if (data is None) or isinstance(data, jax.Array): + return data + if isinstance(data, np.ndarray): + return jnp.array(data) # pylint: disable=no-member + raise TypeError("HaikuModel requires numpy or jax.numpy data.") + + # similar code to TorchModel; pylint: disable=duplicate-code + # Convert batched data to jax Arrays. + inputs, y_true, s_wght = batch + if not isinstance(inputs, (tuple, list)): + inputs = [inputs] + output = [list(map(convert, inputs)), convert(y_true), convert(s_wght)] + return output # type: ignore + + def apply_updates( # type: ignore # Vector subtype specification + self, + updates: JaxNumpyVector, + ) -> None: + if not isinstance(updates, JaxNumpyVector): + raise TypeError("HaikuModel requires JaxNumpyVector updates.") + self._verify_weights_compatibility(updates, trainable=True) + for key, val in updates.coefs.items(): + layer, weight = key.rsplit(":", 1) + self._params[layer][weight] += val # type: ignore + + def compute_batch_predictions( + self, + batch: Batch, + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: + inputs, y_true, s_wght = self._unpack_batch(batch) + if y_true is None: + raise TypeError( + "`HaikuModel.compute_batch_predictions` received a " + "batch with `y_true=None`, which is unsupported. Please " + "correct the inputs, or override this method to support " + "creating labels from the base inputs." + ) + y_pred = self._transformed_model.apply( + self._params, next(self._rng_gen), *inputs + ) + return ( + np.asarray(y_true), + np.asarray(y_pred), + None if s_wght is None else np.asarray(s_wght), + ) + + def loss_function( + self, + y_true: np.ndarray, + y_pred: np.ndarray, + ) -> np.ndarray: + s_loss = self._loss_fn(jnp.array(y_pred), jnp.array(y_true)) + return np.array(s_loss).squeeze() + + def update_device_policy( + self, + policy: Optional[DevicePolicy] = None, + ) -> None: + # similar code to tensorflow Model; pylint: disable=duplicate-code + # Select the device to use based on the provided or global policy. + if policy is None: + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + # When needed, re-create the model to force moving it to the device. + if self._device is not device: + self._device = device + self._params = jax.device_put(self._params, self._device) diff --git a/declearn/model/haiku/_vector.py b/declearn/model/haiku/_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..74ac8cafe77c12898d9e56b1da7166da51163a65 --- /dev/null +++ b/declearn/model/haiku/_vector.py @@ -0,0 +1,170 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JaxNumpyVector data arrays container.""" + +from typing import Any, Callable, Dict, Optional, Set, Type + +import jax +import jax.numpy as jnp +import numpy as np +from jax.config import config as jaxconfig +from typing_extensions import Self # future: import from typing (Py>=3.11) + +from declearn.model.api import Vector, register_vector_type +from declearn.model.haiku.utils import select_device +from declearn.model.sklearn import NumpyVector +from declearn.utils import get_device_policy + +__all__ = [ + "JaxNumpyVector", +] + + +jaxconfig.update("jax_enable_x64", True) # enable float64 support + + +@register_vector_type(jax.Array) +class JaxNumpyVector(Vector): + """Vector subclass to store jax.numpy.ndarray coefficients. + + This Vector is designed to store a collection of named + jax numpy arrays or scalars, enabling computations that are + either applied to each and every coefficient, or imply + two sets of aligned coefficients (i.e. two JaxNumpyVector + instances with similar coefficients specifications). + + Use `vector.coefs` to access the stored coefficients. + + Notes + ----- + - A `JaxnumpyVector` can be operated with either a: + - scalar value + - `NumpyVector` that has similar specifications + - `JaxNumpyVector` that has similar specifications + => resulting in a `JaxNumpyVector` in each of these cases. + - The wrapped arrays may be placed on any device (CPU, GPU...) + and may not be all on the same device. + - The device-placement of the initial `JaxNumpyVector`'s data + is preserved by operations, including with `NumpyVector`. + - When combining two `JaxNumpyVector`, the device-placement + of the left-most one is used; in that case, one ends up with + `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a + warning will be emitted to prevent silent un-optimized copies. + - When deserializing a `JaxNumpyVector` (either by directly using + `JaxNumpyVector.unpack` or loading one from a JSON dump), loaded + arrays are placed based on the global device-placement policy + (accessed via `declearn.utils.get_device_policy`). Thus it may + have a different device-placement schema than at dump time but + should be coherent with that of `HaikuModel` computations. + """ + + @property + def _op_add(self) -> Callable[[Any, Any], Any]: + return jnp.add + + @property + def _op_sub(self) -> Callable[[Any, Any], Any]: + return jnp.subtract + + @property + def _op_mul(self) -> Callable[[Any, Any], Any]: + return jnp.multiply + + @property + def _op_div(self) -> Callable[[Any, Any], Any]: + return jnp.divide + + @property + def _op_pow(self) -> Callable[[Any, Any], Any]: + return jnp.power + + @property + def compatible_vector_types(self) -> Set[Type[Vector]]: + types = super().compatible_vector_types + return types.union({NumpyVector, JaxNumpyVector}) + + def __init__(self, coefs: Dict[str, jax.Array]) -> None: + super().__init__(coefs) + + def _apply_operation( + self, + other: Any, + func: Callable[[Any, Any], Any], + ) -> Self: + # Ensure 'other' JaxNumpyVector shares this vector's device placement. + if isinstance(other, JaxNumpyVector): + coefs = { + key: jax.device_put(val, self.coefs[key].device()) + for key, val in other.coefs.items() + } + other = JaxNumpyVector(coefs) + return super()._apply_operation(other, func) + + def __eq__(self, other: Any) -> bool: + valid = isinstance(other, JaxNumpyVector) + valid = valid and (self.coefs.keys() == other.coefs.keys()) + return valid and all( + jnp.array_equal(self.coefs[k], other.coefs[k]) for k in self.coefs + ) + + def sign( + self, + ) -> Self: + return self.apply_func(jnp.sign) + + def minimum( + self, + other: Any, + ) -> Self: + if isinstance(other, JaxNumpyVector): + return self._apply_operation(other, jnp.minimum) + return self.apply_func(jnp.minimum, other) + + def maximum( + self, + other: Any, + ) -> Self: + if isinstance(other, Vector): + return self._apply_operation(other, jnp.maximum) + return self.apply_func(jnp.maximum, other) + + def sum( + self, + axis: Optional[int] = None, + keepdims: bool = False, + ) -> Self: + coefs = { + key: jnp.array(jnp.sum(val, axis=axis, keepdims=keepdims)) + for key, val in self.coefs.items() + } + return self.__class__(coefs) + + def pack( + self, + ) -> Dict[str, Any]: + return {key: np.asarray(arr) for key, arr in self.coefs.items()} + + @classmethod + def unpack( + cls, + data: Dict[str, Any], + ) -> Self: + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + coefs = {k: jax.device_put(arr, device) for k, arr in data.items()} + return cls(coefs) diff --git a/declearn/model/haiku/utils/__init__.py b/declearn/model/haiku/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce83acaf4aaaed40a31a1713fc967d190abddc2 --- /dev/null +++ b/declearn/model/haiku/utils/__init__.py @@ -0,0 +1,25 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for jax backend support code. + +GPU/CPU backing device management utils: +* select_device: + Select a backing device to use based on inputs and availability. +""" + +from ._gpu import select_device diff --git a/declearn/model/haiku/utils/_gpu.py b/declearn/model/haiku/utils/_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..1c79288e8d234c480e2f839ab026f7181a572aac --- /dev/null +++ b/declearn/model/haiku/utils/_gpu.py @@ -0,0 +1,76 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for GPU support and device management in jax.""" + +import warnings +from typing import Optional + +import jax +import jaxlib.xla_extension as xe + +__all__ = ["select_device"] + + +def select_device( + gpu: bool, + idx: Optional[int] = None, +) -> xe.Device: # pylint: disable=c-extension-no-member + """Select a backing device to use based on inputs and availability. + + Parameters + ---------- + gpu: bool + Whether to select a GPU device rather than the CPU one. + idx: int or None, default=None + Optional pre-selected device index. Only used when `gpu=True`. + If `idx is None` or exceeds the number of available GPU devices, + use the first available one. + + Warns + ----- + UserWarning: + If `gpu=True` but no GPU is available. + If `idx` exceeds the number of available GPU devices. + + Returns + ------- + device: jaxlib.xla_extension.Device + Selected device. + """ + idx = 0 if idx is None else idx + # List available CPU or GPU devices. + device_type = "gpu" if gpu else "cpu" + devices = [d for d in jax.devices() if d.platform == device_type] + # Case when no GPU is available: warn and use a CPU instead. + if gpu and not devices: + warnings.warn( + "Cannot use a GPU device: either CUDA is unavailable " + "or no GPU is visible to jax." + ) + device_type, idx = "cpu", 0 + devices = jax.devices(device_type) + # similar code to tensorflow util; pylint: disable=duplicate-code + # Case when the desired device index is invalid: select another one. + if idx >= len(devices): + warnings.warn( + f"Cannot use {device_type} device n°{idx}: index is out-of-range." + f"\nUsing {device_type} device n°0 instead." + ) + idx = 0 + # Return the selected device. + return devices[idx] diff --git a/declearn/model/torch/__init__.py b/declearn/model/torch/__init__.py index 1f2618b5ab095c45b3fb2b247d6e0b336a970611..ff0f3cead89b7f6755cb80d87814d23a36b67362 100644 --- a/declearn/model/torch/__init__.py +++ b/declearn/model/torch/__init__.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tensorflow models interfacing tools. +"""Pytorch models interfacing tools. This submodule provides with a generic interface to wrap up any PyTorch `torch.nn.Module` instance that is to be trained with gradient descent. diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py index fdc53dd79d6161b8de6f7d0d0ed004b3954ecea0..e0f37b873848748fc87b22b7f23f62d2cf9361ed 100644 --- a/declearn/test_utils/_vectors.py +++ b/declearn/test_utils/_vectors.py @@ -35,7 +35,7 @@ __all__ = [ ] -FrameworkType = Literal["numpy", "tensorflow", "torch"] +FrameworkType = Literal["numpy", "tensorflow", "torch", "jax"] def list_available_frameworks() -> List[FrameworkType]: @@ -81,6 +81,9 @@ class GradientsTestCase: if self.framework == "torch": module = importlib.import_module("declearn.model.torch") return module.TorchVector + if self.framework == "jax": + module = importlib.import_module("declearn.model.haiku") + return module.JaxNumpyVector raise ValueError(f"Invalid framework '{self.framework}'") def convert(self, array: np.ndarray) -> ArrayLike: @@ -94,12 +97,17 @@ class GradientsTestCase: if self.framework == "torch": torch = importlib.import_module("torch") return torch.from_numpy(array) + if self.framework == "jax": + jnp = importlib.import_module("jax.numpy") + return jnp.asarray(array) raise ValueError(f"Invalid framework '{self.framework}'") def to_numpy(self, array: ArrayLike) -> np.ndarray: """Convert an input framework-based structure to a numpy array.""" if isinstance(array, np.ndarray): return array + if self.framework == "jax": + return np.asarray(array) if self.framework == "tensorflow": # add support for IndexedSlices tensorflow = importlib.import_module("tensorflow") if isinstance(array, tensorflow.IndexedSlices): diff --git a/pyproject.toml b/pyproject.toml index 70ca5ffd51716408a94016d3818cbb7610ae0c45..e486430064c298cb3162945bd694a790e711753e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,10 +46,12 @@ dependencies = [ ] [project.optional-dependencies] -# all non-docs, non-tests extra dependencies -all = [ +all = [ # all non-tests extra dependencies + "dm-haiku == 0.0.9", + "jax == 0.4.4", "functorch", "grpcio >= 1.45", + "jax[cpu] == 0.4.4", "opacus ~= 1.1", "protobuf >= 3.19", "tensorflow ~= 2.5", @@ -64,6 +66,10 @@ grpc = [ "grpcio >= 1.45", "protobuf >= 3.19", ] +haiku = [ + "jax == 0.4.4", + "dm-haiku == 0.0.9", +] tensorflow = [ "tensorflow ~= 2.5", ] diff --git a/test/functional/test_regression.py b/test/functional/test_regression.py index e065e364f4618f0b20ceaf000c3fb8ca759460aa..66f41ecd47605246e6a477980de65ba0a319399e 100644 --- a/test/functional/test_regression.py +++ b/test/functional/test_regression.py @@ -62,24 +62,44 @@ from declearn.test_utils import FrameworkType from declearn.utils import run_as_processes from declearn.utils import set_device_policy -# pylint: disable=ungrouped-imports; optional frameworks' dependencies +# optional frameworks' dependencies pylint: disable=ungrouped-imports +# pylint: disable=duplicate-code +# tensorflow imports try: import tensorflow as tf # type: ignore - from declearn.model.tensorflow import TensorflowModel except ModuleNotFoundError: pass +else: + from declearn.model.tensorflow import TensorflowModel +# torch imports try: import torch +except ModuleNotFoundError: + pass +else: from declearn.model.torch import TorchModel +# pylint: enable=duplicate-code +# haiku and jax imports +try: + import haiku as hk + import jax except ModuleNotFoundError: pass +else: + from declearn.model.haiku import HaikuModel + + def haiku_model_fn(inputs: jax.Array) -> jax.Array: + """Simple linear model implemented with Haiku.""" + return hk.Linear(1)(inputs) + + def haiku_loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array: + """Sample-wise squared error loss function.""" + return (y_pred - y_true) ** 2 SEED = 0 R2_THRESHOLD = 0.999 -# pylint: disable=too-many-function-args - def get_model(framework: FrameworkType) -> Model: """Set up a simple toy regression model.""" @@ -101,6 +121,8 @@ def get_model(framework: FrameworkType) -> Model: torch.nn.Flatten(0), ) model = TorchModel(torchmod, loss=torch.nn.MSELoss()) + elif framework == "jax": + model = HaikuModel(haiku_model_fn, loss=haiku_loss_fn) else: raise ValueError("unrecognised framework") return model @@ -245,7 +267,9 @@ def _server_routine( # pylint: disable=too-many-arguments # Set up the FederatedServer. model = get_model(framework) - netwk = NetworkServerConfig("websockets", "127.0.0.1", 8765) + netwk = NetworkServerConfig.from_params( + protocol="websockets", host="127.0.0.1", port=8765 + ) optim = FLOptimConfig.from_params( aggregator="averaging", client_opt={ @@ -254,7 +278,6 @@ def _server_routine( }, server_opt=1.0, ) - server = FederatedServer( model, netwk, @@ -281,8 +304,9 @@ def _client_routine( name: str = "client", ) -> None: """Routine to run a FL client, called by `run_declearn_experiment`.""" - # Run the declearn FL client routine. - netwk = NetworkClientConfig("websockets", "ws://localhost:8765", name) + netwk = NetworkClientConfig.from_params( + protocol="websockets", server_uri="ws://localhost:8765", name=name + ) client = FederatedClient(netwk, train, valid) client.run() diff --git a/test/model/test_haiku.py b/test/model/test_haiku.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b8a2c6bf0dcdfca20dbaee05ad63c0379222b5 --- /dev/null +++ b/test/model/test_haiku.py @@ -0,0 +1,295 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for HaikuModel.""" + +import sys +from typing import Any, Callable, Dict, List, Literal, Union + +import numpy as np +import pytest + +try: + import haiku as hk + import jax + import jax.numpy as jnp + from jax.config import config as jaxconfig +except ModuleNotFoundError: + pytest.skip("jax and/or haiku are unavailable", allow_module_level=True) + +from declearn.model.haiku import HaikuModel, JaxNumpyVector +from declearn.typing import Batch +from declearn.utils import set_device_policy + +# dirty trick to import from `model_testing.py`; +# pylint: disable=wrong-import-order, wrong-import-position +sys.path.append(".") +from model_testing import ModelTestCase, ModelTestSuite + +# Overriding float32 default in jax +jaxconfig.update("jax_enable_x64", True) + + +def cnn_fn(inputs: jax.Array) -> jax.Array: + """Simple CNN in a purely functional form""" + stack = [ + hk.Conv2D(output_channels=32, kernel_shape=(7, 7), padding="SAME"), + jax.nn.relu, + hk.MaxPool(window_shape=(8, 8, 1), strides=(8, 8, 1), padding="VALID"), + hk.Conv2D(output_channels=16, kernel_shape=(5, 5), padding="SAME"), + jax.nn.relu, + hk.AvgPool(window_shape=(8, 8, 1), strides=(8, 8, 1), padding="VALID"), + hk.Reshape((16,)), + hk.Linear(1), + ] + model = hk.Sequential(stack) # type: ignore + return model(inputs) + + +def mlp_fn(inputs: jax.Array) -> jax.Array: + """Simple MLP in a purely functional form""" + model = hk.nets.MLP([32, 16, 1]) + return model(inputs) + + +def rnn_fn(inputs: jax.Array) -> jax.Array: + """Simple RNN in a purely functional form""" + inputs = inputs[None, :] if len(inputs.shape) == 1 else inputs + core = hk.DeepRNN( + [ + hk.Embed(100, 32), + hk.LSTM(32), + jax.nn.tanh, + ] + ) + batch_size = inputs.shape[0] + initial_state = core.initial_state(batch_size) + logits, _ = hk.dynamic_unroll( + core, inputs, initial_state, time_major=False + ) + return hk.Linear(1)(logits)[:, -1, :] + + +def loss_fn(y_pred: jax.Array, y_true: jax.Array) -> jax.Array: + """Per-sample binary cross entropy""" + y_pred = jax.nn.sigmoid(y_pred) + y_pred = jnp.squeeze(y_pred) + log_p, log_not_p = jnp.log(y_pred), jnp.log(1.0 - y_pred) + return -y_true * log_p - (1.0 - y_true) * log_not_p + + +class HaikuTestCase(ModelTestCase): + """Tensorflow Keras test-case-provider fixture. + + Implemented architectures are: + * "MLP": + - input: 64-dimensional features vectors + - stack: 32-neurons fully-connected layer with ReLU + 16-neurons fully-connected layer with ReLU + 1 output neuron with sigmoid activation + * "RNN": + - input: 128-tokens-sequence in a 100-tokens-vocabulary + - stack: 32-dimensional embedding matrix + 16-neurons LSTM layer with tanh activation + 1 output neuron with sigmoid activation + * "CNN": + - input: 64x64 image with 3 channels (normalized values) + - stack: 32 7x7 conv. filters, then 8x8 max pooling + 16 5x5 conv. filters, then 8x8 avg pooling + 1 output neuron with sigmoid activation + """ + + vector_cls = JaxNumpyVector + tensor_cls = jax.Array + + def __init__( + self, + kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], + device: Literal["cpu", "gpu"], + ) -> None: + """Specify the desired model architecture.""" + if kind not in ("MLP", "MLP-tune", "RNN", "CNN"): + raise ValueError(f"Invalid test architecture: '{kind}'.") + if device not in ("cpu", "gpu"): + raise ValueError(f"Invalid device choice for test: '{device}'.") + self.kind = kind + self.device = device + set_device_policy(gpu=(device == "gpu"), idx=0) + + @staticmethod + def to_numpy( + tensor: Any, + ) -> np.ndarray: + """Convert an input jax jax.Array to a numpy jax.Array.""" + assert isinstance(tensor, jax.Array) + return np.asarray(tensor) + + @property + def dataset( + self, + ) -> List[Batch]: + """Suited toy binary-classification dataset.""" + rng = np.random.default_rng(seed=0) + if self.kind.startswith("MLP"): + inputs = rng.normal(size=(2, 32, 64)).astype("float32") + elif self.kind == "RNN": + inputs = rng.choice(100, size=(2, 32, 128)) + elif self.kind == "CNN": + inputs = rng.normal(size=(2, 32, 64, 64, 3)).astype("float32") + labels = rng.choice(2, size=(2, 32)) + inputs = jnp.asarray(inputs) # type: ignore + labels = jnp.asarray(labels) # type: ignore + batches = list(zip(inputs, labels, [None, None])) + return batches # type: ignore + + @property + def model(self) -> HaikuModel: + """Suited toy binary-classification haiku models.""" + if self.kind == "CNN": + shape = [64, 64, 3] + model_fn = cnn_fn + elif self.kind.startswith("MLP"): + shape = [64] + model_fn = mlp_fn + elif self.kind == "RNN": + shape = [128] + model_fn = rnn_fn + model = HaikuModel(model_fn, loss_fn) + model.initialize( + { + "features_shape": shape, + "data_type": "int" if self.kind == "RNN" else "float32", + } + ) + if self.kind == "MLP-tune": + names = model.get_weight_names() + model.set_trainable_weights([names[i] for i in range(3)]) + return model + + def assert_correct_device( + self, + vector: JaxNumpyVector, + ) -> None: + """Raise if a vector is backed on the wrong type of device.""" + name = f"{self.device}:0" + assert all( + f"{arr.device().platform}:{arr.device().id}" == name + for arr in vector.coefs.values() + ) + + def get_trainable_criterion( + self, + c_type: str, + ) -> Union[ + List[str], + Dict[str, Dict[str, Any]], + Callable[[str, str, jax.Array], bool], + ]: + "Build different weight freezing criteria" + if c_type == "names": + names = self.model.get_weight_names() + return [names[2], names[3]] + if c_type == "pytree": + params = getattr(self.model, "_params") + return {k: v for i, (k, v) in enumerate(params.items()) if i != 1} + if c_type == "predicate": + return lambda m, n, p: n != "b" + raise KeyError(f"Invalid 'c_type' parameter: {c_type}.") + + +@pytest.fixture(name="test_case") +def fixture_test_case( + kind: Literal["MLP", "MLP-tune", "RNN", "CNN"], + device: Literal["cpu", "gpu"], +) -> HaikuTestCase: + """Fixture to access a TensorflowTestCase.""" + return HaikuTestCase(kind, device) + + +DEVICES = ["cpu"] +if any(d.platform == "gpu" for d in jax.devices()): + DEVICES.append("gpu") + + +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kind", ["MLP", "MLP-tune", "RNN", "CNN"]) +class TestHaikuModel(ModelTestSuite): + """Unit tests for declearn.model.tensorflow.TensorflowModel.""" + + @pytest.mark.filterwarnings("ignore: Our custom Haiku serialization") + def test_serialization( + self, + test_case: ModelTestCase, + ) -> None: + super().test_serialization(test_case) + + @pytest.mark.parametrize( + "criterion_type", ["names", "pytree", "predicate"] + ) + def test_get_frozen_weights( + self, + test_case: HaikuTestCase, + criterion_type: str, + ) -> None: + """Check that `get_weights` behaves properly with frozen weights.""" + model = test_case.model # type: HaikuModel + criterion = test_case.get_trainable_criterion(criterion_type) + model.set_trainable_weights(criterion) # freeze some weights + w_all = model.get_weights() + w_trn = model.get_weights(trainable=True) + assert set(w_trn.coefs).issubset(w_all.coefs) # check on keys + n_params = len(model.get_weight_names()) + n_trainable = len(model.get_weight_names(trainable=True)) + assert n_trainable < n_params + assert len(w_trn.coefs) == n_trainable + assert len(w_all.coefs) == n_params + + @pytest.mark.parametrize( + "criterion_type", ["names", "pytree", "predicate"] + ) + def test_set_frozen_weights( + self, + test_case: HaikuTestCase, + criterion_type: str, + ) -> None: + """Check that `set_weights` behaves properly with frozen weights.""" + # similar code to TorchModel tests; pylint: disable=duplicate-code + # Setup a model with some frozen weights, and gather trainable ones. + model = test_case.model + criterion = test_case.get_trainable_criterion(criterion_type) + model.set_trainable_weights(criterion) # freeze some weights + w_trn = model.get_weights(trainable=True) + # Test that `set_weights` works if and only if properly parametrized. + with pytest.raises(KeyError): + model.set_weights(w_trn) + with pytest.raises(KeyError): + model.set_weights(model.get_weights(), trainable=True) + model.set_weights(w_trn, trainable=True) + + def test_proper_model_placement( + self, + test_case: HaikuTestCase, + ) -> None: + """Check that at instantiation, model weights are properly placed.""" + model = test_case.model + policy = model.device_policy + assert policy.gpu == (test_case.device == "gpu") + assert policy.idx == 0 + params = jax.tree_util.tree_leaves(getattr(model, "_params")) + device = f"{test_case.device}:0" + for arr in params: + assert f"{arr.device().platform}:{arr.device().id}" == device diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index 2e89fb4c71f1f129a2ed2ab3d7d7eb014640bc14..a6da92e767b7db9760357a00777c46116d13e577 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -17,8 +17,8 @@ """Unit tests for TensorflowModel.""" -import warnings import sys +import warnings from typing import Any, List, Literal import numpy as np @@ -38,7 +38,7 @@ from declearn.utils import set_device_policy # dirty trick to import from `model_testing.py`; # pylint: disable=wrong-import-order, wrong-import-position sys.path.append(".") -from model_testing import ModelTestSuite, ModelTestCase +from model_testing import ModelTestCase, ModelTestSuite class TensorflowTestCase(ModelTestCase): @@ -51,7 +51,7 @@ class TensorflowTestCase(ModelTestCase): 16-neurons fully-connected layer with ReLU 1 output neuron with sigmoid activation * "MLP-tune": - - same as NLP, but freeze the first layer of the stack + - same as MLP, but freeze the first layer of the stack * "RNN": - input: 128-tokens-sequence in a 100-tokens-vocabulary - stack: 32-dimensional embedding matrix