From 2beb9f2eb1614a8d3ae593b51f2d0c3ddd1ff921 Mon Sep 17 00:00:00 2001
From: Paul Andrey <paul.andrey@inria.fr>
Date: Mon, 21 Nov 2022 16:57:33 +0100
Subject: [PATCH] Move `NumpyVector` from `model.api` to `model.sklearn`
 submodule.

---
 README.md                                  | 2 +-
 declearn/model/api/__init__.py             | 2 --
 declearn/model/sklearn/__init__.py         | 2 ++
 declearn/model/{api => sklearn}/_np_vec.py | 2 +-
 declearn/model/sklearn/_sgd.py             | 3 ++-
 declearn/model/tensorflow/_vector.py       | 5 +++--
 declearn/model/torch/_vector.py            | 5 +++--
 test/model/test_sksgd.py                   | 3 +--
 test/optimizer/optim_testing.py            | 3 ++-
 9 files changed, 15 insertions(+), 12 deletions(-)
 rename declearn/model/{api => sklearn}/_np_vec.py (97%)

diff --git a/README.md b/README.md
index a9dab1fc..398d96a2 100644
--- a/README.md
+++ b/README.md
@@ -343,7 +343,7 @@ new custom concrete implementations inheriting the abstraction.
   - Object: Interface framework-specific data structures.
   - Usage: Wrap and operate on model weights, gradients, updates...
   - Examples:
-    - `declearn.model.api.NumpyVector`
+    - `declearn.model.sklearn.NumpyVector`
     - `declearn.model.tensorflow.TensorflowVector`
     - `declearn.model.torch.TorchVector`
   - Extend: use `declearn.model.api.register_vector_type`
diff --git a/declearn/model/api/__init__.py b/declearn/model/api/__init__.py
index 62bdaeaa..35cecc33 100644
--- a/declearn/model/api/__init__.py
+++ b/declearn/model/api/__init__.py
@@ -3,12 +3,10 @@
 """Model Vector abstractions submodule."""
 
 from ._vector import Vector, register_vector_type
-from ._np_vec import NumpyVector
 from ._model import Model
 
 __all__ = [
     "Model",
-    "NumpyVector",
     "Vector",
     "register_vector_type",
 ]
diff --git a/declearn/model/sklearn/__init__.py b/declearn/model/sklearn/__init__.py
index a1f7b36e..06d3265f 100644
--- a/declearn/model/sklearn/__init__.py
+++ b/declearn/model/sklearn/__init__.py
@@ -7,7 +7,9 @@ and to the way their learning process is implemented, model-
 specific interfaces are required for declearn compatibility.
 
 This module currently implements:
+* NumpyVector: Vector subclass to wrap numpy.ndarray objects
 * SklearnSGDModel: interface to SGD-based linear models
 """
 
+from ._np_vec import NumpyVector
 from ._sgd import SklearnSGDModel
diff --git a/declearn/model/api/_np_vec.py b/declearn/model/sklearn/_np_vec.py
similarity index 97%
rename from declearn/model/api/_np_vec.py
rename to declearn/model/sklearn/_np_vec.py
index 26624dd2..74d989bb 100644
--- a/declearn/model/api/_np_vec.py
+++ b/declearn/model/sklearn/_np_vec.py
@@ -1,6 +1,6 @@
 # coding: utf-8
 
-"""NumpyVector model coefficients container."""
+"""NumpyVector data arrays container."""
 
 from typing import Any, Callable, Dict, Union
 
diff --git a/declearn/model/sklearn/_sgd.py b/declearn/model/sklearn/_sgd.py
index 867721ce..e6a1c305 100644
--- a/declearn/model/sklearn/_sgd.py
+++ b/declearn/model/sklearn/_sgd.py
@@ -11,7 +11,8 @@ from sklearn.linear_model import SGDClassifier, SGDRegressor  # type: ignore
 from typing_extensions import Literal  # future: import from typing (Py>=3.8)
 
 from declearn.data_info import aggregate_data_info
-from declearn.model.api import Model, NumpyVector
+from declearn.model.api import Model
+from declearn.model.sklearn._np_vec import NumpyVector
 from declearn.typing import Batch
 from declearn.utils import register_type
 
diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py
index 44e2a69f..b6d2ac5d 100644
--- a/declearn/model/tensorflow/_vector.py
+++ b/declearn/model/tensorflow/_vector.py
@@ -1,6 +1,6 @@
 # coding: utf-8
 
-"""TensorflowVector gradients container."""
+"""TensorflowVector data arrays container."""
 
 from typing import Any, Callable, Dict, Set, Type, Union
 
@@ -12,7 +12,8 @@ from tensorflow.python.framework.ops import EagerTensor  # type: ignore
 # pylint: enable=no-name-in-module
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
 
-from declearn.model.api import NumpyVector, Vector, register_vector_type
+from declearn.model.api import Vector, register_vector_type
+from declearn.model.sklearn import NumpyVector
 
 
 @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices)
diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py
index e5b2be45..66ee0651 100644
--- a/declearn/model/torch/_vector.py
+++ b/declearn/model/torch/_vector.py
@@ -1,6 +1,6 @@
 # coding: utf-8
 
-"""TorchVector gradients container."""
+"""TorchVector data arrays container."""
 
 from typing import Any, Callable, Dict, Set, Tuple, Type
 
@@ -8,7 +8,8 @@ import numpy as np
 import torch
 from typing_extensions import Self  # future: import from typing (Py>=3.11)
 
-from declearn.model.api import NumpyVector, Vector, register_vector_type
+from declearn.model.api import Vector, register_vector_type
+from declearn.model.sklearn import NumpyVector
 
 
 @register_vector_type(torch.Tensor)
diff --git a/test/model/test_sksgd.py b/test/model/test_sksgd.py
index 61d1ad42..67ff9e64 100644
--- a/test/model/test_sksgd.py
+++ b/test/model/test_sksgd.py
@@ -10,8 +10,7 @@ import pytest
 from scipy.sparse import csr_matrix  # type: ignore
 from sklearn.linear_model import SGDClassifier, SGDRegressor  # type: ignore
 
-from declearn.model.api import NumpyVector
-from declearn.model.sklearn import SklearnSGDModel
+from declearn.model.sklearn import NumpyVector, SklearnSGDModel
 from declearn.typing import Batch
 
 # dirty trick to import from `model_testing.py`;
diff --git a/test/optimizer/optim_testing.py b/test/optimizer/optim_testing.py
index 1fcba856..0122f989 100644
--- a/test/optimizer/optim_testing.py
+++ b/test/optimizer/optim_testing.py
@@ -15,7 +15,8 @@ import torch
 from numpy.typing import ArrayLike
 from typing_extensions import Literal  # future: import from typing (Py>=3.8)
 
-from declearn.model.api import NumpyVector, Vector
+from declearn.model.api import Vector
+from declearn.model.sklearn import NumpyVector
 from declearn.model.tensorflow import TensorflowVector
 from declearn.model.torch import TorchVector
 from declearn.optimizer.modules import OptiModule
-- 
GitLab