Mentions légales du service

Skip to content

Improve GPU support for TensorFlow and Torch

ANDREY Paul requested to merge gpu-support into develop

This MR aims at adding proper support for GPU acceleration of tensor computations.

It introduces additions to the current APIs to enable explicitly selecting the kind of device (CPU or GPU) backing computations when using a compatible framework, such as TensorFlow or Torch - or Jax/Haiku in the future.

At the moment:

  • Since Torch operates on CPU by default, no GPU acceleration is available in Torch with declearn.
  • Since Tensorflow automatically uses a GPU when one is available, GPU acceleration is probably used in Tensorflow with declearn, but it has neither been properly tested, documented nor modularized.

This MR will:

  • Add a dedicated method to the Model API to select the kind of device backing computations.
  • Revise the backend of TorchModel and TensorflowModel to properly manage the kind of device being used.
  • Possibly extend the Vector API and/or revise the backend of TorchVector and TensorflowVector to ensure computations happening within optimizer plug-ins, aggregators, etc. preserve proper device placement.

Tasklist:

  • Add GPU support for Torch
    • Experiment with Torch to work out the device-management tools and requirements.
    • Implement device-management in TorchModel and TorchVector.
    • Ensure device-based errors are avoided and/or caught-and-fixed as part of gradients' and updates' processing.
    • Ensure that a TorchModel placed on GPU remains there, and all computations are placed on the GPU.
    • Ensure device placement is optimal as part of gradients' and updates' processing.
  • Add GPU support for TensorFlow
    • Experiment with TensorFlow to work out the device-management tools and requirements.
    • Implement device-management in TensorflowModel and TensorflowVector.
    • Ensure device-based errors are avoided and/or caught-and-fixed as part of gradients' and updates' processing.
  • Abstract GPU support API changes & Add warnings in SklearnSGDModel about the lack of GPU support.
  • Implement mechanisms to have clients specify the (kind of) device they want to use for computations.
  • Enhance the test suite to take device-placement into account and/or write GPU-specific tests.
  • Document GPU support and device policy management as part of the README.

Closes #11 (closed) - See that issue for notes on how the frameworks work and implementation choices

Edited by ANDREY Paul

Merge request reports