diff --git a/declearn/model/__init__.py b/declearn/model/__init__.py index 08d9cdb897d87c7572a1eb21499697ff53d4b0d1..00b1dc64623089e1032f58d44da3020caee816c9 100644 --- a/declearn/model/__init__.py +++ b/declearn/model/__init__.py @@ -19,8 +19,25 @@ This declearn submodule provides with: * Model and Vector abstractions, used as an API to design FL algorithms -* Submodules implementing interfaces to curretnly supported frameworks -and models. +* Submodules implementing interfaces to various frameworks and models. + +The automatically-imported submodules implemented here are: +* api: Model and Vector abstractions' defining module. + - Model: abstract API to interface framework-specific models. + - Vector: abstract API for data tensors containers. +* sklearn: scikit-learn based or oriented tools + - NumpyVector: Vector for numpy array data structures. + - SklearnSGDModel: Model for scikit-learn's SGDClassifier and SGDRegressor. + +The optional-dependency-based submodules that may be manually imported are: +* tensorflow: tensorflow-interfacing tools + - TensorflowModel: Model to wrap any tensorflow-keras Layer model. + - TensorflowOptiModule: Hacky OptiModule to wrap a keras Optimizer. + - TensorflowVector: Vector for tensorflow Tensor and IndexedSlices. +* torch: pytorch-interfacing tools + - TorchModel: Model to wrap any torch Module model. + - TorchOptiModule: Hacky OptiModule to wrap a torch Optimizer. + - TorchVector: Vector for torch Tensor objects. """ from . import api diff --git a/declearn/model/api/__init__.py b/declearn/model/api/__init__.py index ab147758c5187ff8c62a507d3c63824424045646..e78ce0b7eb11713488970743fef75676fcbb21ef 100644 --- a/declearn/model/api/__init__.py +++ b/declearn/model/api/__init__.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Model Vector abstractions submodule.""" +"""Model and Vector abstractions submodule.""" from ._vector import Vector, register_vector_type from ._model import Model diff --git a/declearn/model/sklearn/_np_vec.py b/declearn/model/sklearn/_np_vec.py index 65ac351d93904c3375f6d2244f8b271f3a067eb1..49da62cfb0e0e42ad69976c8a225c990aff805f9 100644 --- a/declearn/model/sklearn/_np_vec.py +++ b/declearn/model/sklearn/_np_vec.py @@ -25,6 +25,7 @@ from typing_extensions import Self # future: import from typing (Py>=3.11) from declearn.model.api._vector import Vector, register_vector_type + __all__ = [ "NumpyVector", ] diff --git a/declearn/model/tensorflow/_vector.py b/declearn/model/tensorflow/_vector.py index d8f2c87723f3214bd64e8a8e90cf6b823dd1d6d9..bc73d25c0baf4de2dd2af6f7a717132c8190ba1d 100644 --- a/declearn/model/tensorflow/_vector.py +++ b/declearn/model/tensorflow/_vector.py @@ -37,6 +37,11 @@ from declearn.model.tensorflow.utils import ( from declearn.utils import get_device_policy +__all__ = [ + "TensorflowVector", +] + + @register_vector_type(tf.Tensor, EagerTensor, tf.IndexedSlices) class TensorflowVector(Vector): """Vector subclass to store tensorflow tensors. diff --git a/declearn/model/torch/_vector.py b/declearn/model/torch/_vector.py index 403efbc57a8266f7396dfaf6b2f6e3c87a733ce3..662aaa100fde401335d9f80657f32b505986614e 100644 --- a/declearn/model/torch/_vector.py +++ b/declearn/model/torch/_vector.py @@ -29,6 +29,11 @@ from declearn.model.torch.utils import select_device from declearn.utils import get_device_policy +__all__ = [ + "TorchVector", +] + + @register_vector_type(torch.Tensor) class TorchVector(Vector): """Vector subclass to store PyTorch tensors.