diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 189986b730d1ee3e75ccd723bc49697598371879..c962e8b6a6645feb0b564a4779db95e107bcfb28 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -114,6 +114,7 @@ run_declearn_tests() { commands=( "run_unit_tests $@" "run_integration_tests $@" + "run_torch13_tests $@" ) run_commands "declearn test suite" "${commands[@]}" status=$? @@ -152,6 +153,33 @@ run_integration_tests() { } +run_torch13_tests() { + : ' + Verbosely run Torch 1.13-specific unit tests. + + Install Torch 1.13 at the start of this function, and re-install + torch >=2.0 at the end of it, together with its co-dependencies. + ' + echo "Re-installing torch 1.13 and its co-dependencies." + pip install .[torch1] + if [[ $? -eq 0 ]]; then + echo "Running unit tests for torch 1.13." + command="pytest $@ + --cov --cov-append --cov-report= + test/model/test_torch.py + " + run_command $command + status=$? + else + echo "\e[31mSkipping tests as installation failed.\e[0m" + status=1 + fi + echo "Re-installing torch 2.X and its co-dependencies." + pip install .[torch2] + return $status +} + + main() { if [[ $# -eq 0 ]]; then echo "Missing required positional argument."