Mentions légales du service

Skip to content
Snippets Groups Projects
skl_graph.py 13.9 KiB
Newer Older
NADAL Morgane's avatar
NADAL Morgane committed
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2018)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

# Base skeleton graph

from __future__ import annotations

import sklgraph.brick.edge as dg_
import sklgraph.brick.edge_update as eu_
import sklgraph.brick.node as nd_
import sklgraph.skl_map as sm_
from sklgraph.skl_map import skl_map_t
NADAL Morgane's avatar
NADAL Morgane committed

from enum import Enum as enum_t
from typing import Callable, Dict, Optional, Tuple, Union

import matplotlib.pyplot as pl_
import networkx as nx_
import numpy as np_
import scipy.ndimage as im_
import skimage.draw as dw_
from mpl_toolkits import mplot3d as m3_


array_t = np_.ndarray
plot_mode_e = enum_t("plot_mode_e", "Networkx SKL SKL_Curve Graphviz")


class skl_graph_t(nx_.MultiGraph):
    #
    # Must be a unique connected component
    #
    __slots__ = ("dim", "bbox_lengths", "n_e_nodes", "n_b_nodes", "has_widths")

    colormap = {0: "b", 1: "r", 2: "g"}
    font_size = 6
    width = 2

    def __init__(self):
        #
        super().__init__()
        for slot in self.__class__.__slots__:
            setattr(self, slot, None)

    @classmethod
    def FromSkeleton(
        cls, skeleton: skl_map_t, size_voxel: array_t) -> skl_graph_t:
NADAL Morgane's avatar
NADAL Morgane committed
        #
        instance = cls()

        instance.bbox_lengths = skeleton.map.shape
        instance.dim = instance.bbox_lengths.__len__()
        instance.has_widths = skeleton.widths is not None

        part_map = skeleton.PartMap()

        if instance._DealsWithSpecialCases(part_map, widths=skeleton.widths):
            return instance

        e_nodes, e_node_lmap = nd_.EndNodes(part_map, widths=skeleton.widths)
        b_nodes, b_node_lmap = nd_.BranchNodes(part_map, widths=skeleton.widths)
        edges, edge_lmap = dg_.RawEdges(skeleton.map, b_node_lmap)
        eu_.AssignNodeIDsToEdges(
            e_nodes, e_node_lmap, b_nodes, b_node_lmap, edges, edge_lmap
        )

        instance.add_nodes_from((node.uid, dict(as_node_t=node)) for node in e_nodes)
        instance.add_nodes_from((node.uid, dict(as_node_t=node)) for node in b_nodes)
        for edge in edges:
            edge.SetLengths(widths=skeleton.widths, size_voxel=size_voxel)
NADAL Morgane's avatar
NADAL Morgane committed
            instance.AddEdge(edge)

        instance.n_e_nodes = e_nodes.__len__()
        instance.n_b_nodes = b_nodes.__len__()

        return instance

    def _DealsWithSpecialCases(self, part_map: array_t, widths: array_t = None) -> bool:
        #
        invalid_n_neighbors = sm_.InvalidNNeighborsForMap(part_map)

        singleton = np_.where(part_map == 0)
        if singleton[0].size > 0:
            # Can only be 1 then (see SkeletonIsValid)
            assert singleton[0].size == 1
            singleton = np_.array(singleton, dtype=np_.int64).squeeze()
            end_node = nd_.end_node_t.WithPosition(singleton, widths=widths)
            self.add_node(end_node.uid, as_node_t=end_node)

            return True
        #
        elif np_.logical_or(part_map == 2, part_map == invalid_n_neighbors).all():
            # TODO: Handle '3x3-cross with missing central pixel'-case and other self-loops w/o nodes
            print(
                "Self-loop skeleton: Not handled yet; An exception will probably raise shortly!"
            )

            return True

        return False

    @property
    def is_valid(self) -> bool:
        #
        return (sum([degree == 1 for _, degree in self.degree]) == self.n_e_nodes) and (
            sum([degree > 1 for _, degree in self.degree]) == self.n_b_nodes
        )

    def AddEdge(self, edge: dg_.edge_t) -> None:
        #
        assert edge.node_uids.__len__() == 2

        edge_id = edge.uid
        version_number = 1
        edge_id_w_vn = edge_id
        while self.has_edge(*edge.node_uids, key=edge_id_w_vn):
            version_number += 1
            edge_id_w_vn = edge_id + "+" + version_number.__str__()

        self.add_edge(*edge.node_uids, key=edge_id_w_vn, as_edge_t=edge)

    def RebuiltSkeletonMap(self) -> array_t:
        #
        # Not uint to allow for subtraction
        map_ = np_.zeros(self.bbox_lengths, dtype=np_.int8)

        for ___, ___, edge in self.edges.data("as_edge_t"):
            map_[edge.sites] = 2

        for ___, node in self.nodes.data("as_node_t"):
            if isinstance(node, nd_.branch_node_t):
                map_[node.sites] = 3
            else:
                map_[tuple(node.position)] = 1

        return map_

    def RebuiltObjectMap(self) -> array_t:
        #
        if not self.has_widths:
            raise ValueError("Requires an SKL graph with widths")

        # Not uint to allow for subtraction
        map_ = np_.zeros(self.bbox_lengths, dtype=np_.int8)

        if self.dim == 2:
            ball_fct = dw_.circle
        else:
            ball_fct = _Sphere

        for ___, node in self.nodes.data("as_node_t"):
            if isinstance(node, nd_.branch_node_t):
                for *sites, radius in zip(
                    *node.sites,
                    np_.around(0.5 * (node.diameters - 1.0)).astype(np_.int64),
                ):
                    map_[ball_fct(*sites, radius, shape=map_.shape)] = 1
            else:
                map_[
                    ball_fct(
                        *node.position,
                        np_.around(0.5 * (node.diameter - 1.0))
                        .astype(np_.int64)
                        .item(),
                        shape=map_.shape,
                    )
                ] = 1

        for ___, ___, edge in self.edges.data("as_edge_t"):
            for *sites, radius in zip(
                *edge.sites, np_.around(0.5 * (edge.widths - 1.0)).astype(np_.int64)
            ):
                map_[ball_fct(*sites, radius, shape=map_.shape)] = 1

        return map_

    def Plot(
        self,
        figure: pl_.Figure = None,
        axes: pl_.axes.Axes = None,
        mode: plot_mode_e = plot_mode_e.SKL,
        w_directions: bool = False,
        colormap: dict = None,
        font_size: int = None,
        width: float = None,
        should_block: bool = True,
        should_return_figure: bool = False,
        should_return_axes: bool = False,
    ) -> Optional[Union[object, Tuple[object, object]]]:
        #
        if self.number_of_nodes() < 1:
            print(f"{__name__}.{self.Plot.__name__}: Empty graph")
            return

        if axes is None:
            if figure is None:
                figure = pl_.figure()
            if self.dim == 2:
                axes = figure.gca()
            else:
                axes = figure.add_subplot(1, 1, 1, projection=m3_.Axes3D.name)
            axes.invert_yaxis()
        else:
            figure = axes.get_figure()

        if axes.yaxis_inverted():
            transformation = lambda y: y
            vector_transf = lambda y: y
        else:
            max_0 = self.bbox_lengths[0] - 1
            transformation = lambda y: max_0 - np_.asarray(y)
            vector_transf = lambda y: -np_.asarray(y)

        transform_coords = lambda pos: (pos[1], transformation(pos[0]), *pos[2:])
        positions_as_dict = dict(
            (uid, transform_coords(node.position))
            for uid, node in self.nodes.data("as_node_t")
        )

        if font_size is None:
            font_size = skl_graph_t.font_size

        if self.dim == 2:
            if mode is plot_mode_e.Networkx:
                self._PlotWithNetworkX(
                    positions_as_dict, axes, colormap, font_size, width
                )
            #
            elif mode in (plot_mode_e.SKL, plot_mode_e.SKL_Curve):
                self._PlotExplicitly(
                    positions_as_dict,
                    transformation,
                    vector_transf,
                    axes,
                    font_size,
                    mode is plot_mode_e.SKL_Curve,
                    w_directions,
                )
            #
            elif mode is plot_mode_e.Graphviz:
                self._PlotWithGraphviz(axes)
            #
            else:
                raise ValueError(f"{mode}: Invalid mode")
            #
        else:
            self._PlotExplicitly(
                positions_as_dict,
                transformation,
                vector_transf,
                axes,
                font_size,
                mode is plot_mode_e.SKL_Curve,
                w_directions,
            )

        if self.dim == 2:
            # Matplotlib says: NotImplementedError: It is not currently possible to manually set the aspect on 3D axes
            axes.axis("equal")

        if should_block:
            pl_.show()  # Better named as TriggerMatplotlibEventLoop

        if should_return_figure:
            if should_return_axes:
                return figure, axes
            else:
                return figure
        elif should_return_axes:
            return axes

    def _PlotWithNetworkX(
        self,
        positions_as_dict: Dict[str, Tuple[int, ...]],
        axes: pl_.axes.Axes,
        colormap: dict,
        font_size: int,
        width: float,
    ) -> None:
        #
        if colormap is None:
            colormap = skl_graph_t.colormap
        if width is None:
            width = skl_graph_t.width

        node_degrees = (elm[1] for elm in self.degree)
        node_colors = tuple(
            colormap[degree] if degree < 3 else colormap[2] for degree in node_degrees
        )

        nx_.draw_networkx(
            self,
            ax=axes,
            pos=positions_as_dict,
            with_labels=True,
            node_color=node_colors,
            font_size=font_size,
            width=width,
        )
        nx_.draw_networkx_edge_labels(
            self,
            ax=axes,
            pos=positions_as_dict,
            edge_labels=self._EdgeIDsForPlot(),
            font_size=font_size,
        )

    def _PlotExplicitly(
        self,
        positions_as_dict: Dict[str, Tuple[int, ...]],
        transformation: Callable[[array_t], array_t],
        vector_transf: Callable[[array_t], array_t],
        axes: pl_.axes.Axes,
        font_size: int,
        as_curve: bool,
        w_directions: bool,
    ) -> None:
        #
        dg_.Plot(
            self.edges.data("as_edge_t"),
            transformation,
            vector_transf,
            axes,
            as_curve=as_curve,
            w_directions=w_directions,
        )
        nd_.PlotEndNodes(self.nodes.data("as_node_t"), transformation, axes)

        if self.dim == 2:
            nd_.Plot2DBranchNodes(self.nodes.data("as_node_t"), transformation, axes)
            nx_.draw_networkx_labels(
                self, ax=axes, pos=positions_as_dict, font_size=font_size
            )
        else:
            nd_.Plot3DBranchNodes(self.nodes.data("as_node_t"), transformation, axes)
            nd_.Plot3DNodeLabels(self, positions_as_dict, axes, font_size)

    def _PlotWithGraphviz(self, axes: pl_.axes.Axes) -> None:
        #
        try:
            import imageio as io_
            import pygraphviz as gp_
            import tempfile as tp_

            graph = nx_.nx_agraph.to_agraph(self)
            with tp_.NamedTemporaryFile() as tmp_accessor:
                img_name = tmp_accessor.name
                graph.layout()
                graph.draw(img_name, format="png")
                axes.imshow(io_.imread(img_name))
        except Exception as exc:
            axes.text(
                0,
                0,
                f"Unable to plot graph using pygraphviz/imageio.\nPlease check installed modules.\n[{exc}]",
                horizontalalignment="center",
            )

    def _EdgeIDsForPlot(self) -> Dict[str, str]:
        #
        lengths_as_dict = nx_.get_edge_attributes(self, "length")
        w_lengths_as_dict = (
            nx_.get_edge_attributes(self, "w_length") if self.has_widths else None
        )

        w_length_str = ""
        edge_ids = {}
        for key, value in lengths_as_dict.items():
            if w_lengths_as_dict is not None:
                w_length_str = "/" + str(round(w_lengths_as_dict[key]))
            edge_ids[key[0:2]] = key[2] + "\n" + str(round(value)) + w_length_str

        return edge_ids


def _Sphere(
    row: int, col: int, dep: int, radius: int, shape: Tuple[int, int, int]
) -> array_t:
    #
    sphere = np_.zeros(shape, dtype=np_.bool)
    # dw_.ellipsoid leaves a one pixel margin around the ellipse, hence [1:-1, 1:-1, 1:-1]
    ellipse = dw_.ellipsoid(radius, radius, radius)[1:-1, 1:-1, 1:-1]
    sp_slices = tuple(
        slice(0, min(sphere.shape[idx_], ellipse.shape[idx_])) for idx_ in (0, 1, 2)
    )
    sphere[sp_slices] = ellipse[sp_slices]

    return im_.shift(
        sphere, (row - radius, col - radius, dep - radius), order=0, prefilter=False
    )