Mentions légales du service

Skip to content
Snippets Groups Projects
skl_map.py 11 KiB
Newer Older
NADAL Morgane's avatar
NADAL Morgane committed
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2018)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

# It is implicit that all functions below require the skeleton map to be valid (as tested by SkeletonIsValid())

from __future__ import annotations

from typing import Callable, Optional, Tuple

import numpy as np_
import scipy.ndimage as im_
import skimage.morphology as mp_


array_t = np_.array


_CENTER_3x3 = ((0, 0, 0), (0, 1, 0), (0, 0, 0))
_CROSS_3x3 = np_.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=np_.uint8)
_CROSS_3x3x3 = np_.array((_CENTER_3x3, _CROSS_3x3, _CENTER_3x3), dtype=np_.uint8)
_CROSS_FOR_DIM = (None, None, _CROSS_3x3, _CROSS_3x3x3)

_SHIFTS_FOR_2D_NEIGHBORS = tuple(
    (i, j) for i in (-1, 0, 1) for j in (-1, 0, 1) if i != 0 or j != 0
)
_SHIFTS_FOR_3D_NEIGHBORS = tuple(
    (i, j, k)
    for i in (-1, 0, 1)
    for j in (-1, 0, 1)
    for k in (-1, 0, 1)
    if i != 0 or j != 0 or k != 0
)
_SHIFTS_FOR_NEIGHBORS_FOR_DIM = (
    None,
    None,
    _SHIFTS_FOR_2D_NEIGHBORS,
    _SHIFTS_FOR_3D_NEIGHBORS,
)

_SQUARE_3x3 = np_.ones((3, 3), dtype=np_.uint8)
_SQUARE_3x3x3 = np_.ones((3, 3, 3), dtype=np_.uint8)
_LABELIZED_MAP_8_fct = lambda map_: im_.label(
    map_, structure=_SQUARE_3x3, output=np_.int64
)  # type: Callable[[array_t], Tuple[array_t, int]]
_LABELIZED_MAP_26_fct = lambda map_: im_.label(
    map_, structure=_SQUARE_3x3x3, output=np_.int64
)  # type: Callable[[array_t], Tuple[array_t, int]]
LABELIZED_MAP_fct_FOR_DIM = (None, None, _LABELIZED_MAP_8_fct, _LABELIZED_MAP_26_fct)


class skl_map_t:
    #
    __slots__ = (
        "invalid_n_neighbors",
        "map",
        "widths",
    )  # widths=Distances from skeleton to shape border

    def __init__(self):
        #
        for slot in self.__class__.__slots__:
            setattr(self, slot, None)

    @classmethod
    def FromSkeletonMap(
        cls, skeleton_map: array_t, check_validity: Optional[str] = "single"
    ) -> skl_map_t:
        #
        # check_validity: None, "single" or "multi"
        #
        if check_validity is None:
            pass
        elif check_validity == "single":
            if not skl_map_t.SkeletonIsValid(skeleton_map, verbose=True):
                raise ValueError("Invalid skeleton")
        elif check_validity == "multi":
            if not skl_map_t.MultiSkeletonIsValid(skeleton_map, verbose=True):
                raise ValueError("Invalid multi-skeleton")
        else:
            raise ValueError(f"{check_validity}: Invalid \"check_validity\" value")

        instance = cls()

        # Must be equal to the max number of neighbors in a skeleton + 1.
        # Used for the background.
        instance.invalid_n_neighbors = InvalidNNeighborsForMap(skeleton_map)
        instance.map = skeleton_map.copy()

        return instance

    @classmethod
    def FromShapeMap(
        cls,
        shape_map: array_t,
NADAL Morgane's avatar
NADAL Morgane committed
        do_post_thinning: bool = True,
        store_widths: bool = False,
    ) -> skl_map_t:
        #
        # Works for multi-skeleton if mp_.thin and mp_.skeletonize_3d do
        # TODO: check about mp_.thin and mp_.skeletonize_3d
        #
        if skeletonize:
            if shape_map.ndim == 2:
                # Doc says it removes every pixel up to breaking connectivity
                bmap = mp_.thin(shape_map)  # Not boolean map yet
            elif shape_map.ndim == 3:
                bmap = mp_.skeletonize_3d(shape_map)  # Not boolean map yet
            else:
                raise ValueError(f"{shape_map.ndim}: Invalid map dimension; Expected: 2 or 3")
NADAL Morgane's avatar
NADAL Morgane committed
        else:
NADAL Morgane's avatar
NADAL Morgane committed
        # >0: because max can be 255, which turns into -1 with int8 conversion
        bmap = bmap > 0  # Now it's a boolean map

        instance = cls()

        instance.invalid_n_neighbors = InvalidNNeighborsForMap(shape_map)
        instance.map = bmap.astype(np_.int8)  # Not uint to allow for subtraction
        if do_post_thinning:
            instance.FixMap()

        if store_widths:
            instance.widths = 2.0 * im_.distance_transform_edt(shape_map) + 1.0

        return instance

    def FixMap(self) -> None:
        #
        # Removes all pixels that do not break 8- or 26-connectivity
        #
        # Works for multi-skeleton
        #
        def FixLocalMap_n(
            padded_sm_: array_t,
            part_map_: array_t,
            n_neighbors_: int,
            cross_: array_t,
            invalid_n_neighbors_: int,
            labelized_map_fct_: Callable[[array_t], Tuple[array_t, int]],
        ) -> bool:
            #
            skel_has_been_modified_ = False

            center = padded_sm_.ndim * (1,)
            for coords in zip(*np_.where(part_map_ == n_neighbors_)):
                lm_slices = tuple(slice(coord - 1, coord + 2) for coord in coords)
                local_map = padded_sm_[lm_slices]
                if (local_map[cross_] == invalid_n_neighbors_).any():
                    local_map[center] = 0

                    _, n_components = labelized_map_fct_(local_map)
                    if n_components == 1:
                        skel_has_been_modified_ = True
                    else:
                        local_map[center] = 1

            return skel_has_been_modified_

        padded_map = np_.pad(self.map, 1, "constant")

        cross = _CROSS_FOR_DIM[self.map.ndim]
        labelized_map_fct = LABELIZED_MAP_fct_FOR_DIM[self.map.ndim]

        excluded_n_neighbors = {
            0,
            1,
            self.invalid_n_neighbors - 1,
            self.invalid_n_neighbors,
        }
        skel_has_been_modified = True
        while skel_has_been_modified:
            skel_has_been_modified = False

            part_map = SkeletonPartMap(padded_map, check_validity=None)
            included_n_neighbors = set(np_.unique(part_map)).difference(
                excluded_n_neighbors
            )

            for n_neighbors in sorted(included_n_neighbors, reverse=True):
                skel_has_been_modified = skel_has_been_modified or FixLocalMap_n(
                    padded_map,
                    part_map,
                    n_neighbors,
                    cross,
                    self.invalid_n_neighbors,
                    labelized_map_fct,
                )

        if self.map.ndim == 2:
            self.map[:, :] = padded_map[1:-1, 1:-1]
        else:
            self.map[:, :, :] = padded_map[1:-1, 1:-1, 1:-1]

    def PruneBasedOnWidth(self: array_t, min_width: float) -> None:
        #
        # Works for multi-skeleton
        #
        while True:
            part_map = self.PartMap()
            end_positions = np_.where(part_map == 1)
            distances = self.widths[end_positions]

            tiny_distances = distances < min_width
            if tiny_distances.any():
                extra_positions = tuple(site[tiny_distances] for site in end_positions)
                self.map[extra_positions] = 0
            else:
                break

    def PartMap(self: array_t) -> array_t:
        #
        '''
        The part map is labeled as follows: background=invalid_n_neighbors_Xd_c; Pixels of the skeleton=number of
        neighboring pixels that belong to the skeleton (as expected, isolated pixels receive 0).

        Works for multi-skeleton
        '''

NADAL Morgane's avatar
NADAL Morgane committed
        part_map = self.map.copy()
        padded_sm = np_.pad(self.map, 1, "constant")

        unpadding_domain = self.map.ndim * (slice(1, -1),)
        for shifts in _SHIFTS_FOR_NEIGHBORS_FOR_DIM[self.map.ndim]:
            part_map += np_.roll(padded_sm, shifts, axis=range(self.map.ndim))[
                unpadding_domain
            ]
        # if self.map.ndim == 2:
        #     for shifts in _SHIFTS_FOR_2D_NEIGHBORS:
        #         part_map += np_.roll(padded_sm, shifts, axis=(0, 1))[1:-1, 1:-1]
        # else:
        #     for shifts in _SHIFTS_FOR_3D_NEIGHBORS:
        #         part_map += np_.roll(padded_sm, shifts, axis=(0, 1, 2))[
        #             1:-1, 1:-1, 1:-1
        #         ]

        part_map[self.map == 0] = self.invalid_n_neighbors + 1

        return part_map - 1

    @staticmethod
    def SkeletonIsValid(skeleton_map: array_t, verbose: bool = False) -> bool:
        #
        if skl_map_t.MultiSkeletonIsValid(skeleton_map, verbose=verbose):
            _, n_components = LABELIZED_MAP_fct_FOR_DIM[skeleton_map.ndim](skeleton_map)
            if n_components == 1:
                return True

            if verbose:
                print("Skeleton map has more than one connected component")

            return False

        return False

    @staticmethod
    def MultiSkeletonIsValid(skeleton_map: array_t, verbose: bool = False) -> bool:
        #
        if (skeleton_map.ndim != 2) and (skeleton_map.ndim != 3):
            if verbose:
                print(
                    f"Skeleton map must be 2- or 3-dimensional; "
                    f"Actual dimensionality: {skeleton_map.ndim}"
                )

            return False

        if np_.array_equal(np_.unique(skeleton_map), (0, 1)):
            return True

        if verbose:
            print("Skeleton map has values other than zero and one")

        return False


def InvalidNNeighborsForMap(map: array_t) -> int:
    #
    return 3 ** map.ndim


def SkeletonPartMap(
    skeleton_map: array_t, check_validity: Optional[str] = "single"
) -> array_t:
    #
    # The part map is labeled as follows: background=invalid_n_neighbors_Xd_c; Pixels of the skeleton=number of
    # neighboring pixels that belong to the skeleton (as expected, isolated pixels receive 0).
    #
    # Works for multi-skeleton
    #
    skeleton = skl_map_t.FromSkeletonMap(skeleton_map, check_validity=check_validity)

    return skeleton.PartMap()