How to pass **kwargs argument to networks ?
Issue
I have a network with a lot of parameters:
MixAttSPIDNA Parameters codes:
class MixAttSPIDNA(nn.Module):
def __init__(self, n_outputs, n_blocks, n_features, device, **kwargs):
self.blocks = nn.ModuleList([MixAttSPIDNABlock(n_features, n_outputs, **kwargs) for i in range(n_blocks)])
class MixAttSPIDNABlock(nn.Module):
def __init__(self, n_features, n_outputs, **kwargs):
self.AttHub = AttHub(n_features, n_outputs, **kwargs)
class AttHub(nn.Module):
def __init__(self, n_features,
third_n_features_out,
first_n_features_in=50,
first_n_heads_in=1,
first_n_hubs=10,
first_n_features_out=50,
first_n_heads_out=5,
second_n_features_in=25,
second_n_heads_in=1,
second_n_hubs=7,
second_n_features_out=25,
third_n_features_in=25,
third_n_heads_in=1):
I do not want to specify each time all these 10-20 parameters, especially if someone else use this network and does not know these parameters.
However, If I do it like that, the parameters with **kwargs
are not working because of:
KeyError: 'kwargs'
$ dnadna train [training_config.yml]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/pjobic/Work/DNADNA/master/run.py", line 135, in train
model_trainer.prepare(error_log=True, save_best=True, save_checkpoints=True)
File "/home/pjobic/Work/DNADNA/master/dnadna/training.py", line 319, in prepare
net_params = utils.dict_slice(net_params, *sig.parameters.keys())
File "/home/pjobic/Work/DNADNA/master/dnadna/utils/__init__.py", line 176, in dict_slice
s[k] = d[k]
File "/home/pjobic/Work/DNADNA/master/dnadna/utils/config.py", line 782, in __getitem__
value = super().__getitem__(key)
File "/home/pjobic/Work/DNADNA/master/dnadna/utils/config.py", line 305, in __getitem__
raise KeyError(key)
KeyError: 'kwargs'
>>>
It is due to ModelTrainer.prepare(), near line 300 in dnadna/training.py:
sig = inspect.signature(net_cls)
net_params = utils.dict_slice(net_params, *sig.parameters.keys())
And the code doesn't recognize "kwargs" argument.
Question
Is there a way to solve it ? @embray
Edited by JOBIC Pierre