Mentions légales du service

Skip to content
Snippets Groups Projects
frangi_py_test.py 9.98 KiB
Newer Older
NADAL Morgane's avatar
NADAL Morgane committed
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2019)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

import brick.processing.frangi3 as fg_

import itertools as it_
import time as tm_
from typing import Sequence, Tuple

import matplotlib.pyplot as pl_
import numpy as np_
import scipy.ndimage as im_
import skimage.draw as dw_
import skimage.measure as ms_
import skimage.morphology as mp_
from mpl_toolkits.mplot3d import Axes3D
import skimage.io as io_


array_t = np_.ndarray


# Free parameters
PARALLEL_CHOICES = (False, True)
METHOD_CHOICES = ("python", "c")
DIFFERENTIATION_CHOICES = ("indirect", "direct")

SHAPE = (200, 200, 200)
MARGIN = 4

SCALE_RANGE = (0.1, 6.1)
SCALE_STEP = 1.0

# Spiral parameters
N_SAMPLES = 500
ANGLE_MAX = 50.0
HEIGHT_MAX = 2.0
ROTATION_SPEED = 0.05
RADIAL_SPEED = 0.05


def Spiral(
    n_samples: int,
    angle_max: float,
    height_max: float,
    rotation_speed: float,
    radial_speed: float,
) -> Tuple[array_t, array_t, array_t]:
    #
    angles = np_.linspace(0.0, angle_max, n_samples)
    x_coords = rotation_speed * np_.exp(radial_speed * angles) * np_.cos(angles)
    y_coords = rotation_speed * np_.exp(radial_speed * angles) * np_.sin(angles)
    z_coords = np_.linspace(0.0, height_max, n_samples)

    return x_coords, y_coords, z_coords


def ScaledAndShiftedAndPossiblyRounded(
    x_coords: array_t,
    y_coords: array_t,
    z_coords: array_t,
    target_bbox: Sequence[int],
    margin: int = 0,
    rounded: bool = False,
) -> Tuple[array_t, array_t, array_t]:
    #
    if target_bbox.__len__() == 3:  # It is a shape
        target_bbox = (
            0,
            target_bbox[0] - 1,
            0,
            target_bbox[1] - 1,
            0,
            target_bbox[2] - 1,
        )

    corner = np_.array((np_.min(x_coords), np_.min(y_coords), np_.min(z_coords)))
    shape = np_.diff(
        (
            corner[0],
            np_.max(x_coords),
            corner[1],
            np_.max(y_coords),
            corner[2],
            np_.max(z_coords),
        )
    )[0::2]
    # Even if dealing with indices, do not do + 1
    target_shape = np_.diff(target_bbox)[0::2] - 2 * margin

    scaling = target_shape / shape
    shift = tuple(tgt_coord + margin for tgt_coord in target_bbox[0::2])

    x_coords = scaling[0] * (x_coords - corner[0]) + shift[0]
    y_coords = scaling[1] * (y_coords - corner[1]) + shift[1]
    z_coords = scaling[2] * (z_coords - corner[2]) + shift[2]

    if rounded:
        x_coords = np_.around(x_coords)
        y_coords = np_.around(y_coords)
        z_coords = np_.around(z_coords)

    return x_coords, y_coords, z_coords


def WriteCurveInVolume(
    x_coords: array_t,
    y_coords: array_t,
    z_coords: array_t,
    volume: array_t,
    target_bbox: Sequence[int] = None,
    thickness: int = 1,
) -> None:
    #
    coords = []  # curr_x, next_x, curr_y, next_y...
    for coord in (x_coords, y_coords, z_coords):
        iterators = it_.tee(coord)
        next(iterators[1], None)
        coords.extend(iterators)

    for curr_x, next_x, curr_y, next_y, curr_z, next_z in zip(*coords):
        line = dw_.line_nd(
            (curr_x, curr_y, curr_z),
            (next_x, next_y, next_z),
            endpoint=True,
            integer=True,
        )

        volume[line] = 1

    if thickness > 1:
        selem = mp_.ball(thickness // 2, dtype=np_.bool)
        if target_bbox is None:
            dilated = im_.binary_dilation(volume, structure=selem).astype(
                np_.uint8, copy=False
            )
            volume[...] = dilated[...]
        else:
            x_slice = slice(target_bbox[0], target_bbox[1] + 1)
            y_slice = slice(target_bbox[2], target_bbox[3] + 1)
            z_slice = slice(target_bbox[4], target_bbox[5] + 1)
            dilated = im_.binary_dilation(
                volume[x_slice, y_slice, z_slice], structure=selem
            ).astype(np_.uint8, copy=False)
            volume[x_slice, y_slice, z_slice] = dilated[...]


def PrepareCurvePlot(
    x_coords: array_t, y_coords: array_t, z_coords: array_t, title: str = None
) -> None:
    #
    fig = pl_.figure()
    ax = fig.add_subplot(111, projection=Axes3D.name)
    ax.plot(x_coords, y_coords, z_coords)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    if title is not None:
        ax.set_title(title)


def PrepareIsoSurfacePlot(volume: array_t, title: str = None) -> None:
    #
    fig = pl_.figure()
    ax = fig.add_subplot(111, projection=Axes3D.name)
    verts, faces = ms_.marching_cubes_classic(
        volume, level=0.0, spacing=(0.5, 0.5, 0.5)
    )
    ax.plot_trisurf(verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap="Spectral", lw=1)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")
    if title is not None:
        ax.set_title(title)


if __name__ == "__main__":
    #
    print(f"STARTED: {tm_.strftime('%a, %b %d %Y @ %H:%M:%S')}")
    start_time = tm_.time()

    print("--- Volume Creation")

    # volume_ = np_.zeros(SHAPE, dtype=np_.uint8)
    # x_local, y_local, z_local = Spiral(
    #     N_SAMPLES, ANGLE_MAX, HEIGHT_MAX, ROTATION_SPEED, RADIAL_SPEED
    # )
    #
    # half_shape = tuple(dim // 2 for dim in SHAPE)
    # thickness_ = 1
    # for x_range in ((0, half_shape[0]), (half_shape[0] + 1, SHAPE[0] - 1)):
    #     for y_range in ((0, half_shape[1]), (half_shape[1] + 1, SHAPE[1] - 1)):
    #         for z_range in ((0, half_shape[2]), (half_shape[2] + 1, SHAPE[2] - 1)):
    #             print(f"    Thickness: {thickness_}")
    #             target_bbox_ = x_range + y_range + z_range
    #             x_global, y_global, z_global = ScaledAndShiftedAndPossiblyRounded(
    #                 x_local,
    #                 y_local,
    #                 z_local,
    #                 target_bbox_,
    #                 margin=MARGIN,
    #                 rounded=True,
    #             )
    #             WriteCurveInVolume(
    #                 x_global,
    #                 y_global,
    #                 z_global,
    #                 volume_,
    #                 target_bbox_,
    #                 thickness=thickness_,
    #             )
    #             thickness_ += 1

    data_path = 'D:\\MorganeNadal\\PyCharm\\nutrimorph\\data\\DIO_6H_6_1.70bis_2.2_3.tif'
    volume_ = io_.imread(data_path)
    volume_ = volume_[:,:,:,1]
    # volume_ = volume_[:, 512:, 512:]

    print("--- Display Preparation")

    PrepareIsoSurfacePlot(volume_, "Original Volume")

    print("--- Done")

    useless_choices = (
        # (False, "c", "direct"),
        # (False, "c", "indirect"),
        (True, "c", "indirect"),
        (True, "c", "direct"),
        # (True, "python", "direct"),
        # (True, "python", "indirect"),
        # (False, "python", "direct"),
        # (False, "python", "indirect"),
    )
    prm_choices = []
    for parallel_prm in PARALLEL_CHOICES:
        for method_prm in METHOD_CHOICES:
            for differentiation_prm in DIFFERENTIATION_CHOICES:
                choice = (parallel_prm, method_prm, differentiation_prm)
                if choice not in useless_choices:
                    prm_choices.append(choice)

    for idx, parameters in enumerate(prm_choices, start=1):
        prm_as_str = "/".join(elm.__str__().capitalize() for elm in parameters)
        prm_as_str = prm_as_str.replace("False", "Sequential")
        prm_as_str = prm_as_str.replace("True", "Parallel")

        elapsed_time = tm_.gmtime(tm_.time() - start_time)
        print(f"\nElapsed Time={tm_.strftime('%Hh %Mm %Ss', elapsed_time)}")

        print(f"--- Frangi Enhancement {idx} of {prm_choices.__len__()}: {prm_as_str}")

        enhanced, scale_map = fg_.FrangiEnhancement(
            volume_,
            scale_range=SCALE_RANGE,
            scale_step=SCALE_STEP,
            alpha=0.5,
            beta=0.5,
            frangi_c=500.0,
            in_parallel=parameters[0],
            method=parameters[1],
            differentiation_mode=parameters[2],
        )
        elapsed_time = tm_.gmtime(tm_.time() - start_time)
        print(f"\nElapsed Time={tm_.strftime('%Hh %Mm %Ss', elapsed_time)}")

        print("--- Display Preparation")

        PrepareIsoSurfacePlot(enhanced, f"Enhanced Volume {prm_as_str}")
        fig_ = pl_.figure()
        ax_ = fig_.add_subplot(111)
        ax_.matshow(scale_map.max(axis=2))

        print("--- Done")
        elapsed_time = tm_.gmtime(tm_.time() - start_time)
        print(f"\nElapsed Time={tm_.strftime('%Hh %Mm %Ss', elapsed_time)}")

    elapsed_time = tm_.gmtime(tm_.time() - start_time)
    print(f"\nElapsed Time={tm_.strftime('%Hh %Mm %Ss', elapsed_time)}")
    print(f"DONE: {tm_.strftime('%a, %b %d %Y @ %H:%M:%S')}")

    pl_.show()