diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index bc73d25c0baf4de2dd2af6f7a717132c8190ba1d..401bbb0665ae95b80a78f0bbb55fa90feb16a3af 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -17,7 +17,8 @@ """TensorflowVector data arrays container.""" -from typing import Any, Callable, Dict, Optional, Set, Type, Union +import warnings +from typing import Any, Callable, Dict, Optional, Set, Type, TypeVar, Union # fmt: off import numpy as np @@ -31,6 +32,7 @@ from typing_extensions import Self # future: import from typing (Py>=3.11) from declearn.model.api import Vector, register_vector_type from declearn.model.sklearn import NumpyVector from declearn.model.tensorflow.utils import ( + add_indexed_slices_support, preserve_tensor_device, select_device, ) @@ -42,6 +44,33 @@ __all__ = [ ] +TensorT = TypeVar("TensorT", tf.Tensor, tf.IndexedSlices) + + +def enhance_tf_op( + tf_op: Callable[[tf.Tensor, Any], tf.Tensor], + inplc: bool = False, +) -> Callable[[TensorT, Any], TensorT]: + """Wrap up a tensorflow operation to preserve IndexedSlices and device.""" + func = add_indexed_slices_support(preserve_tensor_device(tf_op), inplc) + setattr(func, "_pre_wrapped", True) + return func + + +# Wrap up base tensorflow operations to add support for IndexedSlices +# inputs and preserve tensor's device-placement +tf_op_add = enhance_tf_op(tf.add) +tf_op_sub = enhance_tf_op(tf.subtract) +tf_op_mul = enhance_tf_op(tf.multiply) +tf_op_div = enhance_tf_op(tf.truediv) +tf_op_pow = enhance_tf_op(tf.pow) +tf_op_min = enhance_tf_op(tf.minimum) +tf_op_max = enhance_tf_op(tf.maximum) +tf_op_sign = enhance_tf_op(tf.sign, inplc=True) +tf_op_sqre = enhance_tf_op(tf.square, inplc=True) +tf_op_sqrt = enhance_tf_op(tf.sqrt, inplc=True) + + @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices) class TensorflowVector(Vector): """Vector subclass to store tensorflow tensors. @@ -52,9 +81,14 @@ class TensorflowVector(Vector): two TensorflowVector with similar specifications). Note that support for IndexedSlices is implemented, as these are a - common type for auto-differentiated gradients. - - Note that this class does not (yet?) support special tensor types + common type for auto-differentiated gradients. When using built-in + operators and methods, these structures will be preserved, unless + densification is required (e.g. when summing with a dense tensor). + When using `TensorflowVector.apply_func` directly, support for the + IndexedSlices' preservation should be added manually, typically by + using `declearn.model.tensorflow.utils.add_indexed_slices_support`. + + Note that this class does not currently support special tensor types such as SparseTensor or RaggedTensor. Use `vector.coefs` to access the stored coefficients. @@ -84,23 +118,23 @@ class TensorflowVector(Vector): @property def _op_add(self) -> Callable[[Any, Any], Any]: - return tf.add + return tf_op_add @property def _op_sub(self) -> Callable[[Any, Any], Any]: - return tf.subtract + return tf_op_sub @property def _op_mul(self) -> Callable[[Any, Any], Any]: - return tf.multiply + return tf_op_mul @property def _op_div(self) -> Callable[[Any, Any], Any]: - return tf.divide + return tf_op_div @property def _op_pow(self) -> Callable[[Any, Any], Any]: - return tf.pow + return tf_op_pow @property def compatible_vector_types(self) -> Set[Type[Vector]]: @@ -118,7 +152,8 @@ class TensorflowVector(Vector): *args: Any, **kwargs: Any, ) -> Self: - func = preserve_tensor_device(func) + if not getattr(func, "_pre_wrapped", False): + func = preserve_tensor_device(func) return super().apply_func(func, *args, **kwargs) def _apply_operation( @@ -126,7 +161,8 @@ class TensorflowVector(Vector): other: Any, func: Callable[[Any, Any], Any], ) -> Self: - func = preserve_tensor_device(func) + if not getattr(func, "_pre_wrapped", False): + func = preserve_tensor_device(func) return super()._apply_operation(other, func) def dtypes( @@ -160,7 +196,8 @@ class TensorflowVector(Vector): if isinstance(tensor, tf.IndexedSlices): val = cls._pack_tensor(tensor.values) ind = cls._pack_tensor(tensor.indices) - return ["slices", val, ind] + shp = cls._pack_tensor(tensor.dense_shape) + return ["slices", val, ind, shp] return np.array(tensor.numpy()) @classmethod @@ -172,7 +209,8 @@ class TensorflowVector(Vector): if isinstance(data, list) and (data[0] == "slices"): val = cls._unpack_tensor(data[1]) ind = cls._unpack_tensor(data[2]) - return tf.IndexedSlices(val, ind) + shp = cls._unpack_tensor(data[3]) + return tf.IndexedSlices(val, ind, shp) try: return tf.convert_to_tensor(data) except TypeError as exc: @@ -209,29 +247,40 @@ class TensorflowVector(Vector): return tf.reduce_all(t_a == t_b).numpy() def sign(self) -> Self: - return self.apply_func(tf.sign) + return self.apply_func(tf_op_sign) def minimum( self, other: Any, ) -> Self: if isinstance(other, Vector): - return self._apply_operation(other, tf.minimum) - return self.apply_func(tf.minimum, other) + return self._apply_operation(other, tf_op_min) + return self.apply_func(tf_op_min, other) def maximum( self, other: Any, ) -> Self: if isinstance(other, Vector): - return self._apply_operation(other, tf.maximum) - return self.apply_func(tf.maximum, other) + return self._apply_operation(other, tf_op_max) + return self.apply_func(tf_op_max, other) def sum( self, axis: Optional[int] = None, keepdims: bool = False, ) -> Self: + if keepdims or (axis is not None): + if any( + isinstance(x, tf.IndexedSlices) for x in self.coefs.values() + ): + warnings.warn( + "Calling `TensorflowVector.sum()` with non-default " + "arguments and tf.IndexedSlices coefficients might " + "result in unexpected outputs, due to the latter " + "being converted to their dense counterpart.", + category=RuntimeWarning, + ) return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims) def __pow__( @@ -242,7 +291,7 @@ class TensorflowVector(Vector): # than tf.pow as results tend to differ for small values. if isinstance(other, (int, float)): if other == 2: - return self.apply_func(tf.square) + return self.apply_func(tf_op_sqre) if other == 0.5: - return self.apply_func(tf.sqrt) + return self.apply_func(tf_op_sqrt) return super().__pow__(other) diff --git a/declearn/model/tensorflow/utils/__init__.py b/declearn/model/tensorflow/utils/__init__.py index 5f7d66b1d29e47c7b07dc8b10f1b6e9d11187896..3c376866a9307db883ccdf79c325fed2f368797e 100644 --- a/declearn/model/tensorflow/utils/__init__.py +++ b/declearn/model/tensorflow/utils/__init__.py @@ -28,6 +28,10 @@ GPU/CPU backing device management utils: Loss function management utils: * build_keras_loss: Type-check, deserialize and/or wrap a keras loss into a Loss object. + +Better support for sparse tensor structures: +* add_indexed_slices_support: + Run a function on a pair of tensors, adding support for IndexedSlices. """ from ._gpu import ( @@ -36,3 +40,4 @@ from ._gpu import ( select_device, ) from ._loss import build_keras_loss +from ._slices import add_indexed_slices_support diff --git a/declearn/model/tensorflow/utils/_slices.py b/declearn/model/tensorflow/utils/_slices.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c43959b0c7ca0705901a2a85ec7dabbe8fcf3a --- /dev/null +++ b/declearn/model/tensorflow/utils/_slices.py @@ -0,0 +1,137 @@ +# coding: utf-8 + +# Copyright 2023 Inria (Institut National de Recherche en Informatique +# et Automatique) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utils to handle tf.IndexedSlices as part of tensor-processing operations.""" + +import functools +import warnings +from typing import Any, Callable, TypeVar + +import numpy as np +import tensorflow as tf # type: ignore + + +__all__ = [ + "add_indexed_slices_support", +] + + +TensorT = TypeVar("TensorT", tf.Tensor, tf.IndexedSlices) + + +def apply_func_to_tensor_or_slices( + first: TensorT, + other: Any, + tf_op: Callable[[tf.Tensor, Any], tf.Tensor], +) -> TensorT: + """Run a function on a pair of tensors, adding support for IndexedSlices. + + The intended use of this function is to add support for IndexedSlices + to basic tensorflow operators, such as `tf.add` or `tf.multiply`. + + Parameters + ---------- + first: tf.Tensor or tf.IndexedSlices + Tensor or IndexedSlices data structure. + other: tf.Tensor or tf.IndexedSlices or np.ndarray or float or int + Scalar or data array that needs operating onto `first` via `tf_op`. + tf_op: function(tf.Tensor, any) -> tf.Tensor + Function that operates on a tf.Tensor and another value. + + Returns + ------- + output: tf.Tensor or tf.IndexedSlices + Result from running `tf_op(first, other)` if first is a tf.Tensor, + or from re-wrapping `tf_op(first.values, other[.values])` into a + tf.IndexedSlices structure if first is such a structure. The only + exception is when operating on IndexedSlices and a full-rank array + or tensor: then a full-rank output is returned, with a warning. + + Raises + ------ + TypeError: + If `first` and `other` are two tf.IndexedSlices with different + shapes or non-zero indices. + If `first` is a tf.IndexedSlices and `func` failed on its values. + """ + slice_inp = isinstance(first, tf.IndexedSlices) + # Case when combining two IndexedSlices objects. + if slice_inp and isinstance(other, tf.IndexedSlices): + if ( + (first.dense_shape.ndim == other.dense_shape.ndim) + and tf.reduce_all(first.dense_shape == other.dense_shape) + and (first.indices.shape == other.indices.shape) + and tf.reduce_all(first.indices == other.indices) + ): + values = tf_op(first.values, other.values) + return tf.IndexedSlices(values, first.indices, first.dense_shape) + raise TypeError( + f"Cannot apply function {tf_op.__name__} to two IndexedSlices " + "structures with different shapes or indices." + ) + # Case when operating into an IndexedSlices object. + if slice_inp: + # Case when operating with a dense tensor (or array) of same shape. + if isinstance(other, (tf.Tensor, np.ndarray)): + if first.shape == other.shape: + warnings.warn( + f"Applying function {tf_op.__name__} to IndexSlices with " + "a full-rank array or tensor results in densifying it.", + RuntimeWarning, + ) + return tf_op(tf.convert_to_tensor(first), other) + # Generic case (including mis-shaped tensor, to raise an error). + try: + values = tf_op(first.values, other) + except Exception as exc: + raise TypeError( + f"Failed to apply function {tf_op.__name__} to combine a " + f"{type(other)} object into an IndexedSlices tensor: {exc}." + ) from exc + return tf.IndexedSlices(values, first.indices, first.dense_shape) + # All other cases (including right-hand slices that will be converted). + return tf_op(first, other) + + +def add_indexed_slices_support( + tf_op: Callable[[tf.Tensor, Any], tf.Tensor], + inplc: bool = False, +) -> Callable[[TensorT, Any], TensorT]: + """Wrap an input function to overload the handling of tf.IndexedSlices. + + Parameters + ---------- + tf_op: function(tf.Tensor, [any]) -> tf.Tensor + Tensor-processing operation that needs wrapping. + inplc: bool, default=False + Whether to replace the second argument of `tf_op` with None. + Use this to transform tensor-processing functions (wich, in + general, have a `name=None` argument) rather than operations. + + Returns + ------- + func: function(<T>, any) -> <T>, with <T>:(tf.Tensor|tf.IndexedSlices) + Tensor-processing operation that wraps `tf_op` but supports and + preserves tf.IndexedSlices inputs as first (and opt. second) + argument. + Note that in the rare case when func(slices, dense) is called, + the output will be dense, and a RuntimeWarning will be raised. + """ + func = functools.partial(apply_func_to_tensor_or_slices, tf_op=tf_op) + if inplc: + func = functools.partial(func, other=None) + return functools.wraps(tf_op)(func) diff --git a/declearn/optimizer/modules/_noise.py b/declearn/optimizer/modules/_noise.py index 04785191068e1ad3defda614896ff26d0d70dfb4..f593a6bce116a0ef60bb51fc57c86bfb5d21660c 100644 --- a/declearn/optimizer/modules/_noise.py +++ b/declearn/optimizer/modules/_noise.py @@ -17,6 +17,7 @@ """Noise-addition modules for DP using cryptographically-strong RNG.""" +import warnings from abc import ABCMeta, abstractmethod from random import SystemRandom from typing import Any, ClassVar, Dict, Optional, Tuple @@ -94,7 +95,10 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): for key in gradients.coefs } # Add the sampled noise to the gradients and return them. - return gradients + NumpyVector(noise) + # Silence warnings about sparse gradients getting sparsified. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*densifying.*", RuntimeWarning) + return gradients + NumpyVector(noise) @abstractmethod def _sample_noise( diff --git a/declearn/test_utils/_vectors.py b/declearn/test_utils/_vectors.py index 27f821674acec621db1fbf14d569d0c9505b55ea..7b7722c0cf6a8039932c3a84bd149c4f1b41d4d9 100644 --- a/declearn/test_utils/_vectors.py +++ b/declearn/test_utils/_vectors.py @@ -98,6 +98,11 @@ class GradientsTestCase: """Convert an input framework-based structure to a numpy array.""" if isinstance(array, np.ndarray): return array + if self.framework == "tensorflow": # add support for IndexedSlices + tensorflow = importlib.import_module("tensorflow") + if isinstance(array, tensorflow.IndexedSlices): + with tensorflow.device(array.device): + return tensorflow.convert_to_tensor(array).numpy() return array.numpy() # type: ignore @property @@ -111,9 +116,21 @@ class GradientsTestCase: rng = np.random.default_rng(self.seed) shapes = [(64, 32), (32,), (32, 16), (16,), (16, 1), (1,)] values = [rng.normal(size=shape) for shape in shapes] - return self.vector_cls( + vector = self.vector_cls( {str(idx): self.convert(value) for idx, value in enumerate(values)} ) + # In Tensorflow, convert the first gradients to IndexedSlices. + # In this case they are equivalent to dense ones, but this enables + # testing the support for these structures while maintaining the + # possibility to compare outputs' values with other frameworks. + if self.framework == "tensorflow": + tensorflow = importlib.import_module("tensorflow") + vector.coefs["0"] = tensorflow.IndexedSlices( + values=vector.coefs["0"], + indices=tensorflow.range(64), + dense_shape=tensorflow.convert_to_tensor([64, 32]), + ) + return vector @property def mock_ones(self) -> Vector: diff --git a/test/model/test_tflow.py b/test/model/test_tflow.py index 49ae73225dc910ee43652cf1d0cbce6847a335b8..aefd62e45489b2e264babbc69830fec68458147b 100644 --- a/test/model/test_tflow.py +++ b/test/model/test_tflow.py @@ -86,6 +86,8 @@ class TensorflowTestCase(ModelTestCase): tensor: Any, ) -> np.ndarray: """Convert an input tensor to a numpy array.""" + if isinstance(tensor, tf.IndexedSlices): + tensor = tf.convert_to_tensor(tensor) assert isinstance(tensor, tf.Tensor) return tensor.numpy() diff --git a/test/optimizer/test_tflow_optim.py b/test/optimizer/test_tflow_optim.py index 7ff29005a5b3c629878ee647f67d3108b5e8e45b..7475c1f447e632facf5780c83fca52f505b028a4 100644 --- a/test/optimizer/test_tflow_optim.py +++ b/test/optimizer/test_tflow_optim.py @@ -19,7 +19,7 @@ import sys import warnings -from typing import Iterator, Type +from typing import Iterator, Type, Union import numpy as np import pytest @@ -115,6 +115,13 @@ def fix_adam_epsilon( return module +def to_numpy(tensor: Union[tf.Tensor, tf.IndexedSlices]) -> np.ndarray: + """Convert a tensorflow Tensor or IndexedSlices to numpy.""" + if isinstance(tensor, tf.IndexedSlices): + return tensor.values.numpy() + return tensor.numpy() + + @pytest.fixture(name="framework") def framework_fixture(): """Fixture to ensure 'TensorflowOptiModule' only receives tf gradients.""" @@ -157,7 +164,7 @@ class TestTensorflowOptiModule(OptiModuleTestSuite): grads_dec = optim_dec.run(gradients).coefs # Assert that the output gradients are (nearly) identical. assert all( - np.allclose(grads_tfk[key].numpy(), grads_dec[key].numpy()) + np.allclose(to_numpy(grads_tfk[key]), to_numpy(grads_dec[key])) for key in gradients.coefs )