From 2a61cd47923c381e4c85bcc42222cfd866e526bf Mon Sep 17 00:00:00 2001 From: Paul Andrey <paul.andrey@inria.fr> Date: Thu, 23 Feb 2023 15:25:17 +0100 Subject: [PATCH] Fix the (un)packing of scalar tensors in 'Vector' subclasses. --- declearn/model/tensorflow/_vector.py | 3 ++- declearn/model/torch/_vector.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index 1a4c0990..d8f2c877 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Optional, Set, Type, Union # fmt: off +import numpy as np import tensorflow as tf # type: ignore # false-positive; pylint: disable=no-name-in-module from tensorflow.python.framework.ops import EagerTensor # type: ignore @@ -155,7 +156,7 @@ class TensorflowVector(Vector): val = cls._pack_tensor(tensor.values) ind = cls._pack_tensor(tensor.indices) return ["slices", val, ind] - return tensor.numpy() + return np.array(tensor.numpy()) @classmethod def _unpack_tensor( diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py index 6b2a19e2..403efbc5 100644 --- a/declearn/model/torch/_vector.py +++ b/declearn/model/torch/_vector.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Dict, Optional, Set, Tuple, Type +import numpy as np import torch from typing_extensions import Self # future: import from typing (Py>=3.11) @@ -125,7 +126,9 @@ class TorchVector(Vector): def pack( self, ) -> Dict[str, Any]: - return {key: tns.cpu().numpy() for key, tns in self.coefs.items()} + return { + key: np.array(tns.cpu().numpy()) for key, tns in self.coefs.items() + } @classmethod def unpack( -- GitLab