# Copyright CNRS/Inria/UNS # Contributor(s): Eric Debreuve (since 2018) # # eric.debreuve@cnrs.fr # # This software is governed by the CeCILL license under French law and # abiding by the rules of distribution of free software. You can use, # modify and/ or redistribute the software under the terms of the CeCILL # license as circulated by CEA, CNRS and INRIA at the following URL # "http://www.cecill.info". # # As a counterpart to the access to the source code and rights to copy, # modify and redistribute granted by the license, users are provided only # with a limited warranty and the software's author, the holder of the # economic rights, and the successive licensors have only limited # liability. # # In this respect, the user's attention is drawn to the risks associated # with loading, using, modifying and/or developing or reproducing the # software by the user in light of its specific status of free software, # that may mean that it is complicated to manipulate, and that also # therefore means that it is reserved for developers and experienced # professionals having in-depth computer knowledge. Users are therefore # encouraged to load and test the software's suitability as regards their # requirements in conditions enabling the security of their systems and/or # data to be ensured and, more generally, to use and operate it in the # same conditions as regards security. # # The fact that you are presently reading this means that you have had # knowledge of the CeCILL license and that you accept its terms. # Base skeleton graph from __future__ import annotations import sklgraph.brick.edge as dg_ import sklgraph.brick.edge_update as eu_ import sklgraph.brick.node as nd_ import sklgraph.skl_map as sm_ from sklgraph.skl_map import skl_map_t from enum import Enum as enum_t from typing import Callable, Dict, Optional, Tuple, Union import matplotlib.pyplot as pl_ import networkx as nx_ import numpy as np_ import scipy.ndimage as im_ import skimage.draw as dw_ from mpl_toolkits import mplot3d as m3_ array_t = np_.ndarray plot_mode_e = enum_t("plot_mode_e", "Networkx SKL SKL_Curve Graphviz") class skl_graph_t(nx_.MultiGraph): # # Must be a unique connected component # __slots__ = ("dim", "bbox_lengths", "n_e_nodes", "n_b_nodes", "has_widths") colormap = {0: "b", 1: "r", 2: "g"} font_size = 6 width = 2 def __init__(self): # super().__init__() for slot in self.__class__.__slots__: setattr(self, slot, None) @classmethod def FromSkeleton( cls, skeleton: skl_map_t, size_voxel: array_t) -> skl_graph_t: # instance = cls() instance.bbox_lengths = skeleton.map.shape instance.dim = instance.bbox_lengths.__len__() instance.has_widths = skeleton.widths is not None part_map = skeleton.PartMap() if instance._DealsWithSpecialCases(part_map, widths=skeleton.widths): return instance e_nodes, e_node_lmap = nd_.EndNodes(part_map, widths=skeleton.widths) b_nodes, b_node_lmap = nd_.BranchNodes(part_map, widths=skeleton.widths) edges, edge_lmap = dg_.RawEdges(skeleton.map, b_node_lmap) eu_.AssignNodeIDsToEdges( e_nodes, e_node_lmap, b_nodes, b_node_lmap, edges, edge_lmap ) instance.add_nodes_from((node.uid, dict(as_node_t=node)) for node in e_nodes) instance.add_nodes_from((node.uid, dict(as_node_t=node)) for node in b_nodes) for edge in edges: edge.SetLengths(widths=skeleton.widths, size_voxel=size_voxel) instance.AddEdge(edge) instance.n_e_nodes = e_nodes.__len__() instance.n_b_nodes = b_nodes.__len__() return instance def _DealsWithSpecialCases(self, part_map: array_t, widths: array_t = None) -> bool: # invalid_n_neighbors = sm_.InvalidNNeighborsForMap(part_map) singleton = np_.where(part_map == 0) if singleton[0].size > 0: # Can only be 1 then (see SkeletonIsValid) assert singleton[0].size == 1 singleton = np_.array(singleton, dtype=np_.int64).squeeze() end_node = nd_.end_node_t.WithPosition(singleton, widths=widths) self.add_node(end_node.uid, as_node_t=end_node) return True # elif np_.logical_or(part_map == 2, part_map == invalid_n_neighbors).all(): # TODO: Handle '3x3-cross with missing central pixel'-case and other self-loops w/o nodes print( "Self-loop skeleton: Not handled yet; An exception will probably raise shortly!" ) return True return False @property def is_valid(self) -> bool: # return (sum([degree == 1 for _, degree in self.degree]) == self.n_e_nodes) and ( sum([degree > 1 for _, degree in self.degree]) == self.n_b_nodes ) def AddEdge(self, edge: dg_.edge_t) -> None: # assert edge.node_uids.__len__() == 2 edge_id = edge.uid version_number = 1 edge_id_w_vn = edge_id while self.has_edge(*edge.node_uids, key=edge_id_w_vn): version_number += 1 edge_id_w_vn = edge_id + "+" + version_number.__str__() self.add_edge(*edge.node_uids, key=edge_id_w_vn, as_edge_t=edge) def RebuiltSkeletonMap(self) -> array_t: # # Not uint to allow for subtraction map_ = np_.zeros(self.bbox_lengths, dtype=np_.int8) for ___, ___, edge in self.edges.data("as_edge_t"): map_[edge.sites] = 2 for ___, node in self.nodes.data("as_node_t"): if isinstance(node, nd_.branch_node_t): map_[node.sites] = 3 else: map_[tuple(node.position)] = 1 return map_ def RebuiltObjectMap(self) -> array_t: # if not self.has_widths: raise ValueError("Requires an SKL graph with widths") # Not uint to allow for subtraction map_ = np_.zeros(self.bbox_lengths, dtype=np_.int8) if self.dim == 2: ball_fct = dw_.circle else: ball_fct = _Sphere for ___, node in self.nodes.data("as_node_t"): if isinstance(node, nd_.branch_node_t): for *sites, radius in zip( *node.sites, np_.around(0.5 * (node.diameters - 1.0)).astype(np_.int64), ): map_[ball_fct(*sites, radius, shape=map_.shape)] = 1 else: map_[ ball_fct( *node.position, np_.around(0.5 * (node.diameter - 1.0)) .astype(np_.int64) .item(), shape=map_.shape, ) ] = 1 for ___, ___, edge in self.edges.data("as_edge_t"): for *sites, radius in zip( *edge.sites, np_.around(0.5 * (edge.widths - 1.0)).astype(np_.int64) ): map_[ball_fct(*sites, radius, shape=map_.shape)] = 1 return map_ def Plot( self, figure: pl_.Figure = None, axes: pl_.axes.Axes = None, mode: plot_mode_e = plot_mode_e.SKL, w_directions: bool = False, colormap: dict = None, font_size: int = None, width: float = None, should_block: bool = True, should_return_figure: bool = False, should_return_axes: bool = False, ) -> Optional[Union[object, Tuple[object, object]]]: # if self.number_of_nodes() < 1: print(f"{__name__}.{self.Plot.__name__}: Empty graph") return if axes is None: if figure is None: figure = pl_.figure() if self.dim == 2: axes = figure.gca() else: axes = figure.add_subplot(1, 1, 1, projection=m3_.Axes3D.name) axes.invert_yaxis() else: figure = axes.get_figure() if axes.yaxis_inverted(): transformation = lambda y: y vector_transf = lambda y: y else: max_0 = self.bbox_lengths[0] - 1 transformation = lambda y: max_0 - np_.asarray(y) vector_transf = lambda y: -np_.asarray(y) transform_coords = lambda pos: (pos[1], transformation(pos[0]), *pos[2:]) positions_as_dict = dict( (uid, transform_coords(node.position)) for uid, node in self.nodes.data("as_node_t") ) if font_size is None: font_size = skl_graph_t.font_size if self.dim == 2: if mode is plot_mode_e.Networkx: self._PlotWithNetworkX( positions_as_dict, axes, colormap, font_size, width ) # elif mode in (plot_mode_e.SKL, plot_mode_e.SKL_Curve): self._PlotExplicitly( positions_as_dict, transformation, vector_transf, axes, font_size, mode is plot_mode_e.SKL_Curve, w_directions, ) # elif mode is plot_mode_e.Graphviz: self._PlotWithGraphviz(axes) # else: raise ValueError(f"{mode}: Invalid mode") # else: self._PlotExplicitly( positions_as_dict, transformation, vector_transf, axes, font_size, mode is plot_mode_e.SKL_Curve, w_directions, ) if self.dim == 2: # Matplotlib says: NotImplementedError: It is not currently possible to manually set the aspect on 3D axes axes.axis("equal") if should_block: pl_.show() # Better named as TriggerMatplotlibEventLoop if should_return_figure: if should_return_axes: return figure, axes else: return figure elif should_return_axes: return axes def _PlotWithNetworkX( self, positions_as_dict: Dict[str, Tuple[int, ...]], axes: pl_.axes.Axes, colormap: dict, font_size: int, width: float, ) -> None: # if colormap is None: colormap = skl_graph_t.colormap if width is None: width = skl_graph_t.width node_degrees = (elm[1] for elm in self.degree) node_colors = tuple( colormap[degree] if degree < 3 else colormap[2] for degree in node_degrees ) nx_.draw_networkx( self, ax=axes, pos=positions_as_dict, with_labels=True, node_color=node_colors, font_size=font_size, width=width, ) nx_.draw_networkx_edge_labels( self, ax=axes, pos=positions_as_dict, edge_labels=self._EdgeIDsForPlot(), font_size=font_size, ) def _PlotExplicitly( self, positions_as_dict: Dict[str, Tuple[int, ...]], transformation: Callable[[array_t], array_t], vector_transf: Callable[[array_t], array_t], axes: pl_.axes.Axes, font_size: int, as_curve: bool, w_directions: bool, ) -> None: # dg_.Plot( self.edges.data("as_edge_t"), transformation, vector_transf, axes, as_curve=as_curve, w_directions=w_directions, ) nd_.PlotEndNodes(self.nodes.data("as_node_t"), transformation, axes) if self.dim == 2: nd_.Plot2DBranchNodes(self.nodes.data("as_node_t"), transformation, axes) nx_.draw_networkx_labels( self, ax=axes, pos=positions_as_dict, font_size=font_size ) else: nd_.Plot3DBranchNodes(self.nodes.data("as_node_t"), transformation, axes) nd_.Plot3DNodeLabels(self, positions_as_dict, axes, font_size) def _PlotWithGraphviz(self, axes: pl_.axes.Axes) -> None: # try: import imageio as io_ import pygraphviz as gp_ import tempfile as tp_ graph = nx_.nx_agraph.to_agraph(self) with tp_.NamedTemporaryFile() as tmp_accessor: img_name = tmp_accessor.name graph.layout() graph.draw(img_name, format="png") axes.imshow(io_.imread(img_name)) except Exception as exc: axes.text( 0, 0, f"Unable to plot graph using pygraphviz/imageio.\nPlease check installed modules.\n[{exc}]", horizontalalignment="center", ) def _EdgeIDsForPlot(self) -> Dict[str, str]: # lengths_as_dict = nx_.get_edge_attributes(self, "length") w_lengths_as_dict = ( nx_.get_edge_attributes(self, "w_length") if self.has_widths else None ) w_length_str = "" edge_ids = {} for key, value in lengths_as_dict.items(): if w_lengths_as_dict is not None: w_length_str = "/" + str(round(w_lengths_as_dict[key])) edge_ids[key[0:2]] = key[2] + "\n" + str(round(value)) + w_length_str return edge_ids def _Sphere( row: int, col: int, dep: int, radius: int, shape: Tuple[int, int, int] ) -> array_t: # sphere = np_.zeros(shape, dtype=np_.bool) # dw_.ellipsoid leaves a one pixel margin around the ellipse, hence [1:-1, 1:-1, 1:-1] ellipse = dw_.ellipsoid(radius, radius, radius)[1:-1, 1:-1, 1:-1] sp_slices = tuple( slice(0, min(sphere.shape[idx_], ellipse.shape[idx_])) for idx_ in (0, 1, 2) ) sphere[sp_slices] = ellipse[sp_slices] return im_.shift( sphere, (row - radius, col - radius, dep - radius), order=0, prefilter=False )