Mentions légales du service

Skip to content
Snippets Groups Projects
dijkstra_1_to_n.py 13.1 KiB
Newer Older
DEBREUVE Eric's avatar
DEBREUVE Eric committed
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2019)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

"""
Dijkstra Shortest Weighted Path from one image/volume pixel/voxel to one or more image/volume pixel(s)/voxel(s)
    Graph nodes = pixels/voxels = sites
    Graph edges = pixel/voxel neighborhood relationships
    Edge weights = typically computed from pixel/voxel values

Adapted from material of a course by Marc Pegon:
    https://www.ljll.math.upmc.fr/pegon/teaching.html
    https://www.ljll.math.upmc.fr/pegon/documents/BCPST/TP06_src.tar.gz
    https://www.researchgate.net/profile/Marc_Pegon
"""
import heapq as hp_
from typing import Final, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np_
import scipy.ndimage.morphology as mp_

# Slightly slower alternative (look for SSA below)
# import scipy.ndimage as im_
# import skimage.draw as dw_


number_h = Union[int, float]
tintegers_h = Tuple[int, ...]  # tintegers=tuple of integers

site_h = tintegers_h
path_h = Tuple[site_h, ...]
path_nfo_h = Tuple[path_h, float]

array_t = np_.ndarray


def DijkstraCosts(image: array_t, som_map: array_t, ext_map: array_t) -> array_t:
    """
    Gives the value inf if the voxel belongs to a soma or an extension.
    Otherwise, gives the value 1 / (voxel intensity + 1).

    The closer to cost 1, the less probable a connexion will take this path.
    The closer to cost 0.5, the more probable.
    """
    dijkstra_costs = 1.0 / (image + 1.0)
    dijkstra_costs[np_.logical_or(som_map > 0, ext_map > 0)] = np_.inf

    return dijkstra_costs


def DijkstraShortestPath(
    costs: array_t,
    origin: site_h,
    target: Union[site_h, Sequence[site_h]],
    limit_to_sphere: bool = True,
    constrain_direction: bool = True,
    return_all: bool = False,
) -> Union[path_nfo_h, Tuple[path_nfo_h, ...]]:
    """
    Perform the Dijkstra shortest weighted path algorithm
    """
    #
    costs, targets, valid_directions, dir_lengths = _DijkstraShortestPathPrologue(
        costs,
        origin,
        target,
        limit_to_sphere,
        constrain_direction,
    )

    nearest_sites = nearest_site_queue_t()
    nearest_sites.Insert(0, origin)
    min_distance_to = {origin: 0.0}
    predecessor_of = {}
    visited_sites = set()

    while True:
        site_nfo = nearest_sites.Pop()
        if site_nfo is None:
            # Empty queue: full graph traversal did not allow to reach all targets
            # This case is correctly dealt with in the following
            break

        distance, site = site_nfo
        if site in targets:
            targets.remove(site)
            if targets.__len__() > 0:
                continue
            else:
                break

        if site not in visited_sites:
            visited_sites.add(site)
            for successor, edge_length in _OutgoingEdges(
                site, valid_directions, dir_lengths, costs
            ):
                successor = tuple(successor)
                next_distance = distance + edge_length
                min_distance = min_distance_to.get(successor)
                if (min_distance is None) or (next_distance < min_distance):
                    min_distance_to[successor] = next_distance
                    predecessor_of[successor] = site
                    nearest_sites.Insert(next_distance, successor)

    if isinstance(target[0], Sequence):
        targets = tuple(target)
    else:
        targets = (target,)

    if return_all:
        all_paths = []
        for one_target in targets:
            distance = min_distance_to.get(one_target)
            if distance is None:
                all_paths.append(((), None))
            else:
                path = []
                site = one_target
                while site is not None:
                    path.append(site)
                    site = predecessor_of.get(site)
                all_paths.append((tuple(reversed(path)), distance))

        return tuple(all_paths)
    else:
        min_distance = np_.inf
        closest_target = None
        for one_target in targets:
            distance = min_distance_to.get(one_target)
            if (distance is not None) and (distance < min_distance):
                min_distance = distance
                closest_target = one_target

        path = []
        if closest_target is None:
            distance = None
        else:
            distance = min_distance_to.get(closest_target)
            site = closest_target
            while site is not None:
                path.append(site)
                site = predecessor_of.get(site)

        return tuple(reversed(path)), distance


def _DijkstraShortestPathPrologue(
    costs: array_t,
    origin: site_h,
    target: Union[site_h, Sequence[site_h]],
    limit_to_sphere: bool,
    constrain_direction: bool,
) -> Tuple[array_t, List[site_h], array_t, array_t]:
    #
    if isinstance(target[0], Sequence):
        targets = list(target)
    else:
        targets = [target]

    if limit_to_sphere:
        costs = _SphereLimitedCosts(costs, origin, targets)

    if costs.ndim == 2:
        if constrain_direction:
            valid_directions, dir_lengths = _FilteredDirections(
                DIRECTIONS_2D, LENGTHS_2D, origin, targets
            )
        else:
            valid_directions, dir_lengths = DIRECTIONS_2D, LENGTHS_2D
    elif costs.ndim == 3:
        if constrain_direction:
            valid_directions, dir_lengths = _FilteredDirections(
                DIRECTIONS_3D, LENGTHS_3D, origin, targets
            )
        else:
            valid_directions, dir_lengths = DIRECTIONS_3D, LENGTHS_3D
    else:
        raise ValueError(f"Cost matrix has {costs.ndim} dimension(s); Expecting 2 or 3")

    return costs, targets, valid_directions, dir_lengths


def _OutgoingEdges(
    site: site_h, valid_directions: array_t, dir_lengths: array_t, costs: array_t
) -> Iterator:
    #
    neighbors = valid_directions + np_.array(site, dtype=valid_directions.dtype)
    n_dims = site.__len__()

    inside = np_.all(neighbors >= 0, axis=1)
    for c_idx in range(n_dims):
        np_.logical_and(
            inside, neighbors[:, c_idx] <= costs.shape[c_idx] - 1, out=inside
        )
    neighbors = neighbors[inside, :]
    dir_lengths = dir_lengths[inside]

    # For any n_dims: neighbors_costs = costs[tuple(zip(*neighbors.tolist()))]
    if n_dims == 2:
        neighbors_costs = costs[(neighbors[:, 0], neighbors[:, 1])]
    else:
        neighbors_costs = costs[(neighbors[:, 0], neighbors[:, 1], neighbors[:, 2])]

    valid_sites = np_.isfinite(neighbors_costs)  # Excludes inf and nan
    neighbors = neighbors[valid_sites, :]
    neighbors_costs = neighbors_costs[valid_sites] * dir_lengths[valid_sites]

    return zip(neighbors.tolist(), neighbors_costs)


DIRECTIONS_2D: Final = np_.array(
    tuple((i, j) for i in (-1, 0, 1) for j in (-1, 0, 1) if i != 0 or j != 0),
    dtype=np_.int16,
)

DIRECTIONS_3D: Final = np_.array(
    tuple(
        (i, j, k)
        for i in (-1, 0, 1)
        for j in (-1, 0, 1)
        for k in (-1, 0, 1)
        if i != 0 or j != 0 or k != 0
    ),
    dtype=np_.int16,
)

LENGTHS_2D: Final = np_.linalg.norm(DIRECTIONS_2D, axis=1)
LENGTHS_3D: Final = np_.linalg.norm(DIRECTIONS_3D, axis=1)


def _FilteredDirections(
    all_directions: array_t,
    all_lengths: array_t,
    origin: site_h,
    targets: Sequence[site_h],
) -> Tuple[array_t, array_t]:
    #
    n_dims = origin.__len__()
    n_targets = targets.__len__()

    straight_lines = np_.empty((n_dims, n_targets), dtype=np_.float64)
    for p_idx, target in enumerate(targets):
        for c_idx in range(n_dims):
            straight_lines[c_idx, p_idx] = target[c_idx] - origin[c_idx]

    inner_prods = all_directions @ straight_lines
    valid_directions = np_.any(inner_prods >= 0, axis=1)

    return all_directions[valid_directions, :], all_lengths[valid_directions]


def _SphereLimitedCosts(
    costs: array_t, origin: site_h, targets: Sequence[site_h]
) -> array_t:
    """
    Note: Un-numba-ble because of slice
    """
    valid_sites = np_.zeros_like(costs, dtype=np_.bool)
    center_map = np_.ones_like(costs, dtype=np_.uint8)
    distances = np_.empty_like(costs, dtype=np_.float64)

    targets_as_array = np_.array(targets)
    centers = np_.around(0.5 * targets_as_array.__add__(origin)).astype(
        np_.int64, copy=False
    )
    centers = tuple(tuple(center) for center in centers)

    n_dims = origin.__len__()

    for t_idx, target in enumerate(targets):
        sq_radius = max(
            (np_.subtract(centers[t_idx], origin) ** 2).sum(),
            (np_.subtract(centers[t_idx], target) ** 2).sum(),
        )
        radius = np_.sqrt(sq_radius).astype(np_.int64, copy=False) + 1
        # Note the +1 in slices ends to account for right-open ranginess
        bbox = tuple(
            slice(
                max(centers[t_idx][c_idx] - radius, 0),
                min(centers[t_idx][c_idx] + radius + 1, distances.shape[c_idx]),
            )
            for c_idx in range(n_dims)
        )
        if t_idx > 0:
            center_map[centers[t_idx - 1]] = 1
        center_map[centers[t_idx]] = 0
        mp_.distance_transform_edt(center_map[bbox], distances=distances[bbox])
        distance_thr = max(distances[origin], distances[target])

        valid_sites[bbox][distances[bbox] <= distance_thr] = True
        # # SSA
        # if n_dims == 2:
        #     valid_coords = dw_.circle(
        #         *centers[t_idx], radius, shape=valid_sites.shape
        #     )
        # else:
        #     local_shape = tuple(slc.stop - slc.start for slc in bbox)
        #     local_center = tuple(centers[t_idx][c_idx] - slc.start for c_idx, slc in enumerate(bbox))
        #     valid_coords = __SphereCoords__(*local_center, radius, shape=local_shape)
        #     valid_coords = tuple(valid_coords[c_idx] + slc.start for c_idx, slc in enumerate(bbox))
        # valid_sites[valid_coords] = True

    local_cost = costs.copy()
    local_cost[np_.logical_not(valid_sites)] = np_.inf

    return local_cost


class nearest_site_queue_t:
    #
    __slots__ = ("heap", "insertion_idx", "visited_sites")

    def __init__(self):
        #
        self.heap = []
        self.insertion_idx = 0
        self.visited_sites = {}

    def Insert(self, distance: number_h, site: site_h) -> None:
        """
        Insert a new site with its distance, or update the distance of an existing site
        """
        if site in self.visited_sites:
            self._Delete(site)
        self.insertion_idx += 1
        site_nfo = [distance, self.insertion_idx, site]
        self.visited_sites[site] = site_nfo
        hp_.heappush(self.heap, site_nfo)

    def Pop(self) -> Optional[Tuple[number_h, site_h]]:
        """
        Return (distance, site) for the site of minimum distance, or None if queue is empty
        """
        while self.heap:
            distance, _, site = hp_.heappop(self.heap)
            if site is not None:
                del self.visited_sites[site]
                return distance, site

        return None

    def _Delete(self, site: site_h) -> None:
        #
        site_nfo = self.visited_sites.pop(site)
        site_nfo[-1] = None


# # SSA
# def __SphereCoords__(
#     row: int, col: int, dep: int, radius: int, shape: Tuple[int, int, int]
# ) -> np_array_picker_h:
#     #
#     sphere = np_.zeros(shape, dtype=np_.bool)
#     # dw_.ellipsoid leaves a one pixel margin around the ellipse, hence [1:-1, 1:-1, 1:-1]
#     ellipse = dw_.ellipsoid(radius, radius, radius)[1:-1, 1:-1, 1:-1]
#     sp_slices = tuple(
#         slice(0, min(sphere.shape[idx_], ellipse.shape[idx_])) for idx_ in (0, 1, 2)
#     )
#     sphere[sp_slices] = ellipse[sp_slices]
#
#     sphere = im_.shift(
#         sphere, (row - radius, col - radius, dep - radius), order=0, prefilter=False
#     )
#
#     return sphere.nonzero()