diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index 1a4c099054032375e407e0e65706c4fd89d8f944..d8f2c87723f3214bd64e8a8e90cf6b823dd1d6d9 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Optional, Set, Type, Union # fmt: off +import numpy as np import tensorflow as tf # type: ignore # false-positive; pylint: disable=no-name-in-module from tensorflow.python.framework.ops import EagerTensor # type: ignore @@ -155,7 +156,7 @@ class TensorflowVector(Vector): val = cls._pack_tensor(tensor.values) ind = cls._pack_tensor(tensor.indices) return ["slices", val, ind] - return tensor.numpy() + return np.array(tensor.numpy()) @classmethod def _unpack_tensor( diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py index 6b2a19e2d05ea79e8e83c4206153c6d614e3e735..403efbc57a8266f7396dfaf6b2f6e3c87a733ce3 100644 --- a/declearn/model/torch/_vector.py +++ b/declearn/model/torch/_vector.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, Optional, Set, Tuple, Type +import numpy as np import torch from typing_extensions import Self # future: import from typing (Py>=3.11) @@ -125,7 +126,9 @@ class TorchVector(Vector): def pack( self, ) -> Dict[str, Any]: - return {key: tns.cpu().numpy() for key, tns in self.coefs.items()} + return { + key: np.array(tns.cpu().numpy()) for key, tns in self.coefs.items() + } @classmethod def unpack(