diff --git a/declearn/model/haiku/_model.py b/declearn/model/haiku/_model.py index 7ed66f802f670dd912102cdc33af24f94a5fbd66..c6e4646673f4edf5872ecb3524231f4a4de9a627 100644 --- a/declearn/model/haiku/_model.py +++ b/declearn/model/haiku/_model.py @@ -56,6 +56,7 @@ class HaikuModel(Model): instance to be learned federatively. Notes regarding device management (CPU, GPU, etc.): + * By default, jax places data and operations on GPU whenever one is available. * Our `HaikuModel` instead consults the device-placement policy (via @@ -257,33 +258,56 @@ class HaikuModel(Model): an explicit dict of names or even the index of the parameter leaves stored by our HaikuModel. - Example use : - >>> self.get_named_weights() = {'linear': {'w': None, 'b': None}} - Using a function as input - >>> criterion = lambda layer, name, value: name == 'w' - >>> self.set_trainable_weights(criterion) - >>> self._trainable - [0] - Using a dictionnary or pytree - >>> criterion = {'linear': {'b': None}} - >>> self.set_trainable_weights(criterion) - >>> self._trainable - [1] - - Note : model needs to be initialized - - Arguments - -------- - criterion : Callable or dict(str,dict(str,any)) or list(int) - Criterion to be used to identify trainable params. If Callable, - must be a function taking in the name of the module (e.g. - layer name), the element name (e.g. parameter name) and the - corresponding data and returning a boolean. See - [the haiku doc](https://tinyurl.com/3v28upaz) - for details. If a list of integers, should represent the index of - trainable parameters in the parameter tree leaves. If a dict, - should be formatted as a pytree. + Notes + ----- + - The model needs to be initialized for this method to work. + - The list of model weight names (general, or restricted to trainable + ones) may be consulted using the `get_weight_names` method. + + Usage + ----- + + Let us pretend the model is made of a single linear layer; we want + to freeze its bias, leaving only the kernel weights trainable. + ``` + >>> # Display current names of trainable model weights. + >>> self.get_weight_names(trainable=True) + ["linear/~/w", "linear/~/b"] + ``` + - (A) Using a list of weight names: + ``` + >>> criterion = ["linear/~/w"] + >>> self.set_trainable_weights(criterion) + ``` + - (B) Using a function as input: + ``` + >>> criterion = lambda layer, name, value: name == 'w' + >>> self.set_trainable_weights(criterion) + ``` + - (C) Using a dictionnary or pytree: + ``` + >>> criterion = {'linear': {'b': None}} + >>> self.set_trainable_weights(criterion) + ``` + - In all three cases, we can verify the results. + ``` + >>> self.get_weight_names(trainable=True) + ["linear/~/w"] + ``` + Parameters + ---------- + criterion: Callable or dict(str,dict(str,any)) or list(int) + Criterion to be used to identify trainable params. + + - If a list of strings, should represent the names of weights to + keep as trainable (freezing each and every other one). + - If callable, must be a function taking in the name of the module + (e.g. layer name), the element name (e.g. parameter name) and the + corresponding data and returning a boolean. + See [the haiku doc](https://tinyurl.com/3v28upaz) for details. + - If a dict, should be formatted as a pytree, the keys of which + are the nodes/leaves that should remain trainable. """ if not self._initialized: raise ValueError("Model needs to be initialized first")