Mentions légales du service

Skip to content

Implement framework-specific OptiModule subclasses.

ANDREY Paul requested to merge hacky-optim into develop

This MR implements a couple of framework-specific OptiModule subclasses, that provide with the ability to wrap optimizers from TensorFlow or Torch (respectively) into plug-in modules and therefore integrate them into a declearn optimization pipeline.

Rationale

While the aim of the declearn Optimizer API is to provide with framework-agnostic combinable optimization algorithms, and is designed to be extendable by end-users that may require to use some algorithms that are not directly part of the package, there is a point in enabling end-users to use framework-specific, third-party-provided algorithms: whether to limit the amount of declearn-specific syntax they need to assimilate (even though some things are still required to be set through the declearn Optimizer), avoid deterring out users that do not want to go as far as implementing their own plug-ins for yet-unavailable algorithms (such as Adan, Lamb, etc.), or enable quick-testing that an algorithm is of interest prior to considering integrating is into declearn.

It is also relatively easy to hack around the OptiModule API to wrap up framework-specific tools (at least for TensorFlow and PyTorch), but easy as well to get confused in doing so for one that does not have an intricate knowledge of both the declearn and the targeted framework's APIs. Therefore, providing with a documented version of such a hack in declearn can be seen as a way to provide end-users with both a working solution, a clear documentation of its limitations, and opinions as to why it might not be the proper way to go around things.

Implementation

I implemented two relatively-similar OptiModule subclasses, respectively TensorflowOptiModule and TorchOptiModule, under the dedicated declearn.model.<framework> manual-import-only modules (more on that choice in the Discussion section). Due to differences in the two frameworks' APIs, the Torch module expects to be provided with a torch optimizer class and its non-default kwargs, while the Tensorflow one may be provided with an instance directly (or a configuration dict). When using default hyper-parameters for the wrapped framework-specific algorithm, the specification may be as simple as providing a single string (naming the target class and/or providing with its import string).

Both modules use the same kind of hack:

  • force the wrapped optimizer to have a constant base learning rate of 1.0
  • set up artificial, zero-valued framework-specific objects that mimic the model's parameters
  • forcefully bind the input gradients with these artificial parameters
  • have the wrapped optimizer perform its step (that modifies the parameters in place)
  • collect and return the resulting values, which are the desired updates (up to their sign)

Finally, I wrote some dedicated unit tests for both classes, that are placed in dedicated test files. Note that I did not go so far as optimizing the redundancies between both files, nor implemented checks that the framework-specific submodules are available (which could be a nice thing to add in the near future, enabling end-users to run tests in environments where optional dependencies are not (all) installed).

Discussion points

Import paths of the modules

Logically, one would expect all OptiModule subclasses to be implemented under the declearn.optimizer.modules submodule. However, I went for implementing the hacky modules under declearn.model.tensorflow and declearn.model.torch respectively, for a couple of (arguable reasons):

  • This enables importing all framework-specific tools together at once (and, reciprocally, avoiding to import optional third-party dependencies when it is not required by the end-user).
  • This might also help underlining that these are not standard optimizer modules, but hacky framework-specific tools that should be used sparingly.

On the other hand, it could be preferable to move the modules to declearn.optimizer.modules - possibly not importing them by default, but triggering their import as part of declearn.model.<framework>.__init__.py. As I write it I am not entirely convinced that this is a good idea, but perhaps it could be considered in a broader discussion about where to place and how to import framework-specific tools (e.g. adding some utils to check their availability and import them: something like declearn.utils.import_framework()?).

Imports' validation in Torch

For the Torch module, I implemented some (optional, but enabled by default and forced in de-serialization context) user-prompt mechanisms to validate or refuse to run some import statements. This is not optimal, but I think it would be wrong to enable server-emitted instructions to trigger arbitrary import statements without notifying client-side users and expecting their validation or denial.

Any external thoughts on this issue, the implemented solution's practicality and efficiency, and possible alternatives are most welcome.

Edit: this was updated after discussion to bypass user validation when importing from the 'torch' module itself, which should cover most cases. We could consider constituting an allow-list of trusted packages in the future based on users' feedback.

Bypassing the declearn Optimizer API

An alternative to the current implementation could be to go as far as enabling end-users to replace the declearn Optimizer with a framework-specific one (including the use of framework-specific learning rate schedulers, weight decay implementations, etc.). To me this would go against the philosophy of our framework, and be hurtful as to FL-specific algorithms (such as FedProx, Scaffold, etc.) and DP features, hence my attempt at wrapping up third-party optimizers as plug-ins that are expected to preserve combinability with other declearn-provided tools.

Enabling cross-framework optimizer use

The current implementation enforces a strict framework-specificity of the hacky plug-ins. However, if we go for making conversion to and from numpy a requirement of the Vector API (an idea which has been under discussion for a while now), we could easily use those features to enable using torch or tensorflow optimizers to process input gradients from any framework. I do not think that we should do or discuss this as a priority, but I write it down for future reference. If we went for this, I would be in favor or refactoring the two plug-ins' code, probably using some intermediate abstraction layer.

The reset method

I implemented a reset method for both modules, that enables changing the expected input specifications and resetting any internal state variables. This might be useless, or on the opposite, be considered as a possible extension of the current OptiModule API (possibly together with the addition of get_state and set_state methods).

Edited by ANDREY Paul

Merge request reports