Mentions légales du service

Skip to content
Snippets Groups Projects
plot.py 7.23 KiB
Newer Older
DEBREUVE Eric's avatar
DEBREUVE Eric committed
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2019), Morgane Nadal (2020)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

import brick.processing.dijkstra_1_to_n as dk_
from brick.component.extension import extension_t
from brick.component.soma import soma_t
from brick.general.type import array_t, py_array_picker_h

from typing import Any, Optional, Sequence, Tuple, Union

import matplotlib.pyplot as pl_
import mpl_toolkits.mplot3d as p3_
import numpy as np_
import skimage.measure as ms_


def MaximumIntensityProjectionZ(img: array_t, cmap: str ='tab20', axis: int = 0, block=True, output_image_file_name: str = None) -> None:
    """ Maximum Image Projection on the Z axis. """
    #
    xy = np_.amax(img, axis=axis)
    pl_.imshow(xy, cmap=cmap)
    pl_.show(block=block)
    if output_image_file_name is not None:
        pl_.imsave(output_image_file_name, xy, cmap=cmap)
        pl_.close()
        print('Image saved in', output_image_file_name)


colors_c = ("g", "r", "b", "m", "c", "y")
mc_precision_c = 5  # mc=marching cubes


def PlotLMap(
        lmp: array_t, axes=None, labels: Union[int, Tuple[int, ...]] = None
) -> Optional[Any]:
    #
    depth_factor, depth_limit = __DepthFactorAndLimit__(lmp.shape)
    new_axes = axes is None
    if new_axes:
        _, axes = __FigAndAxes__(lmp.shape, depth_limit)

    if isinstance(labels, int):
        labels = (labels,)
    elif labels is None:
        labels = range(1, lmp.max() + 1)
    for label in labels:
        try:
            vertices, faces, _, _ = ms_.marching_cubes_lewiner(
                lmp == label, 0.5, step_size=mc_precision_c
            )
            vertices[:, 0] *= depth_factor
            triangles = vertices[faces]

            mesh = p3_.art3d.Poly3DCollection(triangles)
            mesh.set_facecolor(colors_c[label % colors_c.__len__()])
            axes.add_collection3d(mesh)
        except RuntimeError as exc:
            print(f"{PlotLMap.__name__}: label.{label}: {exc.args[0]}")

    pl_.tight_layout()

    if new_axes:
        return axes


def PlotConnection(
        connection: py_array_picker_h, soma_uid: int, shape: Sequence[int], axes=None
) -> Optional[Any]:
    #
    depth_factor, depth_limit = __DepthFactorAndLimit__(shape)
    new_axes = axes is None
    if new_axes:
        _, axes = __FigAndAxes__(shape, depth_limit)

    # TODO This test is put here but could be move outside this function
    if connection is not None:
        axes.plot(
            depth_factor * np_.array(connection[0]),
            *connection[1:],
            colors_c[soma_uid % colors_c.__len__()],
        )

    pl_.tight_layout()

    if new_axes:
        return axes


def PlotExtensions(
        extensions: Union[extension_t, Sequence[extension_t]],
        shape: Sequence[int],
        axes=None,
) -> Optional[Any]:
    #
    depth_factor, depth_limit = __DepthFactorAndLimit__(shape)
    new_axes = axes is None
    if new_axes:
        _, axes = __FigAndAxes__(shape, depth_limit)

    costs = np_.empty(shape, dtype=np_.float32)

    if isinstance(extensions, extension_t):
        extensions = (extensions,)
    for extension in extensions:
        # Remove voxels that can be removed w/o breaking connectivity
        costs.fill(np_.inf)
        costs[extension.sites] = 1
        for src_ep_idx in range(extension.end_points_as_array.shape[1] - 1):
            src_point = tuple(extension.end_points_as_array[:, src_ep_idx].tolist())
            for tgt_ep_idx in range(
                    src_ep_idx + 1, extension.end_points_as_array.shape[1]
            ):
                tgt_point = tuple(extension.end_points_as_array[:, tgt_ep_idx].tolist())
                sites, _ = dk_.DijkstraShortestPath(
                    costs,
                    src_point,
                    tgt_point,
                    limit_to_sphere=False,
                    constrain_direction=False,
                )
                sites = tuple(zip(*sites))

                if extension.soma_uid is None:
                    uid = extension.uid
                else:
                    uid = extension.soma_uid
                # /!\ Redundant plots within ep-to-ep path
                axes.plot(
                    depth_factor * sites[0],
                    *sites[1:],
                    colors_c[uid % colors_c.__len__()],
                )

    pl_.tight_layout()

    if new_axes:
        return axes


def PlotSomaWithExtensions(soma: soma_t, soma_lmp: array_t, axes=None) -> Optional[Any]:
    #
    shape = soma_lmp.shape
    depth_factor, depth_limit = __DepthFactorAndLimit__(shape)
    new_axes = axes is None
    if new_axes:
        _, axes = __FigAndAxes__(shape, depth_limit)

    PlotLMap(soma_lmp, labels=soma.uid, axes=axes)
    for connection_path in filter(
            lambda path: path is not None, soma.connection_path.values()
    ):
        PlotConnection(connection_path, soma.uid, shape, axes=axes)
    for extension in soma.Extensions():
        for connection_path in filter(
                lambda path: path is not None, extension.connection_path.values()
        ):
            PlotConnection(connection_path, soma.uid, shape, axes=axes)
        PlotExtensions(extension, shape, axes=axes)

    pl_.tight_layout()

    if new_axes:
        return axes


def __DepthFactorAndLimit__(shape: Sequence[int]) -> Tuple[float, int]:
    #
    depth_factor = min(0.5 * (shape[1] + shape[2]) / shape[0], 1)
    depth_limit = int(depth_factor * shape[0]) + 1

    return depth_factor, depth_limit


def __FigAndAxes__(shape: Sequence[int], depth_limit: float) -> Tuple[Any, Any]:
    #
    fig = pl_.figure()
    axes = fig.add_subplot(111, projection=p3_.Axes3D.name)

    axes.set_xlabel(f"depth: {shape[0]}")
    axes.set_ylabel(f"row: {shape[1]}")
    axes.set_zlabel(f"col: {shape[2]}")

    axes.set_xlim3d(0, depth_limit)
    axes.set_ylim3d(0, shape[1])
    axes.set_zlim3d(0, shape[2])

    return fig, axes