Mentions légales du service

Skip to content
Snippets Groups Projects
edge.py 11.1 KiB
Newer Older
NADAL Morgane's avatar
NADAL Morgane committed
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2018)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

from __future__ import annotations

import sklgraph.brick.elm_id as id_
from sklgraph.skl_map import LABELIZED_MAP_fct_FOR_DIM
from brick.processing.input import ToMicron

NADAL Morgane's avatar
NADAL Morgane committed
import itertools as it_
from collections import namedtuple as namedtuple_t
from typing import Callable, Iterable, List, Tuple, cast

import matplotlib.pyplot as pl_
import numpy as np_
import scipy.interpolate as in_
import scipy.spatial.distance as dt_
import skimage.measure as ms_


array_t = np_.ndarray
# ww_length=width-weighted length
# sq_lengths=squared lengths; Interest: all integers
edge_lengths_t = namedtuple_t("edge_lengths_t", "length ww_length lengths sq_lengths")


class edge_t:
    #
    __slots__ = (
        "uid_",  # There is a uid property, hence the underscore here (see notes in property)
        "node_uids",
        "origin_node",
        "dim",
        "sites",
        "widths",
        "lengths",
        "as_curve",
        "origin_direction",
        "final_direction",
    )

    def __init__(self):
        #
        # origin_node: Node ID of node closest to (sites[0][0], sites[1][0])
        #
        super().__init__()
        for slot in self.__class__.__slots__:
            setattr(self, slot, None)

        self.node_uids = []

    @classmethod
    def WithSites(cls, sites: Tuple[array_t, ...]) -> edge_t:
        #
        instance = cls()

        instance.dim = sites.__len__()
        instance.sites = _ReOrderedSites(sites)

        return instance

    def SetWidths(self, widths: array_t) -> None:
        #
        if self.node_uids.__len__() != 2:
            raise ValueError("Edge: Missing sites from adjacent nodes")

        self.widths = widths[self.sites]

    def SetLengths(self, size_voxel: array_t, widths: array_t = None, check_validity: bool = False) -> None:
NADAL Morgane's avatar
NADAL Morgane committed
        #
        if self.node_uids.__len__() != 2:
            raise ValueError("Edge: Missing sites from adjacent nodes")

        sites_as_array = np_.array(self.sites)
        segments = np_.diff(sites_as_array, axis=1)
        # segmentsT = segments.transpose()
        # sq_lengths = (segmentsT.dot(np_.diag(size_voxel)).dot(segments)).sum(axis=0)
        sq_lengths = (segments ** 2).sum(axis=0)
NADAL Morgane's avatar
NADAL Morgane committed
        lengths = np_.sqrt(sq_lengths)
        length = lengths.sum().item()

        if (self.widths is None) and (widths is None):
            ww_length = -1.0
        else:
            if widths is not None:
                # If one bothers to pass widths, use it even if it overrides previous settings
                self.SetWidths(widths)
            ww_length = (
                (0.5 * (self.widths[1:] + self.widths[:-1]) * lengths).sum().item()
            )

        self.lengths = edge_lengths_t(
            length=length, ww_length=ww_length, lengths=lengths, sq_lengths=sq_lengths
        )

        if check_validity:
            # A global condition: self.sites[0].size - 1 <= length
            if cast(array_t, sq_lengths == 0).any():
                raise ValueError("Edge: Repeated sites")
            if cast(array_t, sq_lengths > self.sites.__len__()).any():
                raise ValueError("Edge: Site gaps")

    def SetCurveRepresentation(self, size_voxel: list) -> None:
NADAL Morgane's avatar
NADAL Morgane committed
        #
        if self.node_uids.__len__() != 2:
            raise ValueError("Edge: Missing sites from adjacent nodes")

        if self.sites[0].__len__() > 1:
            if self.lengths is None:
                self.SetLengths(size_voxel=size_voxel)
NADAL Morgane's avatar
NADAL Morgane committed
            arc_lengths = tuple(it_.accumulate((0, *self.lengths.sq_lengths.tolist())))
            self.as_curve = tuple(
                in_.PchipInterpolator(arc_lengths, self.sites[idx_])
                for idx_ in range(self.dim)
            )

    def SetEndPointDirections(self, size_voxel: list) -> None:
NADAL Morgane's avatar
NADAL Morgane committed
        #
        if self.as_curve is None:
            self.SetCurveRepresentation(size_voxel=size_voxel)
NADAL Morgane's avatar
NADAL Morgane committed

        if self.as_curve is not None:
            max_arclength = self.as_curve[0].x.item(-1)
            o_dir, f_dir = [], []
            for d_idx in range(self.dim):
                directions = self.as_curve[d_idx]((0, max_arclength), 1)
                o_dir.append(directions[0])
                f_dir.append(directions[1])
            self.origin_direction = np_.array(o_dir, dtype=np_.float64) / (
                -np_.linalg.norm(o_dir)
            )
            self.final_direction = np_.array(
                f_dir, dtype=np_.float64
            ) / np_.linalg.norm(f_dir)

    @property
    def uid(self) -> str:
        """
        node_uids are not set at edge instantiation. This property can then be called later when they are.
        """
        if self.uid_ is None:
            if self.node_uids.__len__() != 2:
                raise ValueError("Edge: Missing sites from adjacent nodes")

            node_uid_0, node_uid_1 = self.node_uids
            if node_uid_0 > node_uid_1:
                node_uid_0, node_uid_1 = node_uid_1, node_uid_0

            edge_id = [
                id_.EncodedNumber(coord) for coord in node_uid_0.split(id_.coord_sep_c)
            ]
            edge_id.append(id_.coord_sep_c)
            edge_id.extend(
                id_.EncodedNumber(coord) for coord in node_uid_1.split(id_.coord_sep_c)
            )

            self.uid_ = "".join(edge_id)

        return self.uid_


def RawEdges(
    skeleton_map: array_t, b_node_lmap: array_t
) -> Tuple[List[edge_t], array_t]:
    #
    # raw = no valid node labels yet
    #
    edge_map = skeleton_map.copy()
    edge_map[b_node_lmap > 0] = 0
    edge_lmap, n_edges = LABELIZED_MAP_fct_FOR_DIM[skeleton_map.ndim](edge_map)

    edge_props = ms_.regionprops(edge_lmap)

    edges = n_edges * [edge_t()]
    for props in edge_props:
        sites = props.image.nonzero()
        for d_idx in range(skeleton_map.ndim):
            sites[d_idx].__iadd__(props.bbox[d_idx])
        edges[props.label - 1] = edge_t().WithSites(sites)

    return edges, edge_lmap


def Plot(
    edges: Iterable[Tuple[str, str, edge_t]],
    transformation: Callable[[array_t], array_t],
    vector_transf: Callable[[array_t], array_t],
    axes: pl_.axes.Axes,
    size_voxel : list,
NADAL Morgane's avatar
NADAL Morgane committed
    as_curve: bool = False,
    w_directions: bool = False,
) -> None:
    #
    space_dim = 2
    for ___, ___, edge in edges:
        space_dim = edge.dim
        break

    plot_fct = axes.plot if space_dim == 2 else axes.plot3D
    plot_style = "k" if as_curve else "k."

    for origin, destination, edge in edges:
        if as_curve:
            if edge.as_curve is None:
                edge.SetCurveRepresentation(size_voxel)
NADAL Morgane's avatar
NADAL Morgane committed

            if edge.as_curve is None:
                sites = list(edge.sites)
            else:
                max_arc_length = edge.as_curve[0].x.item(-1)
                step = 0.125
                arc_lengths = np_.arange(0.0, max_arc_length + 0.5 * step, step)
                sites = list(
                    edge.as_curve[idx_](arc_lengths) for idx_ in range(space_dim)
                )
        else:
            sites = list(edge.sites)
        sites[0], sites[1] = sites[1], transformation(sites[0])

        line_style = ":" if origin == destination else "-"
        plot_fct(*sites, plot_style + line_style, linewidth=2, markersize=7)

        if w_directions:
            if edge.origin_direction is None:
                edge.SetEndPointDirections(size_voxel)
NADAL Morgane's avatar
NADAL Morgane committed
            if edge.origin_direction is not None:
                dir_sites = tuple(
                    np_.hstack((sites[idx_][0], sites[idx_][-1]))
                    for idx_ in range(space_dim)
                )
                directions = list(zip(edge.origin_direction, edge.final_direction))
                directions[0], directions[1] = (
                    directions[1],
                    vector_transf(directions[0]),
                )
                axes.quiver(*dir_sites, *directions, color="b", linewidth=2)


def _ReOrderedSites(sites: Tuple[array_t, ...]) -> Tuple[array_t, ...]:
    #
    n_sites = sites[0].__len__()

    if n_sites > 2:
        sites_as_array = np_.transpose(np_.array(sites))
        pairwise_dists = dt_.squareform(dt_.pdist(sites_as_array, "chebyshev"))
        reordered_sites_nfo = [(0, sites_as_array[0, :])]
        visited_sites = {0}

        while visited_sites.__len__() < n_sites:
            s_idx, first_site = reordered_sites_nfo[0]
            neighbor_idc = list(
                set((pairwise_dists[s_idx, :] == 1).nonzero()[0]) - visited_sites
            )
            # Length is equal to zero when reaching an extremity
            if neighbor_idc.__len__() > 0:
                reordered_sites_nfo.insert(
                    0, (neighbor_idc[0], sites_as_array[neighbor_idc[0], :])
                )
                visited_sites.add(neighbor_idc[0])

            if neighbor_idc.__len__() == 2:
                # The one seed + the one just added above = 2
                assert reordered_sites_nfo.__len__() == 2
                neighbor_idc[0] = neighbor_idc[1]
            else:
                s_idx, last_point = reordered_sites_nfo[-1]
                neighbor_idc = tuple(
                    set((pairwise_dists[s_idx, :] == 1).nonzero()[0]) - visited_sites
                )
                # Length is equal to zero when reaching an extremity
                if neighbor_idc.__len__() == 0:
                    continue

            reordered_sites_nfo.append(
                (neighbor_idc[0], sites_as_array[neighbor_idc[0], :])
            )
            visited_sites.add(neighbor_idc[0])

        reordered_coords = np_.array(
            tuple(site_nfo[1] for site_nfo in reordered_sites_nfo)
        )
        reordered_coords = tuple(
            reordered_coords[:, idx_] for idx_ in range(sites.__len__())
        )
    #
    else:
        reordered_coords = sites

    return reordered_coords