Implement framework-specific OptiModule subclasses.
This MR implements a couple of framework-specific OptiModule
subclasses, that provide with the ability to wrap optimizers from TensorFlow or Torch (respectively) into plug-in modules and therefore integrate them into a declearn optimization pipeline.
Rationale
While the aim of the declearn Optimizer API is to provide with framework-agnostic combinable optimization algorithms, and is designed to be extendable by end-users that may require to use some algorithms that are not directly part of the package, there is a point in enabling end-users to use framework-specific, third-party-provided algorithms: whether to limit the amount of declearn-specific syntax they need to assimilate (even though some things are still required to be set through the declearn Optimizer
), avoid deterring out users that do not want to go as far as implementing their own plug-ins for yet-unavailable algorithms (such as Adan, Lamb, etc.), or enable quick-testing that an algorithm is of interest prior to considering integrating is into declearn.
It is also relatively easy to hack around the OptiModule
API to wrap up framework-specific tools (at least for TensorFlow and PyTorch), but easy as well to get confused in doing so for one that does not have an intricate knowledge of both the declearn and the targeted framework's APIs. Therefore, providing with a documented version of such a hack in declearn can be seen as a way to provide end-users with both a working solution, a clear documentation of its limitations, and opinions as to why it might not be the proper way to go around things.
Implementation
I implemented two relatively-similar OptiModule
subclasses, respectively TensorflowOptiModule
and TorchOptiModule
, under the dedicated declearn.model.<framework>
manual-import-only modules (more on that choice in the Discussion section). Due to differences in the two frameworks' APIs, the Torch module expects to be provided with a torch optimizer class and its non-default kwargs, while the Tensorflow one may be provided with an instance directly (or a configuration dict). When using default hyper-parameters for the wrapped framework-specific algorithm, the specification may be as simple as providing a single string (naming the target class and/or providing with its import string).
Both modules use the same kind of hack:
- force the wrapped optimizer to have a constant base learning rate of 1.0
- set up artificial, zero-valued framework-specific objects that mimic the model's parameters
- forcefully bind the input gradients with these artificial parameters
- have the wrapped optimizer perform its step (that modifies the parameters in place)
- collect and return the resulting values, which are the desired updates (up to their sign)
Finally, I wrote some dedicated unit tests for both classes, that are placed in dedicated test files. Note that I did not go so far as optimizing the redundancies between both files, nor implemented checks that the framework-specific submodules are available (which could be a nice thing to add in the near future, enabling end-users to run tests in environments where optional dependencies are not (all) installed).
Discussion points
Import paths of the modules
Logically, one would expect all OptiModule
subclasses to be implemented under the declearn.optimizer.modules
submodule. However, I went for implementing the hacky modules under declearn.model.tensorflow
and declearn.model.torch
respectively, for a couple of (arguable reasons):
- This enables importing all framework-specific tools together at once (and, reciprocally, avoiding to import optional third-party dependencies when it is not required by the end-user).
- This might also help underlining that these are not standard optimizer modules, but hacky framework-specific tools that should be used sparingly.
On the other hand, it could be preferable to move the modules to declearn.optimizer.modules
- possibly not importing them by default, but triggering their import as part of declearn.model.<framework>.__init__.py
. As I write it I am not entirely convinced that this is a good idea, but perhaps it could be considered in a broader discussion about where to place and how to import framework-specific tools (e.g. adding some utils to check their availability and import them: something like declearn.utils.import_framework()
?).
Imports' validation in Torch
For the Torch module, I implemented some (optional, but enabled by default and forced in de-serialization context) user-prompt mechanisms to validate or refuse to run some import statements. This is not optimal, but I think it would be wrong to enable server-emitted instructions to trigger arbitrary import statements without notifying client-side users and expecting their validation or denial.
Any external thoughts on this issue, the implemented solution's practicality and efficiency, and possible alternatives are most welcome.
Edit: this was updated after discussion to bypass user validation when importing from the 'torch' module itself, which should cover most cases. We could consider constituting an allow-list of trusted packages in the future based on users' feedback.
Bypassing the declearn Optimizer API
An alternative to the current implementation could be to go as far as enabling end-users to replace the declearn Optimizer with a framework-specific one (including the use of framework-specific learning rate schedulers, weight decay implementations, etc.). To me this would go against the philosophy of our framework, and be hurtful as to FL-specific algorithms (such as FedProx, Scaffold, etc.) and DP features, hence my attempt at wrapping up third-party optimizers as plug-ins that are expected to preserve combinability with other declearn-provided tools.
Enabling cross-framework optimizer use
The current implementation enforces a strict framework-specificity of the hacky plug-ins. However, if we go for making conversion to and from numpy a requirement of the Vector
API (an idea which has been under discussion for a while now), we could easily use those features to enable using torch or tensorflow optimizers to process input gradients from any framework. I do not think that we should do or discuss this as a priority, but I write it down for future reference. If we went for this, I would be in favor or refactoring the two plug-ins' code, probably using some intermediate abstraction layer.
reset
method
The I implemented a reset
method for both modules, that enables changing the expected input specifications and resetting any internal state variables. This might be useless, or on the opposite, be considered as a possible extension of the current OptiModule
API (possibly together with the addition of get_state
and set_state
methods).