Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 986ed789 authored by NADAL Morgane's avatar NADAL Morgane
Browse files

update: thinning after skeletonization in nutrimorph.py and skl_map.py

parent 40ccd520
No related branches found
No related tags found
No related merge requests found
......@@ -36,7 +36,7 @@ import brick.processing.map_labeling as ml_
from brick.component.glial_cmp import glial_cmp_t
from brick.general.type import array_t, site_h
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, Callable
import numpy as np_
import skimage.filters as fl_
......@@ -46,6 +46,48 @@ from scipy import ndimage as im_
import matplotlib.pyplot as pl_
_CENTER_3x3 = ((0, 0, 0), (0, 1, 0), (0, 0, 0))
_CROSS_3x3 = np_.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=np_.uint8)
_CROSS_3x3x3 = np_.array((_CENTER_3x3, _CROSS_3x3, _CENTER_3x3), dtype=np_.uint8)
_CROSS_FOR_DIM = (None, None, _CROSS_3x3, _CROSS_3x3x3)
_FULL_SHIFTS_FOR_2D_NEIGHBORS = tuple(
(i, j) for i in (-1, 0, 1) for j in (-1, 0, 1) if i != 0 or j != 0
)
_FULL_SHIFTS_FOR_3D_NEIGHBORS = tuple(
(i, j, k)
for i in (-1, 0, 1)
for j in (-1, 0, 1)
for k in (-1, 0, 1)
if i != 0 or j != 0 or k != 0
)
_FULL_SHIFTS_FOR_NEIGHBORS_FOR_DIM = (
None,
None,
_FULL_SHIFTS_FOR_2D_NEIGHBORS,
_FULL_SHIFTS_FOR_3D_NEIGHBORS,
)
_MIN_SHIFTS_FOR_2D_NEIGHBORS = tuple(elm for elm in _FULL_SHIFTS_FOR_2D_NEIGHBORS if np_.abs(elm).sum() == 1)
_MIN_SHIFTS_FOR_3D_NEIGHBORS = tuple(elm for elm in _FULL_SHIFTS_FOR_3D_NEIGHBORS if np_.abs(elm).sum() == 1)
_MIN_SHIFTS_FOR_NEIGHBORS_FOR_DIM = (
None,
None,
_MIN_SHIFTS_FOR_2D_NEIGHBORS,
_MIN_SHIFTS_FOR_3D_NEIGHBORS,
)
_SQUARE_3x3 = np_.ones((3, 3), dtype=np_.uint8)
_SQUARE_3x3x3 = np_.ones((3, 3, 3), dtype=np_.uint8)
_LABELIZED_MAP_8_fct = lambda map_: im_.label(
map_, structure=_SQUARE_3x3, output=np_.int64
) # type: Callable[[array_t], Tuple[array_t, int]]
_LABELIZED_MAP_26_fct = lambda map_: im_.label(
map_, structure=_SQUARE_3x3x3, output=np_.int64
) # type: Callable[[array_t], Tuple[array_t, int]]
LABELIZED_MAP_fct_FOR_DIM = (None, None, _LABELIZED_MAP_8_fct, _LABELIZED_MAP_26_fct)
class extension_t(glial_cmp_t):
#
# soma_uid: connected to a soma somewhere upstream
......@@ -254,6 +296,8 @@ class extension_t(glial_cmp_t):
'''
#
result = mp_.skeletonize_3d(coarse_map.astype(np_.uint8, copy=False))
# TODO thinning of the skeleton
TurnSkeletonMapIntoSKLMapByThining(result)
return result.astype(np_.int8, copy=False)
......@@ -325,6 +369,102 @@ def __MorphologicalCleaning__(image: array_t, selem) -> array_t:
return result
def TurnSkeletonMapIntoSKLMapByThining(skl_map: array_t) -> None:
'''
Removes all pixels that do not break 8- or 26-connectivity
Works for multi-skeleton
'''
background_label = BackgroundLabelForTmp(skl_map)
def FixLocalMap_n(
padded_sm_: array_t,
part_map_: array_t,
n_neighbors_: int,
cross_: array_t,
labelled_map_fct_: Callable[[array_t], Tuple[array_t, int]],
) -> bool:
#
skel_has_been_modified_ = False
center = padded_sm_.ndim * (1,)
for coords in zip(*np_.where(part_map_ == n_neighbors_)):
lm_slices = tuple(slice(coord - 1, coord + 2) for coord in coords)
local_map = padded_sm_[lm_slices]
local_part_map = part_map_[lm_slices]
if (local_part_map[cross_] == background_label).any():
local_map[center] = 0
_, n_components = labelled_map_fct_(local_map)
if n_components == 1:
skel_has_been_modified_ = True
else:
local_map[center] = 1
return skel_has_been_modified_
padded_map = np_.pad(skl_map, 1, "constant")
cross = _CROSS_FOR_DIM[skl_map.ndim]
labelized_map_fct = LABELIZED_MAP_fct_FOR_DIM[skl_map.ndim]
excluded_n_neighbors = {
0,
1,
2 * skl_map.ndim,
background_label,
}
skel_has_been_modified = True
while skel_has_been_modified:
skel_has_been_modified = False
part_map = TopologyMapOfSkeleton(padded_map, full_connectivity=False)
included_n_neighbors = set(np_.unique(part_map)).difference(
excluded_n_neighbors
)
for n_neighbors in sorted(included_n_neighbors, reverse=True):
skel_has_been_modified = skel_has_been_modified or FixLocalMap_n(
padded_map, part_map, n_neighbors, cross, labelized_map_fct,
)
if skl_map.ndim == 2:
skl_map[:, :] = padded_map[1:-1, 1:-1]
else:
skl_map[:, :, :] = padded_map[1:-1, 1:-1, 1:-1]
def BackgroundLabelForTmp(a_map: array_t) -> int:
"""
Must be equal to the max number of neighbors in a skeleton + 1.
Note: using a_map avoids shadowing Python's map.
"""
return 3 ** a_map.ndim
def TopologyMapOfSkeleton(skl_map: array_t, full_connectivity: bool = True) -> array_t:
'''
The topology map is labeled as follows: background=invalid_n_neighbors_Xd_c; Pixels of the skeleton=number of
neighboring pixels that belong to the skeleton (as expected, isolated pixels receive 0).
Works for multi-skeleton
'''
#
tmap = np_.array(skl_map, dtype=np_.int8)
if full_connectivity:
shifts_for_dim = _FULL_SHIFTS_FOR_NEIGHBORS_FOR_DIM
else:
shifts_for_dim = _MIN_SHIFTS_FOR_NEIGHBORS_FOR_DIM
padded_sm = np_.pad(skl_map, 1, "constant")
unpadding_domain = skl_map.ndim * (slice(1, -1),)
rolling_axes = tuple(range(skl_map.ndim))
for shifts in shifts_for_dim[skl_map.ndim]:
tmap += np_.roll(padded_sm, shifts, axis=rolling_axes)[unpadding_domain]
tmap[skl_map == 0] = BackgroundLabelForTmp(skl_map) + 1
return tmap - 1
def MaximumIntensityProjectionZ(img: array_t, cmap: str ='tab20', axis: int = 0, output_image_file_name: str = None) -> None:
""" Maximum Image Projection on the Z axis. """
#
......
......@@ -449,7 +449,7 @@ print('\n--- Graph extraction')
for soma in somas:
print(f" Soma {soma.uid}", end="")
# Create SKLGraph skeletonized map
ext_map = skl_map_t.FromShapeMap(ext_nfo['lmp_soma'] == soma.uid, store_widths=True, skeletonize=False, do_post_thinning=False)
ext_map = skl_map_t.FromShapeMap(ext_nfo['lmp_soma'] == soma.uid, store_widths=True, skeletonize=False, do_post_thinning=True)
# do_post_thinning = True, in order to remove pixels that are not breaking connectivity
# Create the graph from the SKLGaph skeletonized map
......
......@@ -64,6 +64,14 @@ _SHIFTS_FOR_NEIGHBORS_FOR_DIM = (
_SHIFTS_FOR_2D_NEIGHBORS,
_SHIFTS_FOR_3D_NEIGHBORS,
)
_MIN_SHIFTS_FOR_2D_NEIGHBORS = tuple(elm for elm in _SHIFTS_FOR_2D_NEIGHBORS if np_.abs(elm).sum() == 1)
_MIN_SHIFTS_FOR_3D_NEIGHBORS = tuple(elm for elm in _SHIFTS_FOR_3D_NEIGHBORS if np_.abs(elm).sum() == 1)
_MIN_SHIFTS_FOR_NEIGHBORS_FOR_DIM = (
None,
None,
_MIN_SHIFTS_FOR_2D_NEIGHBORS,
_MIN_SHIFTS_FOR_3D_NEIGHBORS,
)
_SQUARE_3x3 = np_.ones((3, 3), dtype=np_.uint8)
_SQUARE_3x3x3 = np_.ones((3, 3, 3), dtype=np_.uint8)
......@@ -174,7 +182,8 @@ class skl_map_t:
for coords in zip(*np_.where(part_map_ == n_neighbors_)):
lm_slices = tuple(slice(coord - 1, coord + 2) for coord in coords)
local_map = padded_sm_[lm_slices]
if (local_map[cross_] == invalid_n_neighbors_).any():
local_part_map = part_map_[lm_slices]
if (local_part_map[cross_] == invalid_n_neighbors_).any():
local_map[center] = 0
_, n_components = labelized_map_fct_(local_map)
......@@ -193,14 +202,14 @@ class skl_map_t:
excluded_n_neighbors = {
0,
1,
self.invalid_n_neighbors - 1,
2 * self.map.ndim,
self.invalid_n_neighbors,
}
skel_has_been_modified = True
while skel_has_been_modified:
skel_has_been_modified = False
part_map = SkeletonPartMap(padded_map, check_validity=None)
part_map = SkeletonPartMap(padded_map, full_connectivity=False, check_validity=None)
included_n_neighbors = set(np_.unique(part_map)).difference(
excluded_n_neighbors
)
......@@ -236,7 +245,7 @@ class skl_map_t:
else:
break
def PartMap(self: array_t) -> array_t:
def PartMap(self: array_t, full_connectivity: bool = True) -> array_t:
#
'''
The part map is labeled as follows: background=invalid_n_neighbors_Xd_c; Pixels of the skeleton=number of
......@@ -244,12 +253,15 @@ class skl_map_t:
Works for multi-skeleton
'''
part_map = self.map.copy()
padded_sm = np_.pad(self.map, 1, "constant")
if full_connectivity:
shifts_for_dim = _SHIFTS_FOR_NEIGHBORS_FOR_DIM
else:
shifts_for_dim = _MIN_SHIFTS_FOR_NEIGHBORS_FOR_DIM
unpadding_domain = self.map.ndim * (slice(1, -1),)
for shifts in _SHIFTS_FOR_NEIGHBORS_FOR_DIM[self.map.ndim]:
for shifts in shifts_for_dim[self.map.ndim]:
part_map += np_.roll(padded_sm, shifts, axis=range(self.map.ndim))[
unpadding_domain
]
......@@ -308,7 +320,7 @@ def InvalidNNeighborsForMap(map: array_t) -> int:
def SkeletonPartMap(
skeleton_map: array_t, check_validity: Optional[str] = "single"
skeleton_map: array_t, full_connectivity: bool = True, check_validity: Optional[str] = "single"
) -> array_t:
#
# The part map is labeled as follows: background=invalid_n_neighbors_Xd_c; Pixels of the skeleton=number of
......@@ -318,4 +330,4 @@ def SkeletonPartMap(
#
skeleton = skl_map_t.FromSkeletonMap(skeleton_map, check_validity=check_validity)
return skeleton.PartMap()
return skeleton.PartMap(full_connectivity=full_connectivity)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment