diff --git a/README.md b/README.md index 79f103078166eee32305da4ebf2c7d083a81e3d8..b638674b1b558ad4258255087de58892729d0753 100644 --- a/README.md +++ b/README.md @@ -227,6 +227,38 @@ client = declearn.main.FederatedClient(netwk, train, valid, checkpoint="outputs" client.run() ``` +### Support for GPU acceleration + +TL;DR: GPU acceleration is natively available in `declearn` for model +frameworks that support it, with one line of code and without changing +your original model. + +Details: + +Most machine learning frameworks, including Tensorflow and Torch, enable +accelerating computations by using computational devices other than CPU. +`declearn` interfaces supported frameworks to be able to set a device policy +in a single line of code, accross frameworks. + +`declearn` internalizes the framework-specific code adaptations to place the +data, model weights and computations on such a device. `declearn` provides +with a simple API to define a global device policy. This enables using a +single GPU to accelerate computations, or forcing the use of a CPU. + +By default, the policy is set to use the first available GPU, and otherwise +use the CPU, with a warning that can safely be ignored. + +Setting the device policy to be used can be done in local scripts, either as a +client or as a server. Device policy is local and is not synchronized between +federated learninng participants. + +Here are some examples of the one-liner used: +```python +declearn.utils.set_device_policy(gpu=False) # disable GPU use +declearn.utils.set_device_policy(gpu=True) # use any available GPU +declearn.utils.set_device_policy(gpu=True, idx=2) # specifically use GPU n°2 +``` + ### Note on dependency sharing One important issue however that is not handled by declearn itself is that diff --git a/declearn/model/api/_model.py b/declearn/model/api/_model.py index 6bd5fa081494f3862affba1fdbf95b2919228ebd..41d2c9abd09a0530c2b9b0b9629acd560467a848 100644 --- a/declearn/model/api/_model.py +++ b/declearn/model/api/_model.py @@ -25,7 +25,7 @@ from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.model.api._vector import Vector from declearn.typing import Batch -from declearn.utils import create_types_registry +from declearn.utils import DevicePolicy, create_types_registry __all__ = [ @@ -43,6 +43,13 @@ class Model(metaclass=ABCMeta): writing algorithms and operations agnostic to the framework in which the underlying model is implemented (e.g. PyTorch, TensorFlow, Scikit-Learn...). + + Device-placement (i.e. running computations on CPU or GPU) + is also handled as part of Model classes' backend, mapping + the generic `declearn.utils.DevicePolicy` parameters to any + required framework-specific instruction to adequately pick + the device to use and ensure the wrapped model, input data + and interfaced computations are placed there. """ def __init__( @@ -52,6 +59,13 @@ class Model(metaclass=ABCMeta): """Instantiate a Model interface wrapping a 'model' object.""" self._model = model + @property + @abstractmethod + def device_policy( + self, + ) -> DevicePolicy: + """Return the device-placement policy currently used by this model.""" + @property @abstractmethod def required_data_info( @@ -220,3 +234,27 @@ class Model(metaclass=ABCMeta): s_loss: np.ndarray Sample-wise loss values, as a 1-d numpy array. """ + + @abstractmethod + def update_device_policy( + self, + policy: Optional[DevicePolicy] = None, + ) -> None: + """Update the device-placement policy of this model. + + This method is designed to be called after a change in the global + device-placement policy (e.g. to disable using a GPU, or move to + a specific one), so as to place pre-existing Model instances and + avoid policy inconsistencies that might cause repeated memory or + runtime costs from moving data or weights around each time they + are used. You should otherwise not worry about a Model's device- + placement, as it is handled at instantiation based on the global + device policy (see `declearn.utils.set_device_policy`). + + Parameters + ---------- + policy: DevicePolicy or None, default=None + Optional DevicePolicy dataclass instance to be used. + If None, use the global device policy, accessed via + `declearn.utils.get_device_policy`. + """ diff --git a/declearn/model/api/_vector.py b/declearn/model/api/_vector.py index c55986773c830b5d0ddaecdee596879ee1efdfb5..fc6fdc78b15ddc5eb406b10d3ffef50d50ec2857 100644 --- a/declearn/model/api/_vector.py +++ b/declearn/model/api/_vector.py @@ -315,10 +315,7 @@ class Vector(metaclass=ABCMeta): return type(self)(coefs) # Case when the two vectors have incompatible types. if isinstance(other, Vector): - raise TypeError( - f"Cannot {func.__name__} {type(self).__name__} object with " - f"a vector of incompatible type {type(other).__name__}." - ) + return NotImplemented # Case when operating with another object (e.g. a scalar). try: return type(self)( diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py index 5903f7dd91ce4d7707d22511645899fd047d2b2e..65ac351d93904c3375f6d2244f8b271f3a067eb1 100644 --- a/declearn/model/sklearn/_np_vec.py +++ b/declearn/model/sklearn/_np_vec.py @@ -41,6 +41,19 @@ class NumpyVector(Vector): instances with similar coefficients specifications). Use `vector.coefs` to access the stored coefficients. + + Notes + ----- + - A `NumpyVector` can be operated with either a scalar value, + or another `NumpyVector` that has similar specifications + (same coefficient names, shapes and compatible dtypes). + - Some other `Vector` classes might be made compatible with + `NumpyVector`; in that case, operating with a `NumpyVector` + will always result in a vector of the other type. This is + notably the case with `TensorflowVector` and `TorchVector`. + - There is currently no support for GPU-acceleration with the + `NumpyVector` class, that only handles arrays and operations + placed on a CPU device. """ @property diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py index eabc2883ed169072bae0e31219082ec72703d1c2..28abc4913bf30b1f3e4899c41b7c3555f71394d2 100644 --- a/declearn/model/sklearn/_sgd.py +++ b/declearn/model/sklearn/_sgd.py @@ -18,6 +18,7 @@ """Model subclass to wrap scikit-learn SGD classifier and regressor models.""" import typing +import warnings from typing import Any, Callable, Dict, Literal, Optional, Set, Tuple, Union import numpy as np @@ -29,7 +30,7 @@ from declearn.data_info import aggregate_data_info from declearn.model.api import Model from declearn.model.sklearn._np_vec import NumpyVector from declearn.typing import Batch -from declearn.utils import register_type +from declearn.utils import DevicePolicy, register_type __all__ = [ @@ -63,6 +64,13 @@ class SklearnSGDModel(Model): This `Model` subclass is designed to wrap a `SGDClassifier` or `SGDRegressor` instance (from `sklearn.linear_model`) to be learned federatively. + + Notes regarding device management (CPU, GPU, etc.): + * This Model may only run on CPU, and is unaffected by device- + management policies. + * Calling the `update_device_policy` method has no effect, and + raises a UserWarning if a GPU-targetting policy is passed to + it directly. """ def __init__( @@ -104,6 +112,12 @@ class SklearnSGDModel(Model): None ) # type: Optional[Callable[[np.ndarray, np.ndarray], np.ndarray]] + @property + def device_policy( + self, + ) -> DevicePolicy: + return DevicePolicy(gpu=False, idx=None) + @property def required_data_info( self, @@ -382,3 +396,10 @@ class SklearnSGDModel(Model): else: loss_fn = loss_1d return loss_fn + + def update_device_policy( + self, + policy: Optional[DevicePolicy] = None, + ) -> None: + if policy is not None and policy.gpu: + warnings.warn("'SklearnSGDModel' only runs on a CPU backend.") diff --git a/declearn/model/tensorflow/__init__.py b/declearn/model/tensorflow/__init__.py index ff94c495b4647614cc5f9e9e44b0a9d636f3c3d3..a8db61f23c5e09cb1a25b2c2abdf70a9128d9972 100644 --- a/declearn/model/tensorflow/__init__.py +++ b/declearn/model/tensorflow/__init__.py @@ -24,7 +24,11 @@ through gradient descent. This module exposes: * TensorflowModel: Model subclass to wrap tensorflow.keras.Model objects * TensorflowVector: Vector subclass to wrap tensorflow.Tensor objects + +It also exposes the `utils` submodule, which mainly aims at +providing tools used in the backend of the former objects. """ +from . import utils from ._vector import TensorflowVector from ._model import TensorflowModel diff --git a/declearn/model/tensorflow/_model.py b/declearn/model/tensorflow/_model.py index fc1a9eb88dc4d42f07314276b266d20749661318..a56ff2096e0f4d50eb32c65aeb84bf268d30f9fe 100644 --- a/declearn/model/tensorflow/_model.py +++ b/declearn/model/tensorflow/_model.py @@ -28,18 +28,43 @@ from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.data_info import aggregate_data_info from declearn.model.api import Model -from declearn.model.tensorflow._utils import build_keras_loss +from declearn.model.tensorflow.utils import ( + build_keras_loss, + move_layer_to_device, + select_device, +) from declearn.model.tensorflow._vector import TensorflowVector from declearn.typing import Batch -from declearn.utils import register_type +from declearn.utils import DevicePolicy, get_device_policy, register_type + + +__all__ = [ + "TensorflowModel", +] @register_type(name="TensorflowModel", group="Model") class TensorflowModel(Model): """Model wrapper for TensorFlow Model instances. - This `Model` subclass is designed to wrap a `tf.keras.Model` - instance to be learned federatively. + This `Model` subclass is designed to wrap a `tf.keras.Model` instance + to be trained federatively. + + Notes regarding device management (CPU, GPU, etc.): + * By default, tensorflow places data and operations on GPU whenever one + is available. + * Our `TensorflowModel` instead consults the device-placement policy (via + `declearn.utils.get_device_policy`), places the wrapped keras model's + weights there, and runs computations defined under public methods in + a `tensorflow.device` context, to enforce that policy. + * Note that there is no guarantee that calling a private method directly + will result in abiding by that policy. Hence, be careful when writing + custom code, and use your own context managers to get guarantees. + * Note that if the global device-placement policy is updated, this will + only be propagated to existing instances by manually calling their + `update_device_policy` method. + * You may consult the device policy currently enforced by a TensorflowModel + instance by accessing its `device_policy` property. """ def __init__( @@ -47,6 +72,7 @@ class TensorflowModel(Model): model: tf.keras.layers.Layer, loss: Union[str, tf.keras.losses.Loss], metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None, + _from_config: bool = False, **kwargs: Any, ) -> None: """Instantiate a Model interface wrapping a tensorflow.keras model. @@ -66,7 +92,7 @@ class TensorflowModel(Model): compiled with the model and computed using the `evaluate` method of the returned TensorflowModel instance. **kwargs: Any - Any addition keyword argument to `tf.keras.Model.compile` + Any additional keyword argument to `tf.keras.Model.compile` may be passed. """ # Type-check the input Model and wrap it up. @@ -79,12 +105,30 @@ class TensorflowModel(Model): super().__init__(model) # Ensure the loss is a keras.Loss object and set its reduction to none. loss = build_keras_loss(loss, reduction=tf.keras.losses.Reduction.NONE) - # Compile the wrapped model and retain compilation arguments. - kwargs.update({"loss": loss, "metrics": metrics}) - model.compile(**kwargs) - self._kwargs = kwargs - # Instantiate a SGD optimizer to apply updates as-provided. - self._sgd = tf.keras.optimizers.SGD(learning_rate=1.0) + # Select the device where to place computations and move the model. + policy = get_device_policy() + self._device = select_device(gpu=policy.gpu, idx=policy.idx) + if not _from_config: + self._model = move_layer_to_device(self._model, self._device) + # Finalize initialization using the selected device. + with tf.device(self._device): + # Compile the wrapped model and retain compilation arguments. + kwargs.update({"loss": loss, "metrics": metrics}) + self._model.compile(**kwargs) + self._kwargs = kwargs + # Instantiate a SGD optimizer to apply updates as-provided. + self._sgd = tf.keras.optimizers.SGD(learning_rate=1.0) + + @property + def device_policy( + self, + ) -> DevicePolicy: + device = self._device + try: + idx = int(device.name.rsplit(":", 1)[-1]) + except ValueError: + idx = None + return DevicePolicy(gpu=(device.device_type == "GPU"), idx=idx) @property def required_data_info( @@ -98,7 +142,8 @@ class TensorflowModel(Model): ) -> None: if not self._model.built: data_info = aggregate_data_info([data_info], {"input_shape"}) - self._model.build(data_info["input_shape"]) + with tf.device(self._device): + self._model.build(data_info["input_shape"]) # Warn about frozen weights. # similar to TorchModel warning; pylint: disable=duplicate-code if len(self._model.trainable_weights) < len(self._model.weights): @@ -129,9 +174,15 @@ class TensorflowModel(Model): for key in ("model", "loss", "kwargs"): if key not in config.keys(): raise KeyError(f"Missing key '{key}' in the config dict.") - model = tf.keras.layers.deserialize(config["model"]) - loss = tf.keras.losses.deserialize(config["loss"]) - return cls(model, loss, **config["kwargs"]) + # Set up the device policy. + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + # Deserialize the model and loss keras objects on the device. + with tf.device(device): + model = tf.keras.layers.deserialize(config["model"]) + loss = tf.keras.losses.deserialize(config["loss"]) + # Instantiate the TensorflowModel, avoiding device-to-device copies. + return cls(model, loss, **config["kwargs"], _from_config=True) def get_weights( self, @@ -151,8 +202,9 @@ class TensorflowModel(Model): ) self._verify_weights_compatibility(weights, trainable_only=False) variables = {var.name: var for var in self._model.weights} - for name, value in weights.coefs.items(): - variables[name].assign(value, read_value=False) + with tf.device(self._device): + for name, value in weights.coefs.items(): + variables[name].assign(value, read_value=False) def _verify_weights_compatibility( self, @@ -197,12 +249,13 @@ class TensorflowModel(Model): batch: Batch, max_norm: Optional[float] = None, ) -> TensorflowVector: - data = self._unpack_batch(batch) - if max_norm is None: - grads = self._compute_batch_gradients(*data) - else: - norm = tf.constant(max_norm) - grads = self._compute_clipped_gradients(*data, norm) + with tf.device(self._device): + data = self._unpack_batch(batch) + if max_norm is None: + grads = self._compute_batch_gradients(*data) + else: + norm = tf.constant(max_norm) + grads = self._compute_clipped_gradients(*data, norm) grads_and_vars = zip(grads, self._model.trainable_weights) return TensorflowVector( {var.name: grad for grad, var in grads_and_vars} @@ -285,13 +338,14 @@ class TensorflowModel(Model): updates: TensorflowVector, ) -> None: self._verify_weights_compatibility(updates, trainable_only=True) - # Delegate updates' application to a tensorflow Optimizer. - values = (-1 * updates).coefs.values() - zipped = zip(values, self._model.trainable_weights) - upd_op = self._sgd.apply_gradients(zipped) - # Ensure ops have been performed before exiting. - with tf.control_dependencies([upd_op]): - return None + with tf.device(self._device): + # Delegate updates' application to a tensorflow Optimizer. + values = (-1 * updates).coefs.values() + zipped = zip(values, self._model.trainable_weights) + upd_op = self._sgd.apply_gradients(zipped) + # Ensure ops have been performed before exiting. + with tf.control_dependencies([upd_op]): + return None def evaluate( self, @@ -311,21 +365,23 @@ class TensorflowModel(Model): metrics: dict[str, float] Dictionary associating evaluation metrics' values to their name. """ - return self._model.evaluate(dataset, return_dict=True) + with tf.device(self._device): + return self._model.evaluate(dataset, return_dict=True) def compute_batch_predictions( self, batch: Batch, ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: - inputs, y_true, s_wght = self._unpack_batch(batch) - if y_true is None: - raise TypeError( - "`TensorflowModel.compute_batch_predictions` received a " - "batch with `y_true=None`, which is unsupported. Please " - "correct the inputs, or override this method to support " - "creating labels from the base inputs." - ) - y_pred = self._model(inputs, training=False).numpy() + with tf.device(self._device): + inputs, y_true, s_wght = self._unpack_batch(batch) + if y_true is None: + raise TypeError( + "`TensorflowModel.compute_batch_predictions` received a " + "batch with `y_true=None`, which is unsupported. Please " + "correct the inputs, or override this method to support " + "creating labels from the base inputs." + ) + y_pred = self._model(inputs, training=False).numpy() y_true = y_true.numpy() s_wght = s_wght.numpy() if s_wght is not None else s_wght return y_true, y_pred, s_wght @@ -335,7 +391,21 @@ class TensorflowModel(Model): y_true: np.ndarray, y_pred: np.ndarray, ) -> np.ndarray: - tns_true = tf.convert_to_tensor(y_true) - tns_pred = tf.convert_to_tensor(y_pred) - s_loss = self._model.compute_loss(y=tns_true, y_pred=tns_pred) + with tf.device(self._device): + tns_true = tf.convert_to_tensor(y_true) + tns_pred = tf.convert_to_tensor(y_pred) + s_loss = self._model.compute_loss(y=tns_true, y_pred=tns_pred) return s_loss.numpy().squeeze() + + def update_device_policy( + self, + policy: Optional[DevicePolicy] = None, + ) -> None: + # Select the device to use based on the provided or global policy. + if policy is None: + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + # When needed, re-create the model to force moving it to the device. + if self._device is not device: + self._device = device + self._model = move_layer_to_device(self._model, self._device) diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index f10fa5154d3d82b65d747de31d2d8824bc9bbb2d..f347f212b97c37c3176401004f39d4d09aadf89a 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -25,30 +25,55 @@ import tensorflow as tf # type: ignore from tensorflow.python.framework.ops import EagerTensor # type: ignore # pylint: enable=no-name-in-module from typing_extensions import Self # future: import from typing (Py>=3.11) +# fmt: on from declearn.model.api import Vector, register_vector_type from declearn.model.sklearn import NumpyVector - -# fmt: on +from declearn.model.tensorflow.utils import ( + preserve_tensor_device, + select_device, +) +from declearn.utils import get_device_policy @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices) class TensorflowVector(Vector): """Vector subclass to store tensorflow tensors. - This Vector is designed to store a collection of named - TensorFlow tensors, enabling computations that are either - applied to each and every coefficient, or imply two sets - of aligned coefficients (i.e. two TensorflowVector with - similar specifications). + This Vector is designed to store a collection of named TensorFlow + tensors, enabling computations that are either applied to each and + every coefficient, or imply two sets of aligned coefficients (i.e. + two TensorflowVector with similar specifications). + + Note that support for IndexedSlices is implemented, as these are a + common type for auto-differentiated gradients. - Note that support for IndexedSlices is implemented, - as these are a common type for auto-differentiated - gradients. - Note that this class does not (yet?) support special - tensor types such as SparseTensor or RaggedTensor. + Note that this class does not (yet?) support special tensor types + such as SparseTensor or RaggedTensor. Use `vector.coefs` to access the stored coefficients. + + Notes + ----- + - A `TensorflowVector` can be operated with either a: + - scalar value + - `NumpyVector` that has similar specifications + - `TensorflowVector` that has similar specifications + => resulting in a `TensorflowVector` in each of these cases. + - The wrapped tensors may be placed on any device (CPU, GPU...) + and may not be all on the same device. + - The device-placement of the initial `TensorflowVector`'s data + is preserved by operations, including with `NumpyVector`. + - When combining two `TensorflowVector`, the device-placement + of the left-most one is used; in that case, one ends up with + `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a + warning will be emitted to prevent silent un-optimized copies. + - When deserializing a `TensorflowVector` (either by directly using + `TensorflowVector.unpack` or loading one from a JSON dump), loaded + tensors are placed based on the global device-placement policy + (accessed via `declearn.utils.get_device_policy`). Thus it may + have a different device-placement schema than at dump time but + should be coherent with that of `TensorflowModel` computations. """ @property @@ -81,6 +106,23 @@ class TensorflowVector(Vector): ) -> None: super().__init__(coefs) + def apply_func( + self, + func: Callable[..., Any], + *args: Any, + **kwargs: Any, + ) -> Self: + func = preserve_tensor_device(func) + return super().apply_func(func, *args, **kwargs) + + def _apply_operation( + self, + other: Any, + func: Callable[[Any, Any], Any], + ) -> Self: + func = preserve_tensor_device(func) + return super()._apply_operation(other, func) + def dtypes( self, ) -> Dict[str, str]: @@ -97,7 +139,10 @@ class TensorflowVector(Vector): cls, data: Dict[str, Any], ) -> Self: - coef = {key: cls._unpack_tensor(dat) for key, dat in data.items()} + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + with tf.device(device): + coef = {key: cls._unpack_tensor(dat) for key, dat in data.items()} return cls(coef) @classmethod @@ -149,10 +194,13 @@ class TensorflowVector(Vector): if not isinstance(t_a, type(t_b)): return False if isinstance(t_a, tf.IndexedSlices): - return TensorflowVector._tensor_equal( - t_a.indices, t_b.indices - ) and TensorflowVector._tensor_equal(t_a.values, t_b.values) - return tf.reduce_all(t_a == t_b).numpy() + # fmt: off + return ( + TensorflowVector._tensor_equal(t_a.indices, t_b.indices) + and TensorflowVector._tensor_equal(t_a.values, t_b.values) + ) + with tf.device(t_a.device): + return tf.reduce_all(t_a == t_b).numpy() def sign(self) -> Self: return self.apply_func(tf.sign) @@ -178,8 +226,4 @@ class TensorflowVector(Vector): axis: Optional[int] = None, keepdims: bool = False, ) -> Self: - coefs = { - key: tf.reduce_sum(val, axis=axis, keepdims=keepdims) - for key, val in self.coefs.items() - } - return self.__class__(coefs) + return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims) diff --git a/declearn/model/tensorflow/utils/__init__.py b/declearn/model/tensorflow/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f7d66b1d29e47c7b07dc8b10f1b6e9d11187896 --- /dev/null +++ b/declearn/model/tensorflow/utils/__init__.py @@ -0,0 +1,38 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for tensorflow backend support code. + +GPU/CPU backing device management utils: +* move_layer_to_device: + Create a copy of an input keras layer placed on a given device. +* preserve_tensor_device: + Wrap a tensor-processing function to have it run on its inputs' device. +* select_device: + Select a backing device to use based on inputs and availability. + +Loss function management utils: +* build_keras_loss: + Type-check, deserialize and/or wrap a keras loss into a Loss object. +""" + +from ._gpu import ( + move_layer_to_device, + preserve_tensor_device, + select_device, +) +from ._loss import build_keras_loss diff --git a/declearn/model/tensorflow/utils/_gpu.py b/declearn/model/tensorflow/utils/_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..8164250ab560d4db1e08146f960d3ba3f4c6a6d8 --- /dev/null +++ b/declearn/model/tensorflow/utils/_gpu.py @@ -0,0 +1,137 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for GPU support and device management in tensorflow.""" + +import functools +import warnings +from typing import Any, Callable, Optional, Union + +import tensorflow as tf # type: ignore + + +__all__ = [ + "move_layer_to_device", + "preserve_tensor_device", + "select_device", +] + + +def select_device( + gpu: bool, + idx: Optional[int] = None, +) -> tf.config.LogicalDevice: + """Select a backing device to use based on inputs and availability. + + Parameters + ---------- + gpu: bool + Whether to select a GPU device rather than the CPU one. + idx: int or None, default=None + Optional pre-selected device index. Only used when `gpu=True`. + If `idx is None` or exceeds the number of available GPU devices, + use the first available one. + + Warns + ----- + UserWarning: + If `gpu=True` but no GPU is available. + If `idx` exceeds the number of available GPU devices. + + Returns + ------- + device: tf.config.LogicalDevice + Selected device, usable as `tf.device` argument. + """ + idx = 0 if idx is None else idx + # List available CPU or GPU devices. + device_type = "GPU" if gpu else "CPU" + devices = tf.config.list_logical_devices(device_type) + # Case when no GPU is available: warn and use a CPU instead. + if gpu and not devices: + warnings.warn( + "Cannot use a GPU device: either CUDA is unavailable " + "or no GPU is visible to tensorflow." + ) + device_type, idx = "CPU", 0 + devices = tf.config.list_logical_devices("CPU") + # Case when the desired device index is invalid: select another one. + if idx >= len(devices): + warnings.warn( + f"Cannot use {device_type} device n°{idx}: index is out-of-range." + f"\nUsing {device_type} device n°0 instead." + ) + idx = 0 + # Return the selected device. + return devices[idx] + + +def move_layer_to_device( + layer: tf.keras.layers.Layer, + device: Union[tf.config.LogicalDevice, str], +) -> tf.keras.layers.Layer: + """Create a copy of an input keras layer placed on a given device. + + This functions creates a copy of the input layer and of all its weights. + It may therefore be costful and should be used sparingly, to move away + variables on a device where all further computations are expected to be + run. + + Parameters + ---------- + layer: tf.keras.layers.Layer + Keras layer that needs moving to another device. + device: tf.config.LogicalDevice or str + Device where to place the layer's weights. + + Returns + ------- + layer: tf.keras.layers.Layer + Copy of the input layer, with its weights backed on `device`. + """ + config = tf.keras.layers.serialize(layer) + weights = layer.get_weights() + with tf.device(device): + layer = tf.keras.layers.deserialize(config) + layer.set_weights(weights) + return layer + + +def preserve_tensor_device( + func: Callable[..., tf.Tensor], +) -> Callable[..., tf.Tensor]: + """Wrap a tensor-processing function to have it run on its inputs' device. + + Parameters + ---------- + func: function(tf.Tensor, ...) -> tf.Tensor: + Function to wrap, that takes a tensorflow Tensor as first argument. + + Returns + ------- + func: function(tf.Tensor, ...) -> tf.Tensor: + Similar function to the input one, that operates under a `tf.device` + context so as to run computations on the first input tensor's device. + """ + + @functools.wraps(func) + def wrapped(tensor: tf.Tensor, *args: Any, **kwargs: Any) -> tf.Tensor: + """Wrapped function, running under a `tf.device` context.""" + with tf.device(tensor.device): + return func(tensor, *args, **kwargs) + + return wrapped diff --git a/declearn/model/tensorflow/_utils.py b/declearn/model/tensorflow/utils/_loss.py similarity index 98% rename from declearn/model/tensorflow/_utils.py rename to declearn/model/tensorflow/utils/_loss.py index b858e4e4631862a07b8639a8a0229fc6717d547d..c6b880c8e8aaa6dc90cb197e18bf7677f32285fc 100644 --- a/declearn/model/tensorflow/_utils.py +++ b/declearn/model/tensorflow/utils/_loss.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Backend utils for the declearn.model.tensorflow module.""" +"""Function to parse and/or wrap a keras loss for use with declearn.""" import inspect diff --git a/declearn/model/torch/__init__.py b/declearn/model/torch/__init__.py index 352d3a398f6958488404c2ecd5aecc4a0f29a43b..efdf95f1d5a2ab94b777667fcd3787fb27fdea29 100644 --- a/declearn/model/torch/__init__.py +++ b/declearn/model/torch/__init__.py @@ -24,7 +24,11 @@ gradient descent. This module exposes: * TorchModel: Model subclass to wrap torch.nn.Module objects * TorchVector: Vector subclass to wrap torch.Tensor objects + +It also exposes the `utils` submodule, which mainly aims at +providing tools used in the backend of the former objects. """ +from . import utils from ._vector import TorchVector from ._model import TorchModel diff --git a/declearn/model/torch/_model.py b/declearn/model/torch/_model.py index 488d7782be84ffc9ef38ec9cee4d38bf31ff6d37..4bdea3aaeab477078dd28a24bcc32b0fcc68c3ee 100644 --- a/declearn/model/torch/_model.py +++ b/declearn/model/torch/_model.py @@ -27,9 +27,15 @@ import torch from typing_extensions import Self # future: import from typing (py >=3.11) from declearn.model.api import Model +from declearn.model.torch.utils import AutoDeviceModule, select_device from declearn.model.torch._vector import TorchVector from declearn.typing import Batch -from declearn.utils import register_type +from declearn.utils import DevicePolicy, get_device_policy, register_type + + +__all__ = [ + "TorchModel", +] # alias for unpacked Batch structures, converted to torch.Tensor objects @@ -42,8 +48,23 @@ TensorBatch = Tuple[ class TorchModel(Model): """Model wrapper for PyTorch Model instances. - This `Model` subclass is designed to wrap a `torch.nn.Module` - instance to be learned federatively. + This `Model` subclass is designed to wrap a `torch.nn.Module` instance + to be trained federatively. + + Notes regarding device management (CPU, GPU, etc.): + * By default torch operates on CPU, and it does not automatically move + tensors between devices. This means users have to be careful where + tensors are placed to avoid operations between tensors on different + devices, leading to runtime errors. + * Our `TorchModel` instead consults the global device-placement policy + (via `declearn.utils.get_device_policy`), places the wrapped torch + modules' weights there, and automates the placement of input data on + the same device as the wrapped model. + * Note that if the global device-placement policy is updated, this will + only be propagated to existing instances by manually calling their + `update_device_policy` method. + * You may consult the device policy currently enforced by a TorchModel + instance by accessing its `device_policy` property. """ def __init__( @@ -62,18 +83,29 @@ class TorchModel(Model): is to be minimized through training. Note that it will be altered when wrapped. """ - # Type-check the input Model and wrap it up. + # Type-check the input model. if not isinstance(model, torch.nn.Module): raise TypeError("'model' should be a torch.nn.Module instance.") + # Select the device where to place computations, and wrap the model. + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + model = AutoDeviceModule(model, device=device) super().__init__(model) # Assign loss module and set it not to reduce sample-wise values. if not isinstance(loss, torch.nn.Module): raise TypeError("'loss' should be a torch.nn.Module instance.") - self._loss_fn = loss - self._loss_fn.reduction = "none" # type: ignore + loss.reduction = "none" # type: ignore + self._loss_fn = AutoDeviceModule(loss, device=device) # Compute and assign a functional version of the model. self._func_model = functorch.make_functional(self._model)[0] + @property + def device_policy( + self, + ) -> DevicePolicy: + device = self._model.device + return DevicePolicy(gpu=(device.type == "cuda"), idx=device.index) + @property def required_data_info( self, @@ -102,15 +134,12 @@ class TorchModel(Model): "PyTorch JSON serialization relies on pickle, which may be unsafe." ) with io.BytesIO() as buffer: - torch.save(self._model, buffer) + torch.save(self._model.module, buffer) model = buffer.getbuffer().hex() with io.BytesIO() as buffer: - torch.save(self._loss_fn, buffer) + torch.save(self._loss_fn.module, buffer) loss = buffer.getbuffer().hex() - return { - "model": model, - "loss": loss, - } + return {"model": model, "loss": loss} @classmethod def from_config( @@ -139,6 +168,7 @@ class TorchModel(Model): ) -> None: if not isinstance(weights, TorchVector): raise TypeError("TorchModel requires TorchVector weights.") + # NOTE: this preserves the device placement of current states self._model.load_state_dict(weights.coefs) def compute_batch_gradients( @@ -197,7 +227,7 @@ class TorchModel(Model): """Compute the average (opt. weighted) loss over given predictions.""" loss = self._loss_fn(y_pred, y_true) if s_wght is not None: - loss.mul_(s_wght) + loss.mul_(s_wght.to(loss.device)) return loss.mean() def _compute_samplewise_gradients( @@ -236,7 +266,7 @@ class TorchModel(Model): # false-positive; pylint: disable=no-member grad.mul_(torch.clamp(max_norm / norm, max=1)) if s_wght is not None: - grad.mul_(s_wght) + grad.mul_(s_wght.to(grad.device)) return grads # Vectorize the function to compute sample-wise clipped gradients. with torch.no_grad(): @@ -322,7 +352,7 @@ class TorchModel(Model): try: for key, upd in updates.coefs.items(): tns = self._model.get_parameter(key) - tns.add_(upd) + tns.add_(upd.to(tns.device)) except KeyError as exc: raise KeyError( "Invalid model parameter name(s) found in updates." @@ -342,9 +372,9 @@ class TorchModel(Model): ) self._model.eval() with torch.no_grad(): - y_pred = self._model(*inputs).numpy() - y_true = y_true.numpy() - s_wght = s_wght.numpy() if s_wght is not None else s_wght + y_pred = self._model(*inputs).cpu().numpy() + y_true = y_true.cpu().numpy() + s_wght = None if s_wght is None else s_wght.cpu().numpy() return y_true, y_pred, s_wght # type: ignore def loss_function( @@ -355,4 +385,16 @@ class TorchModel(Model): tns_pred = torch.from_numpy(y_pred) # pylint: disable=no-member tns_true = torch.from_numpy(y_true) # pylint: disable=no-member s_loss = self._loss_fn(tns_pred, tns_true) - return s_loss.numpy().squeeze() + return s_loss.cpu().numpy().squeeze() + + def update_device_policy( + self, + policy: Optional[DevicePolicy] = None, + ) -> None: + # Select the device to use based on the provided or global policy. + if policy is None: + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + # Place the wrapped model and loss function modules on that device. + self._model.set_device(device) + self._loss_fn.set_device(device) diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py index 5a6ee87b5a5e1da03e9e2998d9e3338639e6e394..6b2a19e2d05ea79e8e83c4206153c6d614e3e735 100644 --- a/declearn/model/torch/_vector.py +++ b/declearn/model/torch/_vector.py @@ -19,25 +19,47 @@ from typing import Any, Callable, Dict, Optional, Set, Tuple, Type -import numpy as np import torch from typing_extensions import Self # future: import from typing (Py>=3.11) from declearn.model.api import Vector, register_vector_type from declearn.model.sklearn import NumpyVector +from declearn.model.torch.utils import select_device +from declearn.utils import get_device_policy @register_vector_type(torch.Tensor) class TorchVector(Vector): """Vector subclass to store PyTorch tensors. - This Vector is designed to store a collection of named - PyTorch tensors, enabling computations that are either - applied to each and every coefficient, or imply two sets - of aligned coefficients (i.e. two TorchVector with - similar specifications). + This Vector is designed to store a collection of named PyTorch + tensors, enabling computations that are either applied to each + and every coefficient, or imply two sets of aligned coefficients + (i.e. two TorchVector with similar specifications). Use `vector.coefs` to access the stored coefficients. + + Notes + ----- + - A `TorchVector` can be operated with either a: + - scalar value + - `NumpyVector` that has similar specifications + - `TorchVector` that has similar specifications + => resulting in a `TorchVector` in each of these cases. + - The wrapped tensors may be placed on any device (CPU, GPU...) + and may not be all on the same device. + - The device-placement of the initial `TorchVector`'s data + is preserved by operations, including with `NumpyVector`. + - When combining two `TorchVector`, the device-placement + of the left-most one is used; in that case, one ends up with + `gpu + cpu = gpu` while `cpu + gpu = cpu`. In both cases, a + warning will be emitted to prevent silent un-optimized copies. + - When deserializing a `TorchVector` (either by directly using + `TorchVector.unpack` or loading one from a JSON dump), loaded + tensors are placed based on the global device-placement policy + (accessed via `declearn.utils.get_device_policy`). Thus it may + have a different device-placement schema than at dump time but + should be coherent with that of `TorchModel` computations. """ @property @@ -73,12 +95,20 @@ class TorchVector(Vector): other: Any, func: Callable[[Any, Any], Any], ) -> Self: + # Convert 'other' NumpyVector to a (CPU-backed) TorchVector. if isinstance(other, NumpyVector): # false-positive; pylint: disable=no-member coefs = { key: torch.from_numpy(val) for key, val in other.coefs.items() } other = TorchVector(coefs) + # Ensure 'other' TorchVector shares this vector's device placement. + if isinstance(other, TorchVector): + coefs = { + key: val.to(self.coefs[key].device) + for key, val in other.coefs.items() + } + other = TorchVector(coefs) return super()._apply_operation(other, func) def dtypes( @@ -95,15 +125,20 @@ class TorchVector(Vector): def pack( self, ) -> Dict[str, Any]: - return {key: tns.numpy() for key, tns in self.coefs.items()} + return {key: tns.cpu().numpy() for key, tns in self.coefs.items()} @classmethod def unpack( cls, data: Dict[str, Any], ) -> Self: - # false-positive; pylint: disable=no-member - coefs = {key: torch.from_numpy(dat) for key, dat in data.items()} + policy = get_device_policy() + device = select_device(gpu=policy.gpu, idx=policy.idx) + coefs = { + # false-positive on `torch.from_numpy`; pylint: disable=no-member + key: torch.from_numpy(dat).to(device) + for key, dat in data.items() + } return cls(coefs) def __eq__( @@ -115,8 +150,9 @@ class TorchVector(Vector): valid = self.coefs.keys() == other.coefs.keys() if valid: valid = all( - np.array_equal(self.coefs[k].numpy(), other.coefs[k].numpy()) - for k in self.coefs + # false-positive on 'torch.equal'; pylint: disable=no-member + torch.equal(tns, other.coefs[key].to(tns.device)) + for key, tns in self.coefs.items() ) return valid diff --git a/declearn/model/torch/utils/__init__.py b/declearn/model/torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..225118fa3cc035d9bac36542a59a389906587a6e --- /dev/null +++ b/declearn/model/torch/utils/__init__.py @@ -0,0 +1,27 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for torch backend support code. + +GPU/CPU backing device management utils: +* AutoDeviceModule: + Wrapper for a `torch.nn.Module`, automating device-management. +* select_device: + Select a backing device to use based on inputs and availability. +""" + +from ._gpu import AutoDeviceModule, select_device diff --git a/declearn/model/torch/utils/_gpu.py b/declearn/model/torch/utils/_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..59ed28dc340721d5e92d0e65a851592f4c18fffc --- /dev/null +++ b/declearn/model/torch/utils/_gpu.py @@ -0,0 +1,140 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils for GPU support and device management in torch.""" + +import warnings +from typing import Any, Optional + +import torch + + +__all__ = [ + "AutoDeviceModule", + "select_device", +] + + +def select_device( + gpu: bool, + idx: Optional[int] = None, +) -> torch.device: # pylint: disable=no-member + """Select a backing device to use based on inputs and availability. + + Parameters + ---------- + gpu: bool + Whether to select a GPU device rather than the CPU one. + idx: int or None, default=None + Optional pre-selected GPU device index. Only used when `gpu=True`. + If `idx is None` or exceeds the number of available GPU devices, + use `torch.cuda.current_device()`. + + Warns + ----- + UserWarning: + If `gpu=True` but no GPU is available. + If `idx` exceeds the number of available GPU devices. + + Returns + ------- + device: torch.device + Selected torch device, with type "cpu" or "cuda". + """ + # Case when instructed to use the CPU device. + if not gpu: + return torch.device("cpu") # pylint: disable=no-member + # Case when no GPU is available: warn and use the CPU instead. + if gpu and not torch.cuda.is_available(): + warnings.warn( + "Cannot use a GPU device: either CUDA is unavailable " + "or no GPU is visible to torch." + ) + return torch.device("cpu") # pylint: disable=no-member + # Case when the desired GPU is invalid: select another one. + if (idx or 0) >= torch.cuda.device_count(): + warnings.warn( + f"Cannot use GPU device n°{idx}: index is out-of-range.\n" + f"Using GPU device n°{torch.cuda.current_device()} instead." + ) + idx = None + # Return the selected or auto-selected GPU device index. + if idx is None: + idx = torch.cuda.current_device() + return torch.device("cuda", index=idx) # pylint: disable=no-member + + +class AutoDeviceModule(torch.nn.Module): + """Wrapper for a `torch.nn.Module`, automating device-management. + + This `torch.nn.Module` subclass enables wrapping another one, and + provides: + * a `device` attribute (and instantiation parameter) indicating + where the wrapped module is placed + * automatic placement of input tensors on that device as part of + `forward` calls to the module + * a `set_device` method to change the device and move the wrapped + module to it + + This aims at internalizing device-management boilerplate code. + The wrapped module is assigned to the `module` attribute and thus + can be accessed directly. + """ + + def __init__( + self, + module: torch.nn.Module, + device: torch.device, # pylint: disable=no-member + ) -> None: + """Wrap a torch Module into an AutoDeviceModule. + + Parameters + ---------- + module: torch.nn.Module + Torch module that needs wrapping. + device: torch.device + Torch device where to place the wrapped module and computations. + """ + super().__init__() + self.device = device + self.module = module.to(self.device) + + def forward(self, *inputs: Any) -> torch.Tensor: + """Run the forward computation, automating device-placement of inputs. + + Please refer to `self.module.forward` for details on the wrapped + module's forward specifications. + """ + inputs = tuple( + x.to(self.device) if isinstance(x, torch.Tensor) else x + for x in inputs + ) + return self.module(*inputs) + + def set_device( + self, + device: torch.device, # pylint: disable=no-member + ) -> None: + """Move the wrapped module to a pre-selected torch device. + + Parameters + ---------- + device: torch.device + Torch device where to place the wrapped module and computations. + """ + self.device = device + self.module.to(device) diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py index d789cec4933795a0364b3083164ff449b9f79c57..27f821674acec621db1fbf14d569d0c9505b55ea 100644 --- a/declearn/test_utils/_vectors.py +++ b/declearn/test_utils/_vectors.py @@ -87,7 +87,8 @@ class GradientsTestCase: return array if self.framework == "tensorflow": tensorflow = importlib.import_module("tensorflow") - return tensorflow.convert_to_tensor(array) + with tensorflow.device("CPU"): + return tensorflow.convert_to_tensor(array) if self.framework == "torch": torch = importlib.import_module("torch") return torch.from_numpy(array) diff --git a/declearn/utils/__init__.py b/declearn/utils/__init__.py index 00bf9fc14932fef61b6eaa4ab9b8aefeb404374c..bee5e57cbbe3f30f9b6acd2753ad9096c3c25ab8 100644 --- a/declearn/utils/__init__.py +++ b/declearn/utils/__init__.py @@ -67,6 +67,17 @@ And examples of pre-registered (de)serialization functions: * (deserialize_numpy, serialize_numpy): Pair of functions to (un)pack a numpy ndarray as JSON-serializable data. +Device-policy utils +------------------- +Utils to access or update parameters defining a global device-selection policy. + +* DevicePolicy: + Dataclass to store parameters defining a device-selection policy. +* get_device_policy: + Access a copy of the current global device policy. +* set_device_policy: + Update the current global device policy. + Miscellaneous ------------- @@ -84,6 +95,11 @@ from ._dataclass import ( dataclass_from_func, dataclass_from_init, ) +from ._device_policy import ( + DevicePolicy, + get_device_policy, + set_device_policy, +) from ._json import ( add_json_support, json_dump, diff --git a/declearn/utils/_device_policy.py b/declearn/utils/_device_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..f26ea9653d4018d88c23e0020fb9aee38b968a72 --- /dev/null +++ b/declearn/utils/_device_policy.py @@ -0,0 +1,122 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to define a computation device policy. + +This private submodule defines: +* A dataclass defining a standard to hold a device-selection policy. +* A private global variable holding the current package-wise device policy. +* A public pair of functions acting as a getter and a setter for that variable. +""" + +import dataclasses +from typing import Optional + + +__all__ = [ + "DevicePolicy", + "get_device_policy", + "set_device_policy", +] + + +@dataclasses.dataclass +class DevicePolicy: + """Dataclass to store parameters defining a device-selection policy. + + This class merely defines a shared language of keyword-arguments to + define whether to back computations on CPU or on a GPU device. + + It is meant to be instantiated as a global variable that holds that + information, and can be accessed by framework-specific backend code + so as to take the required steps towards implementing that policy. + + To access or update the current global DevicePolicy, please use the + getter and setter functions: `declearn.utils.get_device_policy` and + `declearn.utils.set_device_policy`. + + Attributes + ---------- + gpu: bool + Whether to use a GPU device rather than the CPU one to back data + and computations. If no GPU is available, use CPU with a warning. + idx: int or None + Optional index of the GPU device to use. + If None, select one arbitrarily. + If this index exceeds the number of available GPUs, select one + arbitrarily, with a warning. + """ + + gpu: bool + idx: Optional[int] + + def __post_init__(self) -> None: + if not isinstance(self.gpu, bool): + raise TypeError( + f"DevicePolicy 'gpu' should be a bool, not '{type(self.gpu)}'." + ) + if not (self.idx is None or isinstance(self.idx, int)): + raise TypeError( + "DevicePolicy 'idx' should be None or an int, not " + f"'{type(self.idx)}'." + ) + + +DEVICE_POLICY = DevicePolicy(gpu=True, idx=None) + + +def get_device_policy() -> DevicePolicy: + """Return a copy of the current global device policy. + + This method is meant to be used: + - By end-users that wish to check the current device policy. + - By the backend code of framework-specific objects so as to + take the required steps towards implementing that policy. + + To update the current policy, use `declearn.utils.set_device_policy`. + + Returns + ------- + policy: DevicePolicy + DevicePolicy dataclass instance, wrapping parameters that specify + the device policy to be enforced by Model and Vector to properly + place data and computations. + """ + return DevicePolicy(**dataclasses.asdict(DEVICE_POLICY)) + + +def set_device_policy( + gpu: bool, + idx: Optional[int] = None, +) -> None: + """Update the current global device policy. + + To access the current policy, use `declearn.utils.set_device_policy`. + + Parameters + ---------- + gpu: bool + Whether to use a GPU device rather than the CPU one to back data + and computations. If no GPU is available, use CPU with a warning. + idx: int or None, default=None + Optional index of the GPU device to use. + If this index exceeds the number of available GPUs, select one + arbitrarily, with a warning. + """ + # Using a global statement to have a proper setter to a private variable. + global DEVICE_POLICY # pylint: disable=global-statement + DEVICE_POLICY = DevicePolicy(gpu, idx) diff --git a/test/model/model_testing.py b/test/model/model_testing.py index 3d632814f678f506a81a25dc7f384141fcbf55df..9936788969cecc3eef00913bcc2f7e0659a8eee8 100644 --- a/test/model/model_testing.py +++ b/test/model/model_testing.py @@ -18,7 +18,7 @@ """Shared testing code for TensorFlow and Torch models' unit tests.""" import json -from typing import Any, List, Protocol, Tuple, Type, Union +from typing import Any, Generic, List, Protocol, Tuple, Type, TypeVar, Union import numpy as np @@ -27,10 +27,13 @@ from declearn.typing import Batch from declearn.utils import json_pack, json_unpack -class ModelTestCase(Protocol): +VectorT = TypeVar("VectorT", bound=Vector) + + +class ModelTestCase(Protocol, Generic[VectorT]): """TestCase fixture-provider protocol.""" - vector_cls: Type[Vector] + vector_cls: VectorT tensor_cls: Union[Type[Any], Tuple[Type[Any], ...]] @staticmethod @@ -51,6 +54,12 @@ class ModelTestCase(Protocol): ) -> Model: """Suited toy binary-classification model.""" + def assert_correct_device( + self, + vector: VectorT, + ) -> None: + """Raise if a vector is backed on the wrong type of device.""" + class ModelTestSuite: """Unit tests for a declearn Model.""" @@ -64,6 +73,7 @@ class ModelTestSuite: config = json.dumps(model.get_config()) other = model.from_config(json.loads(config)) assert model.get_config() == other.get_config() + assert model.device_policy == other.device_policy def test_get_set_weights( self, @@ -75,7 +85,12 @@ class ModelTestSuite: assert isinstance(w_srt, test_case.vector_cls) w_end = w_srt + 1.0 model.set_weights(w_end) - assert model.get_weights() == w_end + w_upd = model.get_weights() + assert w_upd == w_end + # Check that weight tensors are properly placed. + test_case.assert_correct_device(w_srt) + test_case.assert_correct_device(w_end) + test_case.assert_correct_device(w_upd) def test_compute_batch_gradients( self, @@ -113,7 +128,13 @@ class ModelTestSuite: np_grads = model.compute_batch_gradients(np_batch) # type: ignore assert isinstance(np_grads, test_case.vector_cls) my_grads = model.compute_batch_gradients(my_batch) - assert my_grads == np_grads + # Allow for a numerical imprecision of 10^-9. + diff = my_grads - np_grads + max_err = max( + np.abs(test_case.to_numpy(weight)).max() + for weight in diff.coefs.values() + ) + assert max_err < 1e-8 def test_compute_batch_gradients_clipped( self, @@ -137,6 +158,9 @@ class ModelTestSuite: for k in grads_a.coefs ) assert grads_a != grads_b + # Check that gradients are properly placed. + test_case.assert_correct_device(grads_a) + test_case.assert_correct_device(grads_b) def test_apply_updates( self, @@ -157,9 +181,14 @@ class ModelTestSuite: # NOTE: if the model had frozen weights, this test would xfail. w_end = model.get_weights() assert w_end != w_srt - updt = [test_case.to_numpy(val) for val in grads.coefs.values()] - diff = list((w_end - w_srt).coefs.values()) - assert all(np.abs(a - b).max() < 1e-6 for a, b in zip(diff, updt)) + diff = (w_end - w_srt) - grads + assert all( + np.abs(test_case.to_numpy(weight)).max() < 1e-6 + for weight in diff.coefs.values() + ) + # Check that gradients and updated weights are properly placed. + test_case.assert_correct_device(grads) + test_case.assert_correct_device(w_end) def test_serialize_gradients( self, diff --git a/test/model/test_sksgd.py b/test/model/test_sksgd.py index b214f28ac1c0fb595365dec187ca8ac94d52ed71..e53febae8f98d3fd43fbabc61695e9ae366c054b 100644 --- a/test/model/test_sksgd.py +++ b/test/model/test_sksgd.py @@ -106,6 +106,12 @@ class SklearnSGDTestCase(ModelTestCase): model.initialize(data_info) return model + def assert_correct_device( + self, + vector: NumpyVector, + ) -> None: + pass + @pytest.fixture(name="test_case") def fixture_test_case( diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index aafd5d5d567ff37baba544ab8ff34904a529674d..ba1193f6d19e2e15da5dc98a27291bf5fc0722a3 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -33,6 +33,7 @@ except ModuleNotFoundError: from declearn.model.tensorflow import TensorflowModel, TensorflowVector from declearn.typing import Batch +from declearn.utils import set_device_policy # dirty trick to import from `model_testing.py`; # pylint: disable=wrong-import-order, wrong-import-position @@ -67,11 +68,16 @@ class TensorflowTestCase(ModelTestCase): def __init__( self, kind: Literal["MLP", "RNN", "CNN"], + device: Literal["CPU", "GPU"], ) -> None: """Specify the desired model architecture.""" if kind not in ("MLP", "RNN", "CNN"): raise ValueError(f"Invalid keras test architecture: '{kind}'.") + if device not in ("CPU", "GPU"): + raise ValueError(f"Invalid device choice for test: '{device}'.") self.kind = kind + self.device = device + set_device_policy(gpu=(device == "GPU"), idx=0) @staticmethod def to_numpy( @@ -79,7 +85,7 @@ class TensorflowTestCase(ModelTestCase): ) -> np.ndarray: """Convert an input tensor to a numpy array.""" assert isinstance(tensor, tf.Tensor) - return tensor.numpy() # type: ignore + return tensor.numpy() @property def dataset( @@ -132,16 +138,47 @@ class TensorflowTestCase(ModelTestCase): tfmod.build(shape) # as model is built, no data_info is required return TensorflowModel(tfmod, loss="binary_crossentropy", metrics=None) + def assert_correct_device( + self, + vector: TensorflowVector, + ) -> None: + """Raise if a vector is backed on the wrong type of device.""" + name = f"{self.device}:0" + assert all( + tensor.device.endswith(name) for tensor in vector.coefs.values() + ) + @pytest.fixture(name="test_case") def fixture_test_case( - kind: Literal["MLP", "RNN", "CNN"] + kind: Literal["MLP", "RNN", "CNN"], + device: Literal["CPU", "GPU"], ) -> TensorflowTestCase: """Fixture to access a TensorflowTestCase.""" - return TensorflowTestCase(kind) + return TensorflowTestCase(kind, device) + +DEVICES = ["CPU"] +if tf.config.list_logical_devices("GPU"): + DEVICES.append("GPU") + +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("kind", ["MLP", "RNN", "CNN"]) @pytest.mark.filterwarnings("ignore:.*randrange.*:DeprecationWarning") class TestTensorflowModel(ModelTestSuite): """Unit tests for declearn.model.tensorflow.TensorflowModel.""" + + def test_proper_model_placement( + self, + test_case: TensorflowTestCase, + ) -> None: + """Check that at instantiation, model weights are properly placed.""" + model = test_case.model + policy = model.device_policy + assert policy.gpu == (test_case.device == "GPU") + assert policy.idx == 0 + tfmod = getattr(model, "_model") + device = f"{test_case.device}:0" + for var in tfmod.weights: + assert var.device.endswith(device) diff --git a/test/model/test_torch.py b/test/model/test_torch.py index 18b6418a632182262e040ecada563e3cc958ebee..5a6729d59bd53e0e69273ddf5ad936776ddfc78f 100644 --- a/test/model/test_torch.py +++ b/test/model/test_torch.py @@ -17,6 +17,7 @@ """Unit tests for TorchModel.""" +import json import sys from typing import Any, List, Literal, Tuple @@ -30,6 +31,7 @@ except ModuleNotFoundError: from declearn.model.torch import TorchModel, TorchVector from declearn.typing import Batch +from declearn.utils import set_device_policy # dirty trick to import from `model_testing.py`; # pylint: disable=wrong-import-order, wrong-import-position @@ -84,11 +86,14 @@ class TorchTestCase(ModelTestCase): def __init__( self, kind: Literal["MLP", "RNN", "CNN"], + device: Literal["CPU", "GPU"], ) -> None: """Specify the desired model architecture.""" if kind not in ("MLP", "RNN", "CNN"): raise ValueError(f"Invalid torch test architecture: '{kind}'.") self.kind = kind + self.device = device + set_device_policy(gpu=(device == "GPU"), idx=0) @staticmethod def to_numpy( @@ -96,7 +101,7 @@ class TorchTestCase(ModelTestCase): ) -> np.ndarray: """Convert an input tensor to a numpy array.""" assert isinstance(tensor, torch.Tensor) - return tensor.numpy() # type: ignore + return tensor.cpu().numpy() @property def dataset( @@ -133,7 +138,7 @@ class TorchTestCase(ModelTestCase): elif self.kind == "RNN": stack = [ torch.nn.Embedding(100, 32), - torch.nn.LSTM(32, 16, batch_first=True), # type: ignore + torch.nn.LSTM(32, 16, batch_first=True), ExtractLSTMFinalOutput(), torch.nn.Tanh(), torch.nn.Linear(16, 1), @@ -154,13 +159,32 @@ class TorchTestCase(ModelTestCase): nnmod = torch.nn.Sequential(*stack) return TorchModel(nnmod, loss=torch.nn.BCELoss()) + def assert_correct_device( + self, + vector: TorchVector, + ) -> None: + """Raise if a vector is backed on the wrong type of device.""" + dev_type = "cuda" if self.device == "GPU" else "cpu" + assert all( + tensor.device.type == dev_type for tensor in vector.coefs.values() + ) + @pytest.fixture(name="test_case") -def fixture_test_case(kind: Literal["MLP", "RNN", "CNN"]) -> TorchTestCase: +def fixture_test_case( + kind: Literal["MLP", "RNN", "CNN"], + device: Literal["CPU", "GPU"], +) -> TorchTestCase: """Fixture to access a TorchTestCase.""" - return TorchTestCase(kind) + return TorchTestCase(kind, device) + +DEVICES = ["CPU"] +if torch.cuda.device_count(): + DEVICES.append("GPU") + +@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("kind", ["MLP", "RNN", "CNN"]) class TestTorchModel(ModelTestSuite): """Unit tests for declearn.model.torch.TorchModel.""" @@ -175,12 +199,33 @@ class TestTorchModel(ModelTestSuite): # due to the (de)serialization of a custom nn.Module # the expected model behaviour is, however, correct try: - super().test_serialization(test_case) + self._test_serialization(test_case) except AssertionError: pytest.skip( "skipping failed test due to custom nn.Module pickling" ) - super().test_serialization(test_case) + self._test_serialization(test_case) + + def _test_serialization( + self, + test_case: ModelTestCase, + ) -> None: + """Check that the model can be JSON-(de)serialized properly. + + This method replaces the parent `test_serialization` one. + """ + # Same setup as in parent test: a model and a config-based other. + model = test_case.model + config = json.dumps(model.get_config()) + other = model.from_config(json.loads(config)) + # Verify that both models have the same device policy. + assert model.device_policy == other.device_policy + # Verify that both models have a similar structure of modules. + mod_a = list(getattr(model, "_model").modules()) + mod_b = list(getattr(other, "_model").modules()) + assert len(mod_a) == len(mod_b) + assert all(isinstance(a, type(b)) for a, b in zip(mod_a, mod_b)) + assert all(repr(a) == repr(b) for a, b in zip(mod_a, mod_b)) def test_compute_batch_gradients_clipped( self, @@ -195,3 +240,17 @@ class TestTorchModel(ModelTestSuite): ) else: super().test_compute_batch_gradients_clipped(test_case) + + def test_proper_model_placement( + self, + test_case: TorchTestCase, + ) -> None: + """Check that at instantiation, model weights are properly placed.""" + model = test_case.model + policy = model.device_policy + assert policy.gpu == (test_case.device == "GPU") + assert (policy.idx == 0) if policy.gpu else (policy.idx is None) + ptmod = getattr(model, "_model").module + device_type = "cpu" if test_case.device == "CPU" else "cuda" + for param in ptmod.parameters(): + assert param.device.type == device_type diff --git a/test/model/test_vector.py b/test/model/test_vector.py index 6374ad409370145a4f6b2ca4b8a953f3bf9f777e..ca5592a76d25067569903d78b652f202b70fd783 100644 --- a/test/model/test_vector.py +++ b/test/model/test_vector.py @@ -32,7 +32,10 @@ from declearn.test_utils import ( GradientsTestCase, list_available_frameworks, ) -from declearn.utils import json_pack, json_unpack +from declearn.utils import json_pack, json_unpack, set_device_policy + + +set_device_policy(gpu=False) # run Vector unit tests on CPU only @pytest.fixture(name="framework", params=list_available_frameworks()) diff --git a/test/optimizer/test_modules.py b/test/optimizer/test_modules.py index 0bbde31ad6746bd94b2cc631c4d2b0a991645682..ffd1e3132ffb64691b04ebdd03cdad4bd5d4e355 100644 --- a/test/optimizer/test_modules.py +++ b/test/optimizer/test_modules.py @@ -44,7 +44,7 @@ from declearn.test_utils import ( GradientsTestCase, assert_json_serializable_dict, ) -from declearn.utils import access_types_mapping +from declearn.utils import access_types_mapping, set_device_policy # relative import; pylint: disable=wrong-import-order, wrong-import-position # fmt: off @@ -56,6 +56,8 @@ sys.path.pop() OPTIMODULE_SUBCLASSES = access_types_mapping(group="OptiModule") +set_device_policy(gpu=False) # run all OptiModule tests on CPU + @pytest.mark.parametrize( "cls", OPTIMODULE_SUBCLASSES.values(), ids=OPTIMODULE_SUBCLASSES.keys()