Merged
requested to merge feature/472-create-model-abstraction-for-declearn-integration into develop
MR description
Implements model abstraction for pytorch and scikit learn Closes #472 (closed).
Here a list of points that i found difficult to implement (and that need a closer look).
- BaseSklearnModel:
- batch_size computation / reset that is done internally
- sklearn model n_iter attribute increment/decrement (hidden through methods)
- apply_updates method: slightly modified from the one of the poc, that adds gradients instead of changing them, in the same spirit as the
apply_updates
of pytorch. Check if computation is correct - Toolbox classes, that implement some method using multiple inheritance
- TorchModel
- method train not implemented: should we compute loss in this method ?
- is get_gradients method correct?
- saving state of the TorchModel: initial parameters are saved in a specific attribute:
- make sure it handles frozen layers
More broadly speaking:
- appropriate raise of exceptions: some exceptions may not be caught
- appropriate naming of variables
- create multiple modules for each classes (ie one class per file) rather than having 3 classes in a file.
Developer Certificate Of Origin (DCO)
By opening this merge request, you agree the Developer Certificate of Origin (DCO)
This DCO essentially means that:
- you offer the changes under the same license agreement as the project, and
- you have the right to do that,
- you did not steal somebody else’s work.
License
Project code files should begin with these comment lines to help trace their origin:
# This file is originally part of Fed-BioMed
# SPDX-License-Identifier: Apache-2.0
Code files can be reused from another project with a compatible non-contaminating license.
They shall retain the original license and copyright mentions.
The CREDIT.md
file and credit/
directory shall be completed and updated accordingly.
Guidelines for MR review
General:
- give a glance to DoD
- check coding rules and coding style
- check docstrings (eg run
tests/docstrings/check_docstrings
)
Specific to some cases:
- update all conda envs consistently (
development
andvpn
, Linux and MacOS) - if modified researcher (eg new attributes in classes) check if breakpoint needs update (
breakpoint
/load_breakpoint
inExperiment()
,save_state
/load_state
in aggregators, strategies, secagg, etc.)