Mentions légales du service

Skip to content
Snippets Groups Projects
plot.py 7.22 KiB
Newer Older
DEBREUVE Eric's avatar
DEBREUVE Eric committed
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2019), Morgane Nadal (2020)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

from typing import Any, Optional, Sequence, Tuple, Union

import dijkstra_img as dk_
DEBREUVE Eric's avatar
DEBREUVE Eric committed
import matplotlib.pyplot as pl_
import numpy as nmpy
DEBREUVE Eric's avatar
DEBREUVE Eric committed
import skimage.measure as ms_
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from brick.type.base import array_t, py_array_picker_h
from brick.type.extension import extension_t
from brick.type.soma import soma_t
def MaximumIntensityProjectionZ(
    img: array_t,
    cmap: str = "tab20",
    axis: int = 0,
    block=True,
    output_image_file_name: str = None,
) -> None:
    """Maximum Image Projection on the Z axis."""
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    #
    xy = nmpy.amax(img, axis=axis)
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    pl_.imshow(xy, cmap=cmap)
    pl_.show(block=block)
    if output_image_file_name is not None:
        pl_.imsave(output_image_file_name, xy, cmap=cmap)
        pl_.close()
        print("Image saved in", output_image_file_name)
DEBREUVE Eric's avatar
DEBREUVE Eric committed


colors_c = ("g", "r", "b", "m", "c", "y")
mc_precision_c = 5  # mc=marching cubes


def PlotLMap(
    lmp: array_t, axes=None, labels: Union[int, Tuple[int, ...]] = None
DEBREUVE Eric's avatar
DEBREUVE Eric committed
) -> Optional[Any]:
    #
    depth_factor, depth_limit = __DepthFactorAndLimit__(lmp.shape)
    new_axes = axes is None
    if new_axes:
        _, axes = __FigAndAxes__(lmp.shape, depth_limit)

    if isinstance(labels, int):
        labels = (labels,)
    elif labels is None:
        labels = range(1, nmpy.amax(lmp).item() + 1)
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    for label in labels:
        try:
            vertices, faces, *_ = ms_.marching_cubes_lewiner(
                lmp == label, level=0.5, step_size=mc_precision_c
DEBREUVE Eric's avatar
DEBREUVE Eric committed
            )
            vertices[:, 0] *= depth_factor
            triangles = vertices[faces]

            mesh = Poly3DCollection(triangles)
DEBREUVE Eric's avatar
DEBREUVE Eric committed
            mesh.set_facecolor(colors_c[label % colors_c.__len__()])
            axes.add_collection3d(mesh)
        except RuntimeError as exc:
            print(f"{PlotLMap.__name__}: label.{label}: {exc.args[0]}")

    pl_.tight_layout()

    if new_axes:
        return axes


def PlotConnection(
    connection: py_array_picker_h, soma_uid: int, shape: Sequence[int], axes=None
DEBREUVE Eric's avatar
DEBREUVE Eric committed
) -> Optional[Any]:
    #
    depth_factor, depth_limit = __DepthFactorAndLimit__(shape)
    new_axes = axes is None
    if new_axes:
        _, axes = __FigAndAxes__(shape, depth_limit)

    # TODO This test is put here but could be move outside this function
    if connection is not None:
        axes.plot(
            depth_factor * nmpy.array(connection[0]),
DEBREUVE Eric's avatar
DEBREUVE Eric committed
            *connection[1:],
            colors_c[soma_uid % colors_c.__len__()],
        )

    pl_.tight_layout()

    if new_axes:
        return axes


def PlotExtensions(
    extensions: Union[extension_t, Sequence[extension_t]],
    shape: Sequence[int],
    axes=None,
DEBREUVE Eric's avatar
DEBREUVE Eric committed
) -> Optional[Any]:
    #
    depth_factor, depth_limit = __DepthFactorAndLimit__(shape)
    new_axes = axes is None
    if new_axes:
        _, axes = __FigAndAxes__(shape, depth_limit)

    costs = nmpy.empty(shape, dtype=nmpy.float32)
DEBREUVE Eric's avatar
DEBREUVE Eric committed

    if isinstance(extensions, extension_t):
        extensions = (extensions,)
    for extension in extensions:
        # Remove voxels that can be removed w/o breaking connectivity
        costs.fill(nmpy.inf)
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        costs[extension.sites] = 1
        for src_ep_idx in range(extension.end_points_as_array.shape[1] - 1):
            src_point = tuple(extension.end_points_as_array[:, src_ep_idx].tolist())
            for tgt_ep_idx in range(
                src_ep_idx + 1, extension.end_points_as_array.shape[1]
DEBREUVE Eric's avatar
DEBREUVE Eric committed
            ):
                tgt_point = tuple(extension.end_points_as_array[:, tgt_ep_idx].tolist())
                sites, _ = dk_.DijkstraShortestPath(
                    costs,
                    src_point,
                    tgt_point,
                    should_constrain_steps=False,
DEBREUVE Eric's avatar
DEBREUVE Eric committed
                )
                sites = tuple(zip(*sites))

                if extension.soma_uid is None:
                    uid = extension.uid
                else:
                    uid = extension.soma_uid
                # /!\ Redundant plots within ep-to-ep path
                axes.plot(
                    depth_factor * sites[0],
                    *sites[1:],
                    colors_c[uid % colors_c.__len__()],
                )

    pl_.tight_layout()

    if new_axes:
        return axes


def PlotSomaWithExtensions(soma: soma_t, soma_lmp: array_t, axes=None) -> Optional[Any]:
    #
    shape = soma_lmp.shape
    depth_factor, depth_limit = __DepthFactorAndLimit__(shape)
    new_axes = axes is None
    if new_axes:
        _, axes = __FigAndAxes__(shape, depth_limit)

    PlotLMap(soma_lmp, labels=soma.uid, axes=axes)
    for connection_path in filter(
        lambda path: path is not None, soma.connection_path.values()
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    ):
        PlotConnection(connection_path, soma.uid, shape, axes=axes)
    for extension in soma.Extensions():
        for connection_path in filter(
            lambda path: path is not None, extension.connection_path.values()
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        ):
            PlotConnection(connection_path, soma.uid, shape, axes=axes)
        PlotExtensions(extension, shape, axes=axes)

    pl_.tight_layout()

    if new_axes:
        return axes


def __DepthFactorAndLimit__(shape: Sequence[int]) -> Tuple[float, int]:
    #
    depth_factor = min(0.5 * (shape[1] + shape[2]) / shape[0], 1)
    depth_limit = int(depth_factor * shape[0]) + 1

    return depth_factor, depth_limit


def __FigAndAxes__(shape: Sequence[int], depth_limit: float) -> Tuple[Any, Any]:
    #
    fig = pl_.figure()
    axes = fig.add_subplot(111, projection=Axes3D.name)
DEBREUVE Eric's avatar
DEBREUVE Eric committed

    axes.set_xlabel(f"depth: {shape[0]}")
    axes.set_ylabel(f"row: {shape[1]}")
    axes.set_zlabel(f"col: {shape[2]}")

    axes.set_xlim3d(0, depth_limit)
    axes.set_ylim3d(0, shape[1])
    axes.set_zlim3d(0, shape[2])

    return fig, axes