Improve support for `tf.IndexedSlices` row-sparse gradients.
The current implementation of TensorflowVector
supports wrapping tf.IndexedSlices
data structures, but the wrapped computations are wrong.
Technical context and problem root
tf.IndexedSlices
are a specific kind of row-sparse tensor commonly used in Tensorflow to wrap gradients resulting from some kind of look-up operations. Using a non-frozen embedding layer is a typical use-case that results in producing such gradients. By default, python operators fail on such structures, as Tensorflow expects that only specific kernels will be used to update them as part of optimizers' backend. Tensorflow operators, such as tf.add
, tf.multiply
, etc. which we use under the hood in TensorflowVector
silently convert tf.IndexedSlices
to full-rank tensors. This not only causes unwanted memory use (as zero-valued rows are created and allocated in memory), but also results in altering zero-valued rows that should be left out of the gradients, which is mathematically wrong.
Proposed solution
We should update the backend of TensorflowVector
to better handle how operations are applied to wrapped tf.IndexedSlices
structures. This could probably be done with a generic wrapper to overload common tensorflow operators. The point is to enable combining similarly-indexed slices, and most importantly applying scalar (or broadcastable-tensor) operations to the slices' values, preserving the sparse structure and leaving zero-valued rows unaltered.
Notionally:
-
op(slices, scalar)
should result inop(slices.values, scalar)
-
op(slices, tensor)
should result inop(slices.values, tensor)
under condition that shapes enable broadcating the tensor -
op(tensor, slices)
should result inop(tensor, tf.convert_to_tensor(slices))
under condition that shapes match -
op(slices, slices)
should result inop(slices.values, slices.values)
under conditions thatslices.indices
andslices.dense_shape
match