Mentions légales du service

Skip to content
Snippets Groups Projects
soma.py 7.61 KiB
Newer Older
DEBREUVE Eric's avatar
DEBREUVE Eric committed
from __future__ import annotations

import dijkstra_1_to_n as dk_
import map_labeling as ml_
from extension import extension_t
from type import array_t, py_array_picker_h, site_h

from collections import namedtuple as namedtuple_t
from typing import Optional, Sequence, Tuple

import numpy as np_
import scipy.ndimage as im_
import skimage.filters as fl_
import skimage.measure as ms_
import skimage.morphology as mp_


som_ext_path_t = namedtuple_t("som_ext_path_t", "extension length sites")


max_straight_sq_dist = 30 ** 2
max_weighted_dist = 20.0


min_area_c = 1000


class soma_t:
    #
    # bmp=boolean map
    # lmp=labeled map (intXX or uintXX array)
    # map=extension map (map=binary, int8 or uint8 array))
    #
    __slots__ = (
        "uid",
        "lmp_ref",
        "contour_points",
        "connection_path",
        "extension_uids",
    )

    def __init__(self):
        #
        self.uid = None
        self.lmp_ref = None
        self.contour_points = None
        self.connection_path = None
        self.extension_uids = None

    @classmethod
    def FromMaps(cls, lmp: array_t, contour_lmp: array_t, uid: int) -> soma_t:
        #
        instance = cls()

        instance.uid = uid
        instance.lmp_ref = lmp
        instance.contour_points = tuple(zip(*(contour_lmp == uid).nonzero()))
        instance.connection_path = {}
        instance.extension_uids = []

        return instance

    @property
    def has_extensions(self) -> bool:
        #
        return self.extension_uids.__len__() > 0

    def ContourPointsCloseTo(
        self, point: site_h, max_distance: float
    ) -> Tuple[Optional[Tuple[site_h, ...]], Optional[py_array_picker_h]]:
        #
        points = tuple(
            contour_point
            for contour_point in self.contour_points
            if (np_.subtract(point, contour_point) ** 2).sum() <= max_distance
        )

        if points.__len__() > 0:
            points_as_picker = tuple(zip(*points))
        else:
            points = None
            points_as_picker = None

        return points, points_as_picker

    def Extend(
        self, extensions: Sequence[extension_t], dist_to_soma: array_t, costs: array_t
    ) -> None:
        #
        candidate_ext_eps = []  # eps=end points
        for extension in extensions:
            new_candidates = extension.EndPointsForSoma(self.uid)
            candidate_ext_eps.extend(
                (end_point, extension) for end_point in new_candidates
            )
        candidate_ext_eps.sort(key=lambda elm: dist_to_soma[elm[0]])

        while True:
            som_ext_paths = []
            for ep_idx, (ext_end_point, extension) in enumerate(candidate_ext_eps):
                if extension.soma_uid is not None:
                    continue

                close_contour_points, contour_indexing = self.ContourPointsCloseTo(
                    ext_end_point, max_straight_sq_dist
                )
                if close_contour_points is None:
                    continue

                costs[ext_end_point] = 0.0
                costs[contour_indexing] = 0.0
                print(f"    Soma.{self.uid} <-?-> Ext.{extension.uid}.{ep_idx}")
                sites, length = dk_.DijkstraShortestPath(
                    costs, ext_end_point, close_contour_points
                )
                costs[ext_end_point] = np_.inf
                costs[contour_indexing] = np_.inf
                if length <= max_weighted_dist:
                    som_ext_paths.append(
                        som_ext_path_t(extension=extension, length=length, sites=sites)
                    )
                    if sites.__len__() == 2:
                        break

            if som_ext_paths.__len__() > 0:
                som_ext_paths.sort(key=lambda path: path.length)
                shorest_path = som_ext_paths[0]
                connection_path = tuple(zip(*shorest_path.sites[1:-1]))
                if connection_path.__len__() == 0:
                    connection_path = None

                closest_extension = shorest_path.extension
                closest_extension.soma_uid = self.uid
                extension_path = closest_extension.sites

                self.connection_path[closest_extension.uid] = connection_path
                self.extension_uids.append(closest_extension.uid)

                # TODO: Ideally, these paths should be dilated
                # but in ext-ext connections, there must not be dilation around the current ext
                # (current ext plays the role of a soma in soma-ext step)
                if connection_path is not None:
                    costs[connection_path] = np_.inf
                costs[extension_path] = np_.inf

                print(f"        => {self.uid} <-> {closest_extension.uid}")
            else:
                break

    @staticmethod
    def Map(image: array_t, low: float, high: float, selem: array_t) -> array_t:
        #
        # low = 10 #0.15
        # high = 67.4 # 0.7126
        #
        max_image = image.max()
        nonzero_sites = image.nonzero()
        nonzero_values = image[nonzero_sites]
        min_image = nonzero_values.min()

        low = low * (max_image - min_image) + min_image
        high = high * (max_image - min_image) + min_image
        result = fl_.apply_hysteresis_threshold(image, low, high)
        result = result.astype(np_.int8)

        for dep in range(image.shape[0]):
            result[dep, :, :] = mp_.closing(result[dep, :, :], selem)
            result[dep, :, :] = mp_.opening(result[dep, :, :], selem)

        return result

    @staticmethod
    def FilteredMap(map_: array_t) -> array_t:

        result = map_.copy()
        lmp = ms_.label(map_)

        for region in ms_.regionprops(lmp):
            if region.area <= min_area_c:
                region_sites = (lmp == region.label).nonzero()
                result[region_sites] = 0

        return result

    @staticmethod
    def ContourMap(map_: array_t) -> array_t:
        #
        part_map = ml_.PartLMap(map_)
        # Works because the background is labeled with 27
        result = part_map < 26

        return result.astype(np_.int8)

    @staticmethod
    def InfluenceMaps(map_: array_t) -> Tuple[array_t, array_t]:

        background = (map_ == 0).astype(np_.int8)
        dist_map, idx_map = im_.morphology.distance_transform_edt(
            background, return_indices=True
        )

        # obj_map = np_.empty_like(map_)
        # for row in range(0, obj_map.shape[0]):
        #     for col in range(0, obj_map.shape[1]):
        #         for dep in range(0, obj_map.shape[2]):
        #             obj_map[row, col, dep] = map_[
        #                 idx_map[0, row, col, dep],
        #                 idx_map[1, row, col, dep],
        #                 idx_map[2, row, col, dep],
        #             ]

        return dist_map, np_.array(map_[tuple(idx_map)])

    @staticmethod
    def SomasWithExtensionsLMap(
        somas: Sequence[soma_t], soma_lmp: array_t, extensions: Sequence[extension_t]
    ) -> array_t:
        #
        result = soma_lmp.copy()

        for soma in somas:
            for ext_uid in soma.extension_uids:
                connection_path = soma.connection_path[ext_uid]
                if connection_path is not None:
                    result[connection_path] = soma.uid
                for extension in extensions:
                    if extension.uid == ext_uid:
                        result[extension.sites] = soma.uid

        return result


def NormalizedImage(image: array_t) -> array_t:

    nonextreme_values = image[np_.logical_and(image > 0.0, image < image.max())]
    nonextreme_avg = np_.mean(nonextreme_values)
    result = image / nonextreme_avg

    return result