Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ec4f3e17 authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Enhance support for 'tf.IndexedSlices' in 'TensorflowVector'.

* Implement a public util to wrap tensorflow operations in order to preserve
  `tf.IndexedSlices` structures and run appropriate computations with them.
* Deploy the former wrapper to cover all usual operations in the backend of
  `TensorflowVector`.
* Refactor the use of the device-placement-handling wrapper in the backend of
  `TensorflowVector` to reduce runtime overheads and factor it with the new
  indexed-slices-handling wrapper.
parent 71debe68
No related branches found
No related tags found
1 merge request!33Enhance support for 'tf.IndexedSlices' in 'TensorflowVector'.
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
"""TensorflowVector data arrays container.""" """TensorflowVector data arrays container."""
from typing import Any, Callable, Dict, Optional, Set, Type, Union import warnings
from typing import Any, Callable, Dict, Optional, Set, Type, TypeVar, Union
# fmt: off # fmt: off
import numpy as np import numpy as np
...@@ -31,6 +32,7 @@ from typing_extensions import Self # future: import from typing (Py>=3.11) ...@@ -31,6 +32,7 @@ from typing_extensions import Self # future: import from typing (Py>=3.11)
from declearn.model.api import Vector, register_vector_type from declearn.model.api import Vector, register_vector_type
from declearn.model.sklearn import NumpyVector from declearn.model.sklearn import NumpyVector
from declearn.model.tensorflow.utils import ( from declearn.model.tensorflow.utils import (
add_indexed_slices_support,
preserve_tensor_device, preserve_tensor_device,
select_device, select_device,
) )
...@@ -42,6 +44,33 @@ __all__ = [ ...@@ -42,6 +44,33 @@ __all__ = [
] ]
TensorT = TypeVar("TensorT", tf.Tensor, tf.IndexedSlices)
def enhance_tf_op(
tf_op: Callable[[tf.Tensor, Any], tf.Tensor],
inplc: bool = False,
) -> Callable[[TensorT, Any], TensorT]:
"""Wrap up a tensorflow operation to preserve IndexedSlices and device."""
func = add_indexed_slices_support(preserve_tensor_device(tf_op), inplc)
setattr(func, "_pre_wrapped", True)
return func
# Wrap up base tensorflow operations to add support for IndexedSlices
# inputs and preserve tensor's device-placement
tf_op_add = enhance_tf_op(tf.add)
tf_op_sub = enhance_tf_op(tf.subtract)
tf_op_mul = enhance_tf_op(tf.multiply)
tf_op_div = enhance_tf_op(tf.truediv)
tf_op_pow = enhance_tf_op(tf.pow)
tf_op_min = enhance_tf_op(tf.minimum)
tf_op_max = enhance_tf_op(tf.maximum)
tf_op_sign = enhance_tf_op(tf.sign, inplc=True)
tf_op_sqre = enhance_tf_op(tf.square, inplc=True)
tf_op_sqrt = enhance_tf_op(tf.sqrt, inplc=True)
@register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices) @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices)
class TensorflowVector(Vector): class TensorflowVector(Vector):
"""Vector subclass to store tensorflow tensors. """Vector subclass to store tensorflow tensors.
...@@ -52,9 +81,14 @@ class TensorflowVector(Vector): ...@@ -52,9 +81,14 @@ class TensorflowVector(Vector):
two TensorflowVector with similar specifications). two TensorflowVector with similar specifications).
Note that support for IndexedSlices is implemented, as these are a Note that support for IndexedSlices is implemented, as these are a
common type for auto-differentiated gradients. common type for auto-differentiated gradients. When using built-in
operators and methods, these structures will be preserved, unless
Note that this class does not (yet?) support special tensor types densification is required (e.g. when summing with a dense tensor).
When using `TensorflowVector.apply_func` directly, support for the
IndexedSlices' preservation should be added manually, typically by
using `declearn.model.tensorflow.utils.add_indexed_slices_support`.
Note that this class does not currently support special tensor types
such as SparseTensor or RaggedTensor. such as SparseTensor or RaggedTensor.
Use `vector.coefs` to access the stored coefficients. Use `vector.coefs` to access the stored coefficients.
...@@ -84,23 +118,23 @@ class TensorflowVector(Vector): ...@@ -84,23 +118,23 @@ class TensorflowVector(Vector):
@property @property
def _op_add(self) -> Callable[[Any, Any], Any]: def _op_add(self) -> Callable[[Any, Any], Any]:
return tf.add return tf_op_add
@property @property
def _op_sub(self) -> Callable[[Any, Any], Any]: def _op_sub(self) -> Callable[[Any, Any], Any]:
return tf.subtract return tf_op_sub
@property @property
def _op_mul(self) -> Callable[[Any, Any], Any]: def _op_mul(self) -> Callable[[Any, Any], Any]:
return tf.multiply return tf_op_mul
@property @property
def _op_div(self) -> Callable[[Any, Any], Any]: def _op_div(self) -> Callable[[Any, Any], Any]:
return tf.divide return tf_op_div
@property @property
def _op_pow(self) -> Callable[[Any, Any], Any]: def _op_pow(self) -> Callable[[Any, Any], Any]:
return tf.pow return tf_op_pow
@property @property
def compatible_vector_types(self) -> Set[Type[Vector]]: def compatible_vector_types(self) -> Set[Type[Vector]]:
...@@ -118,7 +152,8 @@ class TensorflowVector(Vector): ...@@ -118,7 +152,8 @@ class TensorflowVector(Vector):
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Self: ) -> Self:
func = preserve_tensor_device(func) if not getattr(func, "_pre_wrapped", False):
func = preserve_tensor_device(func)
return super().apply_func(func, *args, **kwargs) return super().apply_func(func, *args, **kwargs)
def _apply_operation( def _apply_operation(
...@@ -126,7 +161,8 @@ class TensorflowVector(Vector): ...@@ -126,7 +161,8 @@ class TensorflowVector(Vector):
other: Any, other: Any,
func: Callable[[Any, Any], Any], func: Callable[[Any, Any], Any],
) -> Self: ) -> Self:
func = preserve_tensor_device(func) if not getattr(func, "_pre_wrapped", False):
func = preserve_tensor_device(func)
return super()._apply_operation(other, func) return super()._apply_operation(other, func)
def dtypes( def dtypes(
...@@ -160,7 +196,8 @@ class TensorflowVector(Vector): ...@@ -160,7 +196,8 @@ class TensorflowVector(Vector):
if isinstance(tensor, tf.IndexedSlices): if isinstance(tensor, tf.IndexedSlices):
val = cls._pack_tensor(tensor.values) val = cls._pack_tensor(tensor.values)
ind = cls._pack_tensor(tensor.indices) ind = cls._pack_tensor(tensor.indices)
return ["slices", val, ind] shp = cls._pack_tensor(tensor.dense_shape)
return ["slices", val, ind, shp]
return np.array(tensor.numpy()) return np.array(tensor.numpy())
@classmethod @classmethod
...@@ -172,7 +209,8 @@ class TensorflowVector(Vector): ...@@ -172,7 +209,8 @@ class TensorflowVector(Vector):
if isinstance(data, list) and (data[0] == "slices"): if isinstance(data, list) and (data[0] == "slices"):
val = cls._unpack_tensor(data[1]) val = cls._unpack_tensor(data[1])
ind = cls._unpack_tensor(data[2]) ind = cls._unpack_tensor(data[2])
return tf.IndexedSlices(val, ind) shp = cls._unpack_tensor(data[3])
return tf.IndexedSlices(val, ind, shp)
try: try:
return tf.convert_to_tensor(data) return tf.convert_to_tensor(data)
except TypeError as exc: except TypeError as exc:
...@@ -209,29 +247,40 @@ class TensorflowVector(Vector): ...@@ -209,29 +247,40 @@ class TensorflowVector(Vector):
return tf.reduce_all(t_a == t_b).numpy() return tf.reduce_all(t_a == t_b).numpy()
def sign(self) -> Self: def sign(self) -> Self:
return self.apply_func(tf.sign) return self.apply_func(tf_op_sign)
def minimum( def minimum(
self, self,
other: Any, other: Any,
) -> Self: ) -> Self:
if isinstance(other, Vector): if isinstance(other, Vector):
return self._apply_operation(other, tf.minimum) return self._apply_operation(other, tf_op_min)
return self.apply_func(tf.minimum, other) return self.apply_func(tf_op_min, other)
def maximum( def maximum(
self, self,
other: Any, other: Any,
) -> Self: ) -> Self:
if isinstance(other, Vector): if isinstance(other, Vector):
return self._apply_operation(other, tf.maximum) return self._apply_operation(other, tf_op_max)
return self.apply_func(tf.maximum, other) return self.apply_func(tf_op_max, other)
def sum( def sum(
self, self,
axis: Optional[int] = None, axis: Optional[int] = None,
keepdims: bool = False, keepdims: bool = False,
) -> Self: ) -> Self:
if keepdims or (axis is not None):
if any(
isinstance(x, tf.IndexedSlices) for x in self.coefs.values()
):
warnings.warn(
"Calling `TensorflowVector.sum()` with non-default "
"arguments and tf.IndexedSlices coefficients might "
"result in unexpected outputs, due to the latter "
"being converted to their dense counterpart.",
category=RuntimeWarning,
)
return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims) return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims)
def __pow__( def __pow__(
...@@ -242,7 +291,7 @@ class TensorflowVector(Vector): ...@@ -242,7 +291,7 @@ class TensorflowVector(Vector):
# than tf.pow as results tend to differ for small values. # than tf.pow as results tend to differ for small values.
if isinstance(other, (int, float)): if isinstance(other, (int, float)):
if other == 2: if other == 2:
return self.apply_func(tf.square) return self.apply_func(tf_op_sqre)
if other == 0.5: if other == 0.5:
return self.apply_func(tf.sqrt) return self.apply_func(tf_op_sqrt)
return super().__pow__(other) return super().__pow__(other)
...@@ -28,6 +28,10 @@ GPU/CPU backing device management utils: ...@@ -28,6 +28,10 @@ GPU/CPU backing device management utils:
Loss function management utils: Loss function management utils:
* build_keras_loss: * build_keras_loss:
Type-check, deserialize and/or wrap a keras loss into a Loss object. Type-check, deserialize and/or wrap a keras loss into a Loss object.
Better support for sparse tensor structures:
* add_indexed_slices_support:
Run a function on a pair of tensors, adding support for IndexedSlices.
""" """
from ._gpu import ( from ._gpu import (
...@@ -36,3 +40,4 @@ from ._gpu import ( ...@@ -36,3 +40,4 @@ from ._gpu import (
select_device, select_device,
) )
from ._loss import build_keras_loss from ._loss import build_keras_loss
from ._slices import add_indexed_slices_support
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to handle tf.IndexedSlices as part of tensor-processing operations."""
import functools
import warnings
from typing import Any, Callable, TypeVar
import numpy as np
import tensorflow as tf # type: ignore
__all__ = [
"add_indexed_slices_support",
]
TensorT = TypeVar("TensorT", tf.Tensor, tf.IndexedSlices)
def apply_func_to_tensor_or_slices(
first: TensorT,
other: Any,
tf_op: Callable[[tf.Tensor, Any], tf.Tensor],
) -> TensorT:
"""Run a function on a pair of tensors, adding support for IndexedSlices.
The intended use of this function is to add support for IndexedSlices
to basic tensorflow operators, such as `tf.add` or `tf.multiply`.
Parameters
----------
first: tf.Tensor or tf.IndexedSlices
Tensor or IndexedSlices data structure.
other: tf.Tensor or tf.IndexedSlices or np.ndarray or float or int
Scalar or data array that needs operating onto `first` via `tf_op`.
tf_op: function(tf.Tensor, any) -> tf.Tensor
Function that operates on a tf.Tensor and another value.
Returns
-------
output: tf.Tensor or tf.IndexedSlices
Result from running `tf_op(first, other)` if first is a tf.Tensor,
or from re-wrapping `tf_op(first.values, other[.values])` into a
tf.IndexedSlices structure if first is such a structure. The only
exception is when operating on IndexedSlices and a full-rank array
or tensor: then a full-rank output is returned, with a warning.
Raises
------
TypeError:
If `first` and `other` are two tf.IndexedSlices with different
shapes or non-zero indices.
If `first` is a tf.IndexedSlices and `func` failed on its values.
"""
slice_inp = isinstance(first, tf.IndexedSlices)
# Case when combining two IndexedSlices objects.
if slice_inp and isinstance(other, tf.IndexedSlices):
if (
(first.dense_shape.ndim == other.dense_shape.ndim)
and tf.reduce_all(first.dense_shape == other.dense_shape)
and (first.indices.shape == other.indices.shape)
and tf.reduce_all(first.indices == other.indices)
):
values = tf_op(first.values, other.values)
return tf.IndexedSlices(values, first.indices, first.dense_shape)
raise TypeError(
f"Cannot apply function {tf_op.__name__} to two IndexedSlices "
"structures with different shapes or indices."
)
# Case when operating into an IndexedSlices object.
if slice_inp:
# Case when operating with a dense tensor (or array) of same shape.
if isinstance(other, (tf.Tensor, np.ndarray)):
if first.shape == other.shape:
warnings.warn(
f"Applying function {tf_op.__name__} to IndexSlices with "
"a full-rank array or tensor results in densifying it.",
RuntimeWarning,
)
return tf_op(tf.convert_to_tensor(first), other)
# Generic case (including mis-shaped tensor, to raise an error).
try:
values = tf_op(first.values, other)
except Exception as exc:
raise TypeError(
f"Failed to apply function {tf_op.__name__} to combine a "
f"{type(other)} object into an IndexedSlices tensor: {exc}."
) from exc
return tf.IndexedSlices(values, first.indices, first.dense_shape)
# All other cases (including right-hand slices that will be converted).
return tf_op(first, other)
def add_indexed_slices_support(
tf_op: Callable[[tf.Tensor, Any], tf.Tensor],
inplc: bool = False,
) -> Callable[[TensorT, Any], TensorT]:
"""Wrap an input function to overload the handling of tf.IndexedSlices.
Parameters
----------
tf_op: function(tf.Tensor, [any]) -> tf.Tensor
Tensor-processing operation that needs wrapping.
inplc: bool, default=False
Whether to replace the second argument of `tf_op` with None.
Use this to transform tensor-processing functions (wich, in
general, have a `name=None` argument) rather than operations.
Returns
-------
func: function(<T>, any) -> <T>, with <T>:(tf.Tensor|tf.IndexedSlices)
Tensor-processing operation that wraps `tf_op` but supports and
preserves tf.IndexedSlices inputs as first (and opt. second)
argument.
Note that in the rare case when func(slices, dense) is called,
the output will be dense, and a RuntimeWarning will be raised.
"""
func = functools.partial(apply_func_to_tensor_or_slices, tf_op=tf_op)
if inplc:
func = functools.partial(func, other=None)
return functools.wraps(tf_op)(func)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Noise-addition modules for DP using cryptographically-strong RNG.""" """Noise-addition modules for DP using cryptographically-strong RNG."""
import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from random import SystemRandom from random import SystemRandom
from typing import Any, ClassVar, Dict, Optional, Tuple from typing import Any, ClassVar, Dict, Optional, Tuple
...@@ -94,7 +95,10 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False): ...@@ -94,7 +95,10 @@ class NoiseModule(OptiModule, metaclass=ABCMeta, register=False):
for key in gradients.coefs for key in gradients.coefs
} }
# Add the sampled noise to the gradients and return them. # Add the sampled noise to the gradients and return them.
return gradients + NumpyVector(noise) # Silence warnings about sparse gradients getting sparsified.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*densifying.*", RuntimeWarning)
return gradients + NumpyVector(noise)
@abstractmethod @abstractmethod
def _sample_noise( def _sample_noise(
......
...@@ -98,6 +98,11 @@ class GradientsTestCase: ...@@ -98,6 +98,11 @@ class GradientsTestCase:
"""Convert an input framework-based structure to a numpy array.""" """Convert an input framework-based structure to a numpy array."""
if isinstance(array, np.ndarray): if isinstance(array, np.ndarray):
return array return array
if self.framework == "tensorflow": # add support for IndexedSlices
tensorflow = importlib.import_module("tensorflow")
if isinstance(array, tensorflow.IndexedSlices):
with tensorflow.device(array.device):
return tensorflow.convert_to_tensor(array).numpy()
return array.numpy() # type: ignore return array.numpy() # type: ignore
@property @property
...@@ -111,9 +116,21 @@ class GradientsTestCase: ...@@ -111,9 +116,21 @@ class GradientsTestCase:
rng = np.random.default_rng(self.seed) rng = np.random.default_rng(self.seed)
shapes = [(64, 32), (32,), (32, 16), (16,), (16, 1), (1,)] shapes = [(64, 32), (32,), (32, 16), (16,), (16, 1), (1,)]
values = [rng.normal(size=shape) for shape in shapes] values = [rng.normal(size=shape) for shape in shapes]
return self.vector_cls( vector = self.vector_cls(
{str(idx): self.convert(value) for idx, value in enumerate(values)} {str(idx): self.convert(value) for idx, value in enumerate(values)}
) )
# In Tensorflow, convert the first gradients to IndexedSlices.
# In this case they are equivalent to dense ones, but this enables
# testing the support for these structures while maintaining the
# possibility to compare outputs' values with other frameworks.
if self.framework == "tensorflow":
tensorflow = importlib.import_module("tensorflow")
vector.coefs["0"] = tensorflow.IndexedSlices(
values=vector.coefs["0"],
indices=tensorflow.range(64),
dense_shape=tensorflow.convert_to_tensor([64, 32]),
)
return vector
@property @property
def mock_ones(self) -> Vector: def mock_ones(self) -> Vector:
......
...@@ -86,6 +86,8 @@ class TensorflowTestCase(ModelTestCase): ...@@ -86,6 +86,8 @@ class TensorflowTestCase(ModelTestCase):
tensor: Any, tensor: Any,
) -> np.ndarray: ) -> np.ndarray:
"""Convert an input tensor to a numpy array.""" """Convert an input tensor to a numpy array."""
if isinstance(tensor, tf.IndexedSlices):
tensor = tf.convert_to_tensor(tensor)
assert isinstance(tensor, tf.Tensor) assert isinstance(tensor, tf.Tensor)
return tensor.numpy() return tensor.numpy()
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
import sys import sys
import warnings import warnings
from typing import Iterator, Type from typing import Iterator, Type, Union
import numpy as np import numpy as np
import pytest import pytest
...@@ -115,6 +115,13 @@ def fix_adam_epsilon( ...@@ -115,6 +115,13 @@ def fix_adam_epsilon(
return module return module
def to_numpy(tensor: Union[tf.Tensor, tf.IndexedSlices]) -> np.ndarray:
"""Convert a tensorflow Tensor or IndexedSlices to numpy."""
if isinstance(tensor, tf.IndexedSlices):
return tensor.values.numpy()
return tensor.numpy()
@pytest.fixture(name="framework") @pytest.fixture(name="framework")
def framework_fixture(): def framework_fixture():
"""Fixture to ensure 'TensorflowOptiModule' only receives tf gradients.""" """Fixture to ensure 'TensorflowOptiModule' only receives tf gradients."""
...@@ -157,7 +164,7 @@ class TestTensorflowOptiModule(OptiModuleTestSuite): ...@@ -157,7 +164,7 @@ class TestTensorflowOptiModule(OptiModuleTestSuite):
grads_dec = optim_dec.run(gradients).coefs grads_dec = optim_dec.run(gradients).coefs
# Assert that the output gradients are (nearly) identical. # Assert that the output gradients are (nearly) identical.
assert all( assert all(
np.allclose(grads_tfk[key].numpy(), grads_dec[key].numpy()) np.allclose(to_numpy(grads_tfk[key]), to_numpy(grads_dec[key]))
for key in gradients.coefs for key in gradients.coefs
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment