From 2a61cd47923c381e4c85bcc42222cfd866e526bf Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Thu, 23 Feb 2023 15:25:17 +0100
Subject: [PATCH] Fix the (un)packing of scalar tensors in 'Vector' subclasses.

---
 declearn/model/tensorflow/_vector.py | 3 ++-
 declearn/model/torch/_vector.py      | 5 ++++-
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py
index 1a4c0990..d8f2c877 100644
--- a/declearn/model/tensorflow/_vector.py
+++ b/declearn/model/tensorflow/_vector.py
@@ -20,6 +20,7 @@
 from typing import Any, Callable, Dict, Optional, Set, Type, Union
 
 # fmt: off
+import numpy as np
 import tensorflow as tf  # type: ignore
 # false-positive; pylint: disable=no-name-in-module
 from tensorflow.python.framework.ops import EagerTensor  # type: ignore
@@ -155,7 +156,7 @@ class TensorflowVector(Vector):
             val = cls._pack_tensor(tensor.values)
             ind = cls._pack_tensor(tensor.indices)
             return ["slices", val, ind]
-        return tensor.numpy()
+        return np.array(tensor.numpy())
 
     @classmethod
     def _unpack_tensor(
diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py
index 6b2a19e2..403efbc5 100644
--- a/declearn/model/torch/_vector.py
+++ b/declearn/model/torch/_vector.py
@@ -19,6 +19,7 @@
 
 from typing import Any, Callable, Dict, Optional, Set, Tuple, Type
 
+import numpy as np
 import torch
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
 
@@ -125,7 +126,9 @@ class TorchVector(Vector):
     def pack(
         self,
     ) -> Dict[str, Any]:
-        return {key: tns.cpu().numpy() for key, tns in self.coefs.items()}
+        return {
+            key: np.array(tns.cpu().numpy()) for key, tns in self.coefs.items()
+        }
 
     @classmethod
     def unpack(
-- 
GitLab