Newer
Older
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2018)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.
from __future__ import annotations
import sklgraph.brick.elm_id as id_
from sklgraph.skl_map import LABELIZED_MAP_fct_FOR_DIM
from brick.processing.input import ToMicron
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import itertools as it_
from collections import namedtuple as namedtuple_t
from typing import Callable, Iterable, List, Tuple, cast
import matplotlib.pyplot as pl_
import numpy as np_
import scipy.interpolate as in_
import scipy.spatial.distance as dt_
import skimage.measure as ms_
array_t = np_.ndarray
# ww_length=width-weighted length
# sq_lengths=squared lengths; Interest: all integers
edge_lengths_t = namedtuple_t("edge_lengths_t", "length ww_length lengths sq_lengths")
class edge_t:
#
__slots__ = (
"uid_", # There is a uid property, hence the underscore here (see notes in property)
"node_uids",
"origin_node",
"dim",
"sites",
"widths",
"lengths",
"as_curve",
"origin_direction",
"final_direction",
)
def __init__(self):
#
# origin_node: Node ID of node closest to (sites[0][0], sites[1][0])
#
super().__init__()
for slot in self.__class__.__slots__:
setattr(self, slot, None)
self.node_uids = []
@classmethod
def WithSites(cls, sites: Tuple[array_t, ...]) -> edge_t:
#
instance = cls()
instance.dim = sites.__len__()
instance.sites = _ReOrderedSites(sites)
return instance
def SetWidths(self, widths: array_t) -> None:
#
if self.node_uids.__len__() != 2:
raise ValueError("Edge: Missing sites from adjacent nodes")
self.widths = widths[self.sites]
def SetLengths(self, size_voxel: array_t, widths: array_t = None, check_validity: bool = False) -> None:
#
if self.node_uids.__len__() != 2:
raise ValueError("Edge: Missing sites from adjacent nodes")
sites_as_array = np_.array(self.sites)
segments = np_.diff(sites_as_array, axis=1)
# segmentsT = segments.transpose()
# sq_lengths = (segmentsT.dot(np_.diag(size_voxel)).dot(segments)).sum(axis=0)
sq_lengths = (segments ** 2).sum(axis=0)
lengths = np_.sqrt(sq_lengths)
length = lengths.sum().item()
if (self.widths is None) and (widths is None):
ww_length = -1.0
else:
if widths is not None:
# If one bothers to pass widths, use it even if it overrides previous settings
self.SetWidths(widths)
ww_length = (
(0.5 * (self.widths[1:] + self.widths[:-1]) * lengths).sum().item()
)
self.lengths = edge_lengths_t(
length=length, ww_length=ww_length, lengths=lengths, sq_lengths=sq_lengths
)
if check_validity:
# A global condition: self.sites[0].size - 1 <= length
if cast(array_t, sq_lengths == 0).any():
raise ValueError("Edge: Repeated sites")
if cast(array_t, sq_lengths > self.sites.__len__()).any():
raise ValueError("Edge: Site gaps")
def SetCurveRepresentation(self, size_voxel: list) -> None:
#
if self.node_uids.__len__() != 2:
raise ValueError("Edge: Missing sites from adjacent nodes")
if self.sites[0].__len__() > 1:
if self.lengths is None:
self.SetLengths(size_voxel=size_voxel)
arc_lengths = tuple(it_.accumulate((0, *self.lengths.sq_lengths.tolist())))
self.as_curve = tuple(
in_.PchipInterpolator(arc_lengths, self.sites[idx_])
for idx_ in range(self.dim)
)
def SetEndPointDirections(self, size_voxel: list) -> None:
self.SetCurveRepresentation(size_voxel=size_voxel)
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
if self.as_curve is not None:
max_arclength = self.as_curve[0].x.item(-1)
o_dir, f_dir = [], []
for d_idx in range(self.dim):
directions = self.as_curve[d_idx]((0, max_arclength), 1)
o_dir.append(directions[0])
f_dir.append(directions[1])
self.origin_direction = np_.array(o_dir, dtype=np_.float64) / (
-np_.linalg.norm(o_dir)
)
self.final_direction = np_.array(
f_dir, dtype=np_.float64
) / np_.linalg.norm(f_dir)
@property
def uid(self) -> str:
"""
node_uids are not set at edge instantiation. This property can then be called later when they are.
"""
if self.uid_ is None:
if self.node_uids.__len__() != 2:
raise ValueError("Edge: Missing sites from adjacent nodes")
node_uid_0, node_uid_1 = self.node_uids
if node_uid_0 > node_uid_1:
node_uid_0, node_uid_1 = node_uid_1, node_uid_0
edge_id = [
id_.EncodedNumber(coord) for coord in node_uid_0.split(id_.coord_sep_c)
]
edge_id.append(id_.coord_sep_c)
edge_id.extend(
id_.EncodedNumber(coord) for coord in node_uid_1.split(id_.coord_sep_c)
)
self.uid_ = "".join(edge_id)
return self.uid_
def RawEdges(
skeleton_map: array_t, b_node_lmap: array_t
) -> Tuple[List[edge_t], array_t]:
#
# raw = no valid node labels yet
#
edge_map = skeleton_map.copy()
edge_map[b_node_lmap > 0] = 0
edge_lmap, n_edges = LABELIZED_MAP_fct_FOR_DIM[skeleton_map.ndim](edge_map)
edge_props = ms_.regionprops(edge_lmap)
edges = n_edges * [edge_t()]
for props in edge_props:
sites = props.image.nonzero()
for d_idx in range(skeleton_map.ndim):
sites[d_idx].__iadd__(props.bbox[d_idx])
edges[props.label - 1] = edge_t().WithSites(sites)
return edges, edge_lmap
def Plot(
edges: Iterable[Tuple[str, str, edge_t]],
transformation: Callable[[array_t], array_t],
vector_transf: Callable[[array_t], array_t],
axes: pl_.axes.Axes,
as_curve: bool = False,
w_directions: bool = False,
) -> None:
#
space_dim = 2
for ___, ___, edge in edges:
space_dim = edge.dim
break
plot_fct = axes.plot if space_dim == 2 else axes.plot3D
plot_style = "k" if as_curve else "k."
for origin, destination, edge in edges:
if as_curve:
if edge.as_curve is None:
edge.SetCurveRepresentation(size_voxel)
if edge.as_curve is None:
sites = list(edge.sites)
else:
max_arc_length = edge.as_curve[0].x.item(-1)
step = 0.125
arc_lengths = np_.arange(0.0, max_arc_length + 0.5 * step, step)
sites = list(
edge.as_curve[idx_](arc_lengths) for idx_ in range(space_dim)
)
else:
sites = list(edge.sites)
sites[0], sites[1] = sites[1], transformation(sites[0])
line_style = ":" if origin == destination else "-"
plot_fct(*sites, plot_style + line_style, linewidth=2, markersize=7)
if w_directions:
if edge.origin_direction is None:
edge.SetEndPointDirections(size_voxel)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
if edge.origin_direction is not None:
dir_sites = tuple(
np_.hstack((sites[idx_][0], sites[idx_][-1]))
for idx_ in range(space_dim)
)
directions = list(zip(edge.origin_direction, edge.final_direction))
directions[0], directions[1] = (
directions[1],
vector_transf(directions[0]),
)
axes.quiver(*dir_sites, *directions, color="b", linewidth=2)
def _ReOrderedSites(sites: Tuple[array_t, ...]) -> Tuple[array_t, ...]:
#
n_sites = sites[0].__len__()
if n_sites > 2:
sites_as_array = np_.transpose(np_.array(sites))
pairwise_dists = dt_.squareform(dt_.pdist(sites_as_array, "chebyshev"))
reordered_sites_nfo = [(0, sites_as_array[0, :])]
visited_sites = {0}
while visited_sites.__len__() < n_sites:
s_idx, first_site = reordered_sites_nfo[0]
neighbor_idc = list(
set((pairwise_dists[s_idx, :] == 1).nonzero()[0]) - visited_sites
)
# Length is equal to zero when reaching an extremity
if neighbor_idc.__len__() > 0:
reordered_sites_nfo.insert(
0, (neighbor_idc[0], sites_as_array[neighbor_idc[0], :])
)
visited_sites.add(neighbor_idc[0])
if neighbor_idc.__len__() == 2:
# The one seed + the one just added above = 2
assert reordered_sites_nfo.__len__() == 2
neighbor_idc[0] = neighbor_idc[1]
else:
s_idx, last_point = reordered_sites_nfo[-1]
neighbor_idc = tuple(
set((pairwise_dists[s_idx, :] == 1).nonzero()[0]) - visited_sites
)
# Length is equal to zero when reaching an extremity
if neighbor_idc.__len__() == 0:
continue
reordered_sites_nfo.append(
(neighbor_idc[0], sites_as_array[neighbor_idc[0], :])
)
visited_sites.add(neighbor_idc[0])
reordered_coords = np_.array(
tuple(site_nfo[1] for site_nfo in reordered_sites_nfo)
)
reordered_coords = tuple(
reordered_coords[:, idx_] for idx_ in range(sites.__len__())
)
#
else:
reordered_coords = sites
return reordered_coords