Mentions légales du service

Skip to content
Snippets Groups Projects
Verified Commit 164ddb3b authored by ANDREY Paul's avatar ANDREY Paul
Browse files

Improve 'HaikuModel.set_trainable_weights' docs.

parent dd2eef50
No related branches found
No related tags found
No related merge requests found
......@@ -56,6 +56,7 @@ class HaikuModel(Model):
instance to be learned federatively.
Notes regarding device management (CPU, GPU, etc.):
* By default, jax places data and operations on GPU whenever one
is available.
* Our `HaikuModel` instead consults the device-placement policy (via
......@@ -257,33 +258,56 @@ class HaikuModel(Model):
an explicit dict of names or even the index of the parameter leaves
stored by our HaikuModel.
Example use :
>>> self.get_named_weights() = {'linear': {'w': None, 'b': None}}
Using a function as input
>>> criterion = lambda layer, name, value: name == 'w'
>>> self.set_trainable_weights(criterion)
>>> self._trainable
[0]
Using a dictionnary or pytree
>>> criterion = {'linear': {'b': None}}
>>> self.set_trainable_weights(criterion)
>>> self._trainable
[1]
Note : model needs to be initialized
Arguments
--------
criterion : Callable or dict(str,dict(str,any)) or list(int)
Criterion to be used to identify trainable params. If Callable,
must be a function taking in the name of the module (e.g.
layer name), the element name (e.g. parameter name) and the
corresponding data and returning a boolean. See
[the haiku doc](https://tinyurl.com/3v28upaz)
for details. If a list of integers, should represent the index of
trainable parameters in the parameter tree leaves. If a dict,
should be formatted as a pytree.
Notes
-----
- The model needs to be initialized for this method to work.
- The list of model weight names (general, or restricted to trainable
ones) may be consulted using the `get_weight_names` method.
Usage
-----
Let us pretend the model is made of a single linear layer; we want
to freeze its bias, leaving only the kernel weights trainable.
```
>>> # Display current names of trainable model weights.
>>> self.get_weight_names(trainable=True)
["linear/~/w", "linear/~/b"]
```
- (A) Using a list of weight names:
```
>>> criterion = ["linear/~/w"]
>>> self.set_trainable_weights(criterion)
```
- (B) Using a function as input:
```
>>> criterion = lambda layer, name, value: name == 'w'
>>> self.set_trainable_weights(criterion)
```
- (C) Using a dictionnary or pytree:
```
>>> criterion = {'linear': {'b': None}}
>>> self.set_trainable_weights(criterion)
```
- In all three cases, we can verify the results.
```
>>> self.get_weight_names(trainable=True)
["linear/~/w"]
```
Parameters
----------
criterion: Callable or dict(str,dict(str,any)) or list(int)
Criterion to be used to identify trainable params.
- If a list of strings, should represent the names of weights to
keep as trainable (freezing each and every other one).
- If callable, must be a function taking in the name of the module
(e.g. layer name), the element name (e.g. parameter name) and the
corresponding data and returning a boolean.
See [the haiku doc](https://tinyurl.com/3v28upaz) for details.
- If a dict, should be formatted as a pytree, the keys of which
are the nodes/leaves that should remain trainable.
"""
if not self._initialized:
raise ValueError("Model needs to be initialized first")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment