-
NADAL Morgane authored
some refactoring + identification of an essential process for soma-ext and connexions : NormalizedImage()
NADAL Morgane authoredsome refactoring + identification of an essential process for soma-ext and connexions : NormalizedImage()
skl_graph.py 13.85 KiB
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2018)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.
# Base skeleton graph
from __future__ import annotations
import sklgraph.brick.edge as dg_
import sklgraph.brick.edge_update as eu_
import sklgraph.brick.node as nd_
import sklgraph.skl_map as sm_
from sklgraph.skl_map import skl_map_t
from enum import Enum as enum_t
from typing import Callable, Dict, Optional, Tuple, Union
import matplotlib.pyplot as pl_
import networkx as nx_
import numpy as np_
import scipy.ndimage as im_
import skimage.draw as dw_
from mpl_toolkits import mplot3d as m3_
array_t = np_.ndarray
plot_mode_e = enum_t("plot_mode_e", "Networkx SKL SKL_Curve Graphviz")
class skl_graph_t(nx_.MultiGraph):
#
# Must be a unique connected component
#
__slots__ = ("dim", "bbox_lengths", "n_e_nodes", "n_b_nodes", "has_widths")
colormap = {0: "b", 1: "r", 2: "g"}
font_size = 6
width = 2
def __init__(self):
#
super().__init__()
for slot in self.__class__.__slots__:
setattr(self, slot, None)
@classmethod
def FromSkeleton(
cls, skeleton: skl_map_t) -> skl_graph_t:
#
instance = cls()
instance.bbox_lengths = skeleton.map.shape
instance.dim = instance.bbox_lengths.__len__()
instance.has_widths = skeleton.widths is not None
part_map = skeleton.PartMap()
if instance._DealsWithSpecialCases(part_map, widths=skeleton.widths):
return instance
e_nodes, e_node_lmap = nd_.EndNodes(part_map, widths=skeleton.widths)
b_nodes, b_node_lmap = nd_.BranchNodes(part_map, widths=skeleton.widths)
edges, edge_lmap = dg_.RawEdges(skeleton.map, b_node_lmap)
eu_.AssignNodeIDsToEdges(
e_nodes, e_node_lmap, b_nodes, b_node_lmap, edges, edge_lmap
)
instance.add_nodes_from((node.uid, dict(as_node_t=node)) for node in e_nodes)
instance.add_nodes_from((node.uid, dict(as_node_t=node)) for node in b_nodes)
for edge in edges:
edge.SetLengths(skeleton.widths)
instance.AddEdge(edge)
instance.n_e_nodes = e_nodes.__len__()
instance.n_b_nodes = b_nodes.__len__()
return instance
def _DealsWithSpecialCases(self, part_map: array_t, widths: array_t = None) -> bool:
#
invalid_n_neighbors = sm_.InvalidNNeighborsForMap(part_map)
singleton = np_.where(part_map == 0)
if singleton[0].size > 0:
# Can only be 1 then (see SkeletonIsValid)
assert singleton[0].size == 1
singleton = np_.array(singleton, dtype=np_.int64).squeeze()
end_node = nd_.end_node_t.WithPosition(singleton, widths=widths)
self.add_node(end_node.uid, as_node_t=end_node)
return True
#
elif np_.logical_or(part_map == 2, part_map == invalid_n_neighbors).all():
# TODO: Handle '3x3-cross with missing central pixel'-case and other self-loops w/o nodes
print(
"Self-loop skeleton: Not handled yet; An exception will probably raise shortly!"
)
return True
return False
@property
def is_valid(self) -> bool:
#
return (sum([degree == 1 for _, degree in self.degree]) == self.n_e_nodes) and (
sum([degree > 1 for _, degree in self.degree]) == self.n_b_nodes
)
def AddEdge(self, edge: dg_.edge_t) -> None:
#
assert edge.node_uids.__len__() == 2
edge_id = edge.uid
version_number = 1
edge_id_w_vn = edge_id
while self.has_edge(*edge.node_uids, key=edge_id_w_vn):
version_number += 1
edge_id_w_vn = edge_id + "+" + version_number.__str__()
self.add_edge(*edge.node_uids, key=edge_id_w_vn, as_edge_t=edge)
def RebuiltSkeletonMap(self) -> array_t:
#
# Not uint to allow for subtraction
map_ = np_.zeros(self.bbox_lengths, dtype=np_.int8)
for ___, ___, edge in self.edges.data("as_edge_t"):
map_[edge.sites] = 2
for ___, node in self.nodes.data("as_node_t"):
if isinstance(node, nd_.branch_node_t):
map_[node.sites] = 3
else:
map_[tuple(node.position)] = 1
return map_
def RebuiltObjectMap(self) -> array_t:
#
if not self.has_widths:
raise ValueError("Requires an SKL graph with widths")
# Not uint to allow for subtraction
map_ = np_.zeros(self.bbox_lengths, dtype=np_.int8)
if self.dim == 2:
ball_fct = dw_.circle
else:
ball_fct = _Sphere
for ___, node in self.nodes.data("as_node_t"):
if isinstance(node, nd_.branch_node_t):
for *sites, radius in zip(
*node.sites,
np_.around(0.5 * (node.diameters - 1.0)).astype(np_.int64),
):
map_[ball_fct(*sites, radius, shape=map_.shape)] = 1
else:
map_[
ball_fct(
*node.position,
np_.around(0.5 * (node.diameter - 1.0))
.astype(np_.int64)
.item(),
shape=map_.shape,
)
] = 1
for ___, ___, edge in self.edges.data("as_edge_t"):
for *sites, radius in zip(
*edge.sites, np_.around(0.5 * (edge.widths - 1.0)).astype(np_.int64)
):
map_[ball_fct(*sites, radius, shape=map_.shape)] = 1
return map_
def Plot(
self,
figure: pl_.Figure = None,
axes: pl_.axes.Axes = None,
mode: plot_mode_e = plot_mode_e.SKL,
w_directions: bool = False,
colormap: dict = None,
font_size: int = None,
width: float = None,
should_block: bool = True,
should_return_figure: bool = False,
should_return_axes: bool = False,
) -> Optional[Union[object, Tuple[object, object]]]:
#
if self.number_of_nodes() < 1:
print(f"{__name__}.{self.Plot.__name__}: Empty graph")
return
if axes is None:
if figure is None:
figure = pl_.figure()
if self.dim == 2:
axes = figure.gca()
else:
axes = figure.add_subplot(1, 1, 1, projection=m3_.Axes3D.name)
axes.invert_yaxis()
else:
figure = axes.get_figure()
if axes.yaxis_inverted():
transformation = lambda y: y
vector_transf = lambda y: y
else:
max_0 = self.bbox_lengths[0] - 1
transformation = lambda y: max_0 - np_.asarray(y)
vector_transf = lambda y: -np_.asarray(y)
transform_coords = lambda pos: (pos[1], transformation(pos[0]), *pos[2:])
positions_as_dict = dict(
(uid, transform_coords(node.position))
for uid, node in self.nodes.data("as_node_t")
)
if font_size is None:
font_size = skl_graph_t.font_size
if self.dim == 2:
if mode is plot_mode_e.Networkx:
self._PlotWithNetworkX(
positions_as_dict, axes, colormap, font_size, width
)
#
elif mode in (plot_mode_e.SKL, plot_mode_e.SKL_Curve):
self._PlotExplicitly(
positions_as_dict,
transformation,
vector_transf,
axes,
font_size,
mode is plot_mode_e.SKL_Curve,
w_directions,
)
#
elif mode is plot_mode_e.Graphviz:
self._PlotWithGraphviz(axes)
#
else:
raise ValueError(f"{mode}: Invalid mode")
#
else:
self._PlotExplicitly(
positions_as_dict,
transformation,
vector_transf,
axes,
font_size,
mode is plot_mode_e.SKL_Curve,
w_directions,
)
if self.dim == 2:
# Matplotlib says: NotImplementedError: It is not currently possible to manually set the aspect on 3D axes
axes.axis("equal")
if should_block:
pl_.show() # Better named as TriggerMatplotlibEventLoop
if should_return_figure:
if should_return_axes:
return figure, axes
else:
return figure
elif should_return_axes:
return axes
def _PlotWithNetworkX(
self,
positions_as_dict: Dict[str, Tuple[int, ...]],
axes: pl_.axes.Axes,
colormap: dict,
font_size: int,
width: float,
) -> None:
#
if colormap is None:
colormap = skl_graph_t.colormap
if width is None:
width = skl_graph_t.width
node_degrees = (elm[1] for elm in self.degree)
node_colors = tuple(
colormap[degree] if degree < 3 else colormap[2] for degree in node_degrees
)
nx_.draw_networkx(
self,
ax=axes,
pos=positions_as_dict,
with_labels=True,
node_color=node_colors,
font_size=font_size,
width=width,
)
nx_.draw_networkx_edge_labels(
self,
ax=axes,
pos=positions_as_dict,
edge_labels=self._EdgeIDsForPlot(),
font_size=font_size,
)
def _PlotExplicitly(
self,
positions_as_dict: Dict[str, Tuple[int, ...]],
transformation: Callable[[array_t], array_t],
vector_transf: Callable[[array_t], array_t],
axes: pl_.axes.Axes,
font_size: int,
as_curve: bool,
w_directions: bool,
) -> None:
#
dg_.Plot(
self.edges.data("as_edge_t"),
transformation,
vector_transf,
axes,
as_curve=as_curve,
w_directions=w_directions,
)
nd_.PlotEndNodes(self.nodes.data("as_node_t"), transformation, axes)
if self.dim == 2:
nd_.Plot2DBranchNodes(self.nodes.data("as_node_t"), transformation, axes)
nx_.draw_networkx_labels(
self, ax=axes, pos=positions_as_dict, font_size=font_size
)
else:
nd_.Plot3DBranchNodes(self.nodes.data("as_node_t"), transformation, axes)
nd_.Plot3DNodeLabels(self, positions_as_dict, axes, font_size)
def _PlotWithGraphviz(self, axes: pl_.axes.Axes) -> None:
#
try:
import imageio as io_
import pygraphviz as gp_
import tempfile as tp_
graph = nx_.nx_agraph.to_agraph(self)
with tp_.NamedTemporaryFile() as tmp_accessor:
img_name = tmp_accessor.name
graph.layout()
graph.draw(img_name, format="png")
axes.imshow(io_.imread(img_name))
except Exception as exc:
axes.text(
0,
0,
f"Unable to plot graph using pygraphviz/imageio.\nPlease check installed modules.\n[{exc}]",
horizontalalignment="center",
)
def _EdgeIDsForPlot(self) -> Dict[str, str]:
#
lengths_as_dict = nx_.get_edge_attributes(self, "length")
w_lengths_as_dict = (
nx_.get_edge_attributes(self, "w_length") if self.has_widths else None
)
w_length_str = ""
edge_ids = {}
for key, value in lengths_as_dict.items():
if w_lengths_as_dict is not None:
w_length_str = "/" + str(round(w_lengths_as_dict[key]))
edge_ids[key[0:2]] = key[2] + "\n" + str(round(value)) + w_length_str
return edge_ids
def _Sphere(
row: int, col: int, dep: int, radius: int, shape: Tuple[int, int, int]
) -> array_t:
#
sphere = np_.zeros(shape, dtype=np_.bool)
# dw_.ellipsoid leaves a one pixel margin around the ellipse, hence [1:-1, 1:-1, 1:-1]
ellipse = dw_.ellipsoid(radius, radius, radius)[1:-1, 1:-1, 1:-1]
sp_slices = tuple(
slice(0, min(sphere.shape[idx_], ellipse.shape[idx_])) for idx_ in (0, 1, 2)
)
sphere[sp_slices] = ellipse[sp_slices]
return im_.shift(
sphere, (row - radius, col - radius, dep - radius), order=0, prefilter=False
)