Pjobic/refactoring/loss plugin
New plugin to dynamically use different loss functions.
How to use
in .yml
, just put something like that:
loss_config: # concerns __init__
loss_name: "default_loss_funcs" # the one already in use
some_other_init_parameters: 0.0
loss_params: # concerns forward
some_params_for_the_forward_pass: "blabla"
or
# nothing, By default, it uses `default_loss_funcs` like it is already doing in the master branch.
or
loss_config:
loss_name: "scenario_wise_loss" # it will use it automatically
Need to TODO
- dynamically have these
loss_params
like it would be in the code. e.g.:
loss_params:
learned_params: "self.learned_params" # automatically put self.learned_params from ModelTrainer.attr to loss_params dict
- get rid of
errors_mean
anderrors_grouped
, like adding a pluging for visualization/logging bcs it is specific to someone's task and should not be handled in training loop. (as you can see myscenario_wise_loss
do not compute these errors and should not)