Find a safe way to share a Torch Module
Description of the issue
The current implementation of the (de)serialization of a Torch model (i.e. a torch.nn.Module
) uses the torch.load
/ torch.save
functions , which rely on pickle and are therefore harshly unsafe (see the torch docs). However, no satisfying alternative has been found until now to avoid using these functions. As a result, a warning is emitted each time declearn.model.torch.TorchModel.get_config
is called.
The key issue is that torch.load
/ torch.save
uses pickle. A secondary issue is that they save both the model's architecture and its weights, which is heavier that required and redundant with the get_weights
/ set_weights
API.
Discarded solutions
This section documents potential solutions that were looked up but were unsatisfactory.
- Safetensors:
- @nbigaud suggested that safetensors might be a solution.
- However, this library only tackles Tensors (i.e. weights or data) saving, not Module saving. (And it does so in quite a similar way as what we do in declearn).
- TorchScript:
- Torch provides with TorchScript, that enables translating Modules into a cross-platform format.
- However,
torch.jit.save
still produces a zip file with python code and pickle files, exposing the same issues astorch.load
.
- ONNX:
- Torch provides with ONNX conversion tools, that enables exporting a Module to an ONNX model specification, that may be exported into a text or json file.
- This could be a nice way to save and share the model's architecture in a safe way. Additionally, custom operators' registration is possible, which could be used to export user-defined layers when needed.
- However, Torch does not enable reloading a model from ONNX into a Module. Hence this could be used to translate a Torch Module into another framework (e.g. TensorFlow), but not to send a Torch models to Torch-using clients.