Mentions légales du service

Skip to content
Snippets Groups Projects

Fix `TensorflowModel` gradients' labeling

Merged ANDREY Paul requested to merge hotfix-tf-gradient-names into main
1 file
+ 3
0
Compare changes
  • Side-by-side
  • Inline
@@ -91,6 +91,9 @@ class ModelTestSuite:
w_end = model.get_weights()
assert w_srt == w_end
assert isinstance(grads, test_case.vector_cls)
# Check that gradients and weights share the same labeling.
# Note: some weights may not have gradients (if they are frozen).
assert set(grads.coefs).issubset(w_end.coefs)
def test_compute_batch_gradients_np(
self,
Loading