Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 6227db0a authored by BIGAUD Nathan's avatar BIGAUD Nathan
Browse files

HaikuModel Unit testing

parent 10399088
Branches
Tags
1 merge request!32Add support for Jax / Haiku
# coding: utf-8
# Copyright 2023 Inria (Institut National de Recherche en Informatique
# et Automatique)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for TensorflowModel."""
import sys
from functools import partial
from typing import Any, List, Literal
import numpy as np
import pytest
try:
import chex
import haiku as hk
import jax
import jax.numpy as jnp
from jax.config import config as jaxconfig
from jaxtyping import Array
except ModuleNotFoundError:
pytest.skip("jax and/or haiku are unavailable", allow_module_level=True)
from declearn.model.haiku import HaikuModel, JaxNumpyVector
from declearn.typing import Batch
from declearn.utils import set_device_policy
# dirty trick to import from `model_testing.py`;
# pylint: disable=wrong-import-order, wrong-import-position
sys.path.append(".")
from model_testing import ModelTestCase, ModelTestSuite
# Overriding float32 default in jax
jaxconfig.update("jax_enable_x64", True)
# TODO : Implement LSTM with unroll_net, check https://github.com/deepmind/dm-haiku/blob/main/examples/haiku_lstms.ipynb
# TODO add chex variants
def cnn_fn(x: Array) -> jnp.ndarray:
"""Simple CNN in a purely functional form"""
model = hk.Sequential(
[
hk.Conv2D(output_channels=32, kernel_shape=(7, 7), padding="SAME"),
jax.nn.relu,
hk.MaxPool(window_shape=(8, 8), strides=(1, 1), padding="SAME"),
hk.Conv2D(output_channels=16, kernel_shape=(5, 5), padding="SAME"),
jax.nn.relu,
hk.AvgPool(window_shape=(8, 8), strides=(1, 1), padding="SAME"),
hk.Flatten(),
hk.Linear(1),
]
)
return model(x)
def mlp_fn(x: Array) -> jnp.ndarray:
"""Simple MLP in a purely functional form"""
model = hk.nets.MLP([32, 16, 1])
return model(x)
def loss_fn(y_pred: Array, y_true: Array) -> Array:
"""Per-sample binary cross entropy"""
y_pred = jax.nn.sigmoid(y_pred)
y_pred = jnp.squeeze(y_pred)
log_p, log_not_p = jnp.log(y_pred), jnp.log(1.0 - y_pred)
return -y_true * log_p - (1.0 - y_true) * log_not_p
class HaikuTestCase(ModelTestCase):
"""Tensorflow Keras test-case-provider fixture.
Implemented architectures are:
* "MLP":
- input: 64-dimensional features vectors
- stack: 32-neurons fully-connected layer with ReLU
16-neurons fully-connected layer with ReLU
1 output neuron with sigmoid activation
* "CNN":
- input: 64x64 image with 3 channels (normalized values)
- stack: 32 7x7 conv. filters, then 8x8 max pooling
16 5x5 conv. filters, then 8x8 avg pooling
1 output neuron with sigmoid activation
"""
vector_cls = JaxNumpyVector
tensor_cls = Array
def __init__(
self,
kind: Literal["MLP", "CNN"],
device: Literal["cpu", "gpu"],
) -> None:
"""Specify the desired model architecture."""
if kind not in ("MLP", "CNN"):
raise ValueError(f"Invalid test architecture: '{kind}'.")
if device not in ("cpu", "gpu"):
raise ValueError(f"Invalid device choice for test: '{device}'.")
self.kind = kind
self.device = device
set_device_policy(gpu=(device == "gpu"), idx=0)
@staticmethod
def to_numpy(
tensor: Any,
) -> np.ndarray:
"""Convert an input jax array to a numpy array."""
assert isinstance(tensor, Array)
return np.asarray(tensor)
@property
def dataset(
self,
) -> List[Batch]:
"""Suited toy binary-classification dataset."""
rng = np.random.default_rng(seed=0)
if self.kind.startswith("MLP"):
inputs = rng.normal(size=(2, 32, 64))
elif self.kind == "CNN":
inputs = rng.normal(size=(2, 32, 64, 64, 3))
labels = rng.choice([0, 1], size=(2, 32)).astype(float)
inputs = jnp.asarray(inputs)
labels = jnp.asarray(labels)
batches = list(zip(inputs, labels, [None, None]))
return batches
@property
def model(self) -> HaikuModel:
"""Suited toy binary-classification haiku models."""
if self.kind == "CNN":
shape = [64, 64, 3]
model_fn = cnn_fn
elif self.kind == "MLP":
shape = [64]
model_fn = mlp_fn
model = HaikuModel(model_fn, loss_fn)
model.initialize({"features_shape": shape, "data_type": "float64"})
return model
def assert_correct_device(
self,
vector: JaxNumpyVector,
) -> None:
"""Raise if a vector is backed on the wrong type of device."""
name = f"{self.device}:0"
assert all(
f"{arr.device().platform}:{arr.device().id}" == name
for arr in vector.coefs.values()
)
@pytest.fixture(name="test_case")
def fixture_test_case(
kind: Literal["MLP", "CNN"],
device: Literal["cpu", "gpu"],
) -> HaikuTestCase:
"""Fixture to access a TensorflowTestCase."""
return HaikuTestCase(kind, device)
DEVICES = ["cpu"]
if "gpu" in [d.platform for d in jax.devices()]:
DEVICES.append("gpu")
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("kind", ["MLP", "CNN"])
class TestTensorflowModel(ModelTestSuite):
"""Unit tests for declearn.model.tensorflow.TensorflowModel."""
def test_proper_model_placement(
self,
test_case: HaikuTestCase,
) -> None:
"""Check that at instantiation, model weights are properly placed."""
model = test_case.model
policy = model.device_policy
assert policy.gpu == (test_case.device == "gpu")
assert policy.idx == 0
params = getattr(model, "_params_leaves")
device = f"{test_case.device}:0"
for arr in params:
assert f"{arr.device().platform}:{arr.device().id}" == device
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment