diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index f347f212b97c37c3176401004f39d4d09aadf89a..1a4c099054032375e407e0e65706c4fd89d8f944 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -227,3 +227,16 @@ class TensorflowVector(Vector): keepdims: bool = False, ) -> Self: return self.apply_func(tf.reduce_sum, axis=axis, keepdims=keepdims) + + def __pow__( + self, + other: Any, + ) -> Self: + # For square and square root, use dedicated functions rather + # than tf.pow as results tend to differ for small values. + if isinstance(other, (int, float)): + if other == 2: + return self.apply_func(tf.square) + if other == 0.5: + return self.apply_func(tf.sqrt) + return super().__pow__(other)