# Copyright CNRS/Inria/UNS # Contributor(s): Eric Debreuve (since 2019), Morgane Nadal (2020) # # eric.debreuve@cnrs.fr # # This software is governed by the CeCILL license under French law and # abiding by the rules of distribution of free software. You can use, # modify and/ or redistribute the software under the terms of the CeCILL # license as circulated by CEA, CNRS and INRIA at the following URL # "http://www.cecill.info". # # As a counterpart to the access to the source code and rights to copy, # modify and redistribute granted by the license, users are provided only # with a limited warranty and the software's author, the holder of the # economic rights, and the successive licensors have only limited # liability. # # In this respect, the user's attention is drawn to the risks associated # with loading, using, modifying and/or developing or reproducing the # software by the user in light of its specific status of free software, # that may mean that it is complicated to manipulate, and that also # therefore means that it is reserved for developers and experienced # professionals having in-depth computer knowledge. Users are therefore # encouraged to load and test the software's suitability as regards their # requirements in conditions enabling the security of their systems and/or # data to be ensured and, more generally, to use and operate it in the # same conditions as regards security. # # The fact that you are presently reading this means that you have had # knowledge of the CeCILL license and that you accept its terms. from __future__ import annotations import brick.processing.frangi3 as fg_ import brick.processing.map_labeling as ml_ from brick.component.glial_cmp import glial_cmp_t from brick.general.type import array_t, site_h from sklgraph.brick.edge import _ReOrderedSites from typing import Optional, Sequence, Tuple, Callable import numpy as np_ import skimage.filters as fl_ import skimage.measure as ms_ import skimage.morphology as mp_ from scipy import ndimage as im_ import matplotlib.pyplot as pl_ _CENTER_3x3 = ((0, 0, 0), (0, 1, 0), (0, 0, 0)) _CROSS_3x3 = np_.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=np_.uint8) _CROSS_3x3x3 = np_.array((_CENTER_3x3, _CROSS_3x3, _CENTER_3x3), dtype=np_.uint8) _CROSS_FOR_DIM = (None, None, _CROSS_3x3, _CROSS_3x3x3) _FULL_SHIFTS_FOR_2D_NEIGHBORS = tuple( (i, j) for i in (-1, 0, 1) for j in (-1, 0, 1) if i != 0 or j != 0 ) _FULL_SHIFTS_FOR_3D_NEIGHBORS = tuple( (i, j, k) for i in (-1, 0, 1) for j in (-1, 0, 1) for k in (-1, 0, 1) if i != 0 or j != 0 or k != 0 ) _FULL_SHIFTS_FOR_NEIGHBORS_FOR_DIM = ( None, None, _FULL_SHIFTS_FOR_2D_NEIGHBORS, _FULL_SHIFTS_FOR_3D_NEIGHBORS, ) _MIN_SHIFTS_FOR_2D_NEIGHBORS = tuple(elm for elm in _FULL_SHIFTS_FOR_2D_NEIGHBORS if np_.abs(elm).sum() == 1) _MIN_SHIFTS_FOR_3D_NEIGHBORS = tuple(elm for elm in _FULL_SHIFTS_FOR_3D_NEIGHBORS if np_.abs(elm).sum() == 1) _MIN_SHIFTS_FOR_NEIGHBORS_FOR_DIM = ( None, None, _MIN_SHIFTS_FOR_2D_NEIGHBORS, _MIN_SHIFTS_FOR_3D_NEIGHBORS, ) _SQUARE_3x3 = np_.ones((3, 3), dtype=np_.uint8) _SQUARE_3x3x3 = np_.ones((3, 3, 3), dtype=np_.uint8) _LABELIZED_MAP_8_fct = lambda map_: im_.label( map_, structure=_SQUARE_3x3, output=np_.int64 ) # type: Callable[[array_t], Tuple[array_t, int]] _LABELIZED_MAP_26_fct = lambda map_: im_.label( map_, structure=_SQUARE_3x3x3, output=np_.int64 ) # type: Callable[[array_t], Tuple[array_t, int]] LABELIZED_MAP_fct_FOR_DIM = (None, None, _LABELIZED_MAP_8_fct, _LABELIZED_MAP_26_fct) class extension_t(glial_cmp_t): # # soma_uid: connected to a soma somewhere upstream # __slots__ = ("end_points", "scales", "soma_uid", "__cache__") def __init__(self): # super().__init__() for slot in self.__class__.__slots__: setattr(self, slot, None) @classmethod def FromMap(cls, lmp: array_t, scales: array_t, uid: int) -> extension_t: ''' Initialize and create the extension object based on the labelled map. ''' # instance = cls() # Create a boolean map keeping only the extension number 'uid'. bmp = lmp == uid # Initialize the object with its different fields instance.InitializeFromMap(bmp, uid) # cls.sites = _ReOrderedSites(sites) # Find the endpoints sites of the extension end_point_map = cls.EndPointMap(bmp) instance.end_points = end_point_map.nonzero() # Store the frangi scales of the extensions instance.scales = scales[instance.sites] instance.__cache__ = {} return instance @property def is_unconnected(self) -> bool: # return self.soma_uid is None @property def end_points_as_array(self) -> array_t: # pty_name = 'end_points_as_array' if pty_name not in self.__cache__: self.__cache__[pty_name] = np_.array(self.end_points) return self.__cache__[pty_name] def EndPointsForSomaOrExt( self, soma_uid: int, influence_map: array_t ) -> Tuple[site_h, ...]: ''' Find the extensions endpoints in the influence of the soma. ''' # Create a boolean map with the endpoints under the soma's influence ep_bmp = influence_map[self.end_points] == soma_uid # bmp=boolean map # Create a list of endpoints if ep_bmp.any(): # Extract the endpoints indices from the map end_point_idc = ep_bmp.nonzero()[0] # Find the endpoint coordinates based on their indices in the endpoint soma/extension object end_points = self.end_points_as_array[:, end_point_idc] # Return coordinates of endpoints under a tuple[(x1,y1,z1),(x2,y2,z2),...] format return tuple(zip(*end_points.tolist())) # If no endpoints, return an empty tuple return () def BackReferenceSoma(self, glial_cmp: glial_cmp_t) -> None: # if isinstance(glial_cmp, extension_t): self.soma_uid = glial_cmp.soma_uid else: self.soma_uid = glial_cmp.uid def __str__(self) -> str: # if self.extensions is None: n_extensions = 0 else: n_extensions = self.extensions.__len__() return ( f"Ext.{self.uid}, " f"sites={self.sites[0].__len__()}, " f"endpoints={self.end_points[0].__len__()}, " f"soma={self.soma_uid}, " f"extensions={n_extensions}" ) @staticmethod def ExtensionContainingSite( extensions: Sequence[extension_t], site: site_h ) -> Optional[extension_t]: ''' Return extension if a given site is contained into the extension ''' # for extension in extensions: if site in tuple(zip(*extension.sites)): return extension return None @staticmethod def EnhancedForDetection( image: array_t, scale_range, scale_step, alpha, beta, frangi_c, bright_on_dark, method, diff_mode, in_parallel: bool = False ) -> Tuple[array_t, array_t]: ''' Preprocess by white top hat. Perform Frangi vesselness enhancement. ''' # # import os.path as ph_ # if ph_.exists("./__runtime__/frangi.npz"): # print("/!\\ Reading from precomputed data file") # loaded = np_.load("./frangi.npz") # enhanced_img = loaded["enhanced_img"] # scale_map = loaded["scale_map"] # # return enhanced_img, scale_map preprocessed_img = im_.morphology.white_tophat( image, size=2, mode="constant", cval=0.0, origin=0 ) enhanced_img, scale_map = fg_.FrangiEnhancement( preprocessed_img, scale_range, scale_step, alpha, beta, frangi_c, bright_on_dark, in_parallel, method, diff_mode, ) # enhanced_img, scale_map = fl_.frangi( # image=preprocessed_img, # scale_range=scale_range, # scale_step=scale_step, # alpha=alpha, # beta=beta, # gamma=frangi_c, # black_ridges=bright_on_dark) # np_.savez_compressed( # "./runtime/frangi.npz", enhanced_img=enhanced_img, scale_map=scale_map # ) return enhanced_img, scale_map @staticmethod def CoarseMap(image: array_t, low: float, high: float, selem: array_t) -> array_t: ''' Perform hysteresis thresholding and closing/opening. ''' # result = image.copy() if (low is not None) and (high is not None): result = __HysterisisImage__(result, low, high) if selem is not None: result = __MorphologicalCleaning__(result, selem) return result @staticmethod def FilteredCoarseMap(map_: array_t, ext_min_area_c: int) -> array_t: ''' Delete elements with area inferior to the allowed minimum area. ''' # result = map_.copy() # Label the extensions lmp = ms_.label(map_) # Measure the area of each extension for region in ms_.regionprops(lmp): # Delete the ones too small by setting their voxel to 0 if region.area <= ext_min_area_c: region_sites = (lmp == region.label).nonzero() result[region_sites] = 0 lmp[region_sites] = 0 return result, lmp @staticmethod def FineMapFromCoarseMap(coarse_map: array_t) -> array_t: ''' Skeletonize the 3D coarse map. Might contain True-voxels that could be removed w/o breaking connectivity ''' # result = mp_.skeletonize_3d(coarse_map.astype(np_.uint8, copy=False)) # Thinning of the skeleton ThinMap(result) return result.astype(np_.int8, copy=False) @staticmethod def EndPointMap(map_: array_t) -> array_t: ''' Find the endpoints of the extensions. Endpoints ony have one pixel of connectivity. ''' # Find the 26-connectivity of each voxel part_map = ml_.PartLMap(map_) # The background is labeled with 27, and endpoints have a connectivity of 1. result = part_map == 1 return result.astype(np_.int8) def __HysterisisImage__(image: array_t, low: float, high: float) -> array_t: ''' Perform hysteresis, based on the image intensities. ''' # nonzero_sites = (image > 0).nonzero() nonzero_values = image[nonzero_sites] # print(nonzero_values.min(), image.max()) low = low * nonzero_values.min() high = high * image.max() # print("low=", low, " high=", high) # lowt = low*(x_image_f-min_image_f)+max_image_f # hight = high*(max_image_f- min_image_f)+min_image_f # lowt = (image_f >lowt).astype(int) # hight = (image_f <hight).astype(int) result = fl_.apply_hysteresis_threshold(image, low, high) result = result.astype(np_.int8, copy=False) return result def __MorphologicalCleaning__(image: array_t, selem) -> array_t: ''' Perform closing and opening of the image. ''' # result = image.copy() for dep in range(result.shape[0]): result[dep, :, :] = mp_.closing(result[dep, :, :], selem) result[dep, :, :] = mp_.opening(result[dep, :, :], selem) return result def ThinMap(skl_map: array_t) -> None: ''' Removes all pixels that do not break 8- or 26-connectivity Works for multi-skeleton ''' background_label = BackgroundLabelForTmp(skl_map) def FixLocalMap_n( padded_sm_: array_t, part_map_: array_t, n_neighbors_: int, cross_: array_t, labelled_map_fct_: Callable[[array_t], Tuple[array_t, int]], ) -> bool: # skel_has_been_modified_ = False center = padded_sm_.ndim * (1,) for coords in zip(*np_.where(part_map_ == n_neighbors_)): lm_slices = tuple(slice(coord - 1, coord + 2) for coord in coords) local_map = padded_sm_[lm_slices] local_part_map = part_map_[lm_slices] if (local_part_map[cross_] == background_label).any(): local_map[center] = 0 _, n_components = labelled_map_fct_(local_map) if n_components == 1: skel_has_been_modified_ = True else: local_map[center] = 1 return skel_has_been_modified_ padded_map = np_.pad(skl_map, 1, "constant") cross = _CROSS_FOR_DIM[skl_map.ndim] labelized_map_fct = LABELIZED_MAP_fct_FOR_DIM[skl_map.ndim] excluded_n_neighbors = { 0, 1, 2 * skl_map.ndim, background_label, } skel_has_been_modified = True while skel_has_been_modified: skel_has_been_modified = False part_map = TopologyMapOfSkeleton(padded_map, full_connectivity=False) included_n_neighbors = set(np_.unique(part_map)).difference( excluded_n_neighbors ) for n_neighbors in sorted(included_n_neighbors, reverse=True): skel_has_been_modified = skel_has_been_modified or FixLocalMap_n( padded_map, part_map, n_neighbors, cross, labelized_map_fct, ) if skl_map.ndim == 2: skl_map[:, :] = padded_map[1:-1, 1:-1] else: skl_map[:, :, :] = padded_map[1:-1, 1:-1, 1:-1] def BackgroundLabelForTmp(a_map: array_t) -> int: """ Must be equal to the max number of neighbors in a skeleton + 1. Note: using a_map avoids shadowing Python's map. """ return 3 ** a_map.ndim def TopologyMapOfSkeleton(skl_map: array_t, full_connectivity: bool = True) -> array_t: ''' The topology map is labeled as follows: background=invalid_n_neighbors_Xd_c; Pixels of the skeleton=number of neighboring pixels that belong to the skeleton (as expected, isolated pixels receive 0). Works for multi-skeleton ''' # tmap = np_.array(skl_map, dtype=np_.int8) if full_connectivity: shifts_for_dim = _FULL_SHIFTS_FOR_NEIGHBORS_FOR_DIM else: shifts_for_dim = _MIN_SHIFTS_FOR_NEIGHBORS_FOR_DIM padded_sm = np_.pad(skl_map, 1, "constant") unpadding_domain = skl_map.ndim * (slice(1, -1),) rolling_axes = tuple(range(skl_map.ndim)) for shifts in shifts_for_dim[skl_map.ndim]: tmap += np_.roll(padded_sm, shifts, axis=rolling_axes)[unpadding_domain] tmap[skl_map == 0] = BackgroundLabelForTmp(skl_map) + 1 return tmap - 1 def MaximumIntensityProjectionZ(img: array_t, cmap: str ='tab20', axis: int = 0, output_image_file_name: str = None) -> None: """ Maximum Image Projection on the Z axis. """ # xy = np_.amax(img, axis=axis) pl_.imshow(xy, cmap=cmap) pl_.show(block=True) if output_image_file_name is not None: pl_.imsave(output_image_file_name, xy, cmap=cmap) print('Image saved in', output_image_file_name) def EndPointMap(map_: array_t) -> array_t: ''' Find the endpoints of the extensions. Endpoints ony have one pixel of connectivity. ''' # Find the 26-connectivity of each voxel part_map = ml_.PartLMap(map_) # The background is labeled with 27, and endpoints have a connectivity of 1. result = part_map == 1 return result.astype(np_.int8)