from __future__ import annotations

import brick.processing.frangi3 as fg_
import brick.processing.map_labeling as ml_
from brick.component.glial_cmp import glial_cmp_t
from brick.general.type import array_t, site_h

from typing import Optional, Sequence, Tuple, Callable

import numpy as np_
import skimage.filters as fl_
import skimage.measure as ms_
import skimage.morphology as mp_
from scipy import ndimage as im_
import matplotlib.pyplot as pl_

_CENTER_3x3 = ((0, 0, 0), (0, 1, 0), (0, 0, 0))
_CROSS_3x3 = np_.array(((0, 1, 0), (1, 1, 1), (0, 1, 0)), dtype=np_.uint8)
_CROSS_3x3x3 = np_.array((_CENTER_3x3, _CROSS_3x3, _CENTER_3x3), dtype=np_.uint8)
_CROSS_FOR_DIM = (None, None, _CROSS_3x3, _CROSS_3x3x3)

    (i, j) for i in (-1, 0, 1) for j in (-1, 0, 1) if i != 0 or j != 0
    (i, j, k)
    for i in (-1, 0, 1)
    for j in (-1, 0, 1)
    for k in (-1, 0, 1)
    if i != 0 or j != 0 or k != 0
_MIN_SHIFTS_FOR_2D_NEIGHBORS = tuple(elm for elm in _FULL_SHIFTS_FOR_2D_NEIGHBORS if np_.abs(elm).sum() == 1)
_MIN_SHIFTS_FOR_3D_NEIGHBORS = tuple(elm for elm in _FULL_SHIFTS_FOR_3D_NEIGHBORS if np_.abs(elm).sum() == 1)

_SQUARE_3x3 = np_.ones((3, 3), dtype=np_.uint8)
_SQUARE_3x3x3 = np_.ones((3, 3, 3), dtype=np_.uint8)
_LABELIZED_MAP_8_fct = lambda map_: im_.label(
    map_, structure=_SQUARE_3x3, output=np_.int64
)  # type: Callable[[array_t], Tuple[array_t, int]]
_LABELIZED_MAP_26_fct = lambda map_: im_.label(
    map_, structure=_SQUARE_3x3x3, output=np_.int64
)  # type: Callable[[array_t], Tuple[array_t, int]]

class extension_t(glial_cmp_t):
    # soma_uid: connected to a soma somewhere upstream
    __slots__ = ("end_points", "scales", "soma_uid", "__cache__")

    def __init__(self):
        for slot in self.__class__.__slots__:
            setattr(self, slot, None)

    def FromMap(cls, lmp: array_t, scales: array_t, uid: int) -> extension_t:
        Initialize and create the extension object based on the labelled map.
        instance = cls()

        # Create a boolean map keeping only the extension number 'uid'.
        bmp = lmp == uid
        # Initialize the object with its different fields
        instance.InitializeFromMap(bmp, uid)
        # Find the endpoints sites of the extension
        end_point_map = cls.EndPointMap(bmp)
        instance.end_points = end_point_map.nonzero()
        # Store the frangi scales of the extensions
        instance.scales = scales[instance.sites]
        instance.__cache__ = {}

        return instance

    def is_unconnected(self) -> bool:
        return self.soma_uid is None

    def end_points_as_array(self) -> array_t:
        pty_name = 'end_points_as_array'
        if pty_name not in self.__cache__:
            self.__cache__[pty_name] = np_.array(self.end_points)

        return self.__cache__[pty_name]

    def EndPointsForSoma(
        self, soma_uid: int, influence_map: array_t
    ) -> Tuple[site_h, ...]:
        Find the extensions endpoints in the inflence of the soma.
        # Create a boolean map with the endpoints under the soma's influence
        ep_bmp = influence_map[self.end_points] == soma_uid  # bmp=boolean map

        # Create a list of endpoints
        if ep_bmp.any():
            # Extract the endpoints indices from the map
            end_point_idc = ep_bmp.nonzero()[0]
            # Find the endpoint coordinates based on their indices in the endpoint soma/extension object
            end_points = self.end_points_as_array[:, end_point_idc]

            # Return coordinates of endpoints under a tuple[(x1,y1,z1),(x2,y2,z2),...] format
            return tuple(zip(*end_points.tolist()))

        # If no endpoints, return an empty tuple
        return ()

    def BackReferenceSoma(self, glial_cmp: glial_cmp_t) -> None:
        if isinstance(glial_cmp, extension_t):
            self.soma_uid = glial_cmp.soma_uid
            self.soma_uid = glial_cmp.uid

    def __str__(self) -> str:
        if self.extensions is None:
            n_extensions = 0
            n_extensions = self.extensions.__len__()

        return (
            f"Ext.{self.uid}, "
            f"sites={self.sites[0].__len__()}, "
            f"endpoints={self.end_points[0].__len__()}, "
            f"soma={self.soma_uid}, "

    def ExtensionContainingSite(
        extensions: Sequence[extension_t], site: site_h
    ) -> Optional[extension_t]:
        Return extension if a given site is contained into the extension
        for extension in extensions:
            if site in tuple(zip(*extension.sites)):
                return extension

        return None

    def EnhancedForDetection(
        image: array_t,
        in_parallel: bool = False
    ) -> Tuple[array_t, array_t]:
        Preprocess by white top hat.
        Perform Frangi vesselness enhancement.
        # import os.path as ph_

        # if ph_.exists("./__runtime__/frangi.npz"):
        #     print("/!\\ Reading from precomputed data file")
        #     loaded = np_.load("./frangi.npz")
        #     enhanced_img = loaded["enhanced_img"]
        #     scale_map = loaded["scale_map"]
        #     return enhanced_img, scale_map

        preprocessed_img = im_.morphology.white_tophat(
            image, size=2, mode="constant", cval=0.0, origin=0

        enhanced_img, scale_map = fg_.FrangiEnhancement(

        # enhanced_img, scale_map = fl_.frangi(
        #     image=preprocessed_img,
        #     scale_range=scale_range,
        #     scale_step=scale_step,
        #     alpha=alpha,
        #     beta=beta,
        #     gamma=frangi_c,
        #     black_ridges=bright_on_dark)

        # np_.savez_compressed(
        #     "./runtime/frangi.npz", enhanced_img=enhanced_img, scale_map=scale_map
        # )

        return enhanced_img, scale_map

    def CoarseMap(image: array_t, low: float, high: float, selem: array_t) -> array_t:
        Perform hysteresis thresholding and closing/opening.
        result = image.copy()
        if (low is not None) and (high is not None):
            result = __HysterisisImage__(result, low, high)
            # MaximumIntensityProjectionZ(result, output_image_file_name="D:\\MorganeNadal\\M2 report\\for the slides\\ext_hyst_mip.png")

        if selem is not None:
            result = __MorphologicalCleaning__(result, selem)
            # MaximumIntensityProjectionZ(result, output_image_file_name="D:\\MorganeNadal\\M2 report\\for the slides\\ext_opclos_mip.png")

        return result

    def FilteredCoarseMap(map_: array_t, ext_min_area_c: int) -> array_t:
        Delete elements with area inferior to the allowed minimum area.
        result = map_.copy()
        # Label the extensions
        lmp = ms_.label(map_)

        # Measure the area of each extension
        for region in ms_.regionprops(lmp):
            # Delete the ones too small by setting their voxel to 0
            if region.area <= ext_min_area_c:
                region_sites = (lmp == region.label).nonzero()
                result[region_sites] = 0
                lmp[region_sites] = 0

        return result, lmp

    def FineMapFromCoarseMap(coarse_map: array_t) -> array_t:
        Skeletonize the 3D coarse map.
        Might contain True-voxels that could be removed w/o breaking connectivity
        result = mp_.skeletonize_3d(coarse_map.astype(np_.uint8, copy=False))
        # Thinning of the skeleton

        return result.astype(np_.int8, copy=False)

    def EndPointMap(map_: array_t) -> array_t:
        Find the endpoints of the extensions. Endpoints ony have one pixel of connectivity.
        # Find the 26-connectivity of each voxel
        part_map = ml_.PartLMap(map_)
        # The background is labeled with 27, and endpoints have a connectivity of 1.
        result = part_map == 1

        return result.astype(np_.int8)

def __HysterisisImage__(image: array_t, low: float, high: float) -> array_t:
    Perform hysteresis, based on the image intensities.
    nonzero_sites = (image > 0).nonzero()
    nonzero_values = image[nonzero_sites]

    # print(nonzero_values.min(), image.max())

    low = low * nonzero_values.min()
    high = high * image.max()

    # print("low=", low, "   high=", high)
    # lowt = low*(x_image_f-min_image_f)+max_image_f
    # hight = high*(max_image_f- min_image_f)+min_image_f
    # lowt = (image_f >lowt).astype(int)
    # hight = (image_f <hight).astype(int)

    result = fl_.apply_hysteresis_threshold(image, low, high)
    result = result.astype(np_.int8, copy=False)

    return result

def __MorphologicalCleaning__(image: array_t, selem) -> array_t:
    Perform closing and opening of the image.
    result = image.copy()

    for dep in range(result.shape[0]):
        result[dep, :, :] = mp_.closing(result[dep, :, :], selem)
        result[dep, :, :] = mp_.opening(result[dep, :, :], selem)

    return result

def ThinMap(skl_map: array_t) -> None:
    Removes all pixels that do not break 8- or 26-connectivity
    Works for multi-skeleton

    background_label = BackgroundLabelForTmp(skl_map)

    def FixLocalMap_n(
        padded_sm_: array_t,
        part_map_: array_t,
        n_neighbors_: int,
        cross_: array_t,
        labelled_map_fct_: Callable[[array_t], Tuple[array_t, int]],
    ) -> bool:
        skel_has_been_modified_ = False

        center = padded_sm_.ndim * (1,)
        for coords in zip(*np_.where(part_map_ == n_neighbors_)):
            lm_slices = tuple(slice(coord - 1, coord + 2) for coord in coords)
            local_map = padded_sm_[lm_slices]
            local_part_map = part_map_[lm_slices]
            if (local_part_map[cross_] == background_label).any():
                local_map[center] = 0

                _, n_components = labelled_map_fct_(local_map)
                if n_components == 1:
                    skel_has_been_modified_ = True
                    local_map[center] = 1

        return skel_has_been_modified_

    padded_map = np_.pad(skl_map, 1, "constant")

    cross = _CROSS_FOR_DIM[skl_map.ndim]
    labelized_map_fct = LABELIZED_MAP_fct_FOR_DIM[skl_map.ndim]

    excluded_n_neighbors = {
        2 * skl_map.ndim,
    skel_has_been_modified = True
    while skel_has_been_modified:
        skel_has_been_modified = False

        part_map = TopologyMapOfSkeleton(padded_map, full_connectivity=False)
        included_n_neighbors = set(np_.unique(part_map)).difference(

        for n_neighbors in sorted(included_n_neighbors, reverse=True):
            skel_has_been_modified = skel_has_been_modified or FixLocalMap_n(
                padded_map, part_map, n_neighbors, cross, labelized_map_fct,

    if skl_map.ndim == 2:
        skl_map[:, :] = padded_map[1:-1, 1:-1]
        skl_map[:, :, :] = padded_map[1:-1, 1:-1, 1:-1]

def BackgroundLabelForTmp(a_map: array_t) -> int:
    Must be equal to the max number of neighbors in a skeleton + 1.
    Note: using a_map avoids shadowing Python's map.
    return 3 ** a_map.ndim

def TopologyMapOfSkeleton(skl_map: array_t, full_connectivity: bool = True) -> array_t:
    The topology map is labeled as follows: background=invalid_n_neighbors_Xd_c; Pixels of the skeleton=number of
    neighboring pixels that belong to the skeleton (as expected, isolated pixels receive 0).
    Works for multi-skeleton
    tmap = np_.array(skl_map, dtype=np_.int8)

    if full_connectivity:
        shifts_for_dim = _FULL_SHIFTS_FOR_NEIGHBORS_FOR_DIM
        shifts_for_dim = _MIN_SHIFTS_FOR_NEIGHBORS_FOR_DIM
    padded_sm = np_.pad(skl_map, 1, "constant")
    unpadding_domain = skl_map.ndim * (slice(1, -1),)
    rolling_axes = tuple(range(skl_map.ndim))
    for shifts in shifts_for_dim[skl_map.ndim]:
        tmap += np_.roll(padded_sm, shifts, axis=rolling_axes)[unpadding_domain]
    tmap[skl_map == 0] = BackgroundLabelForTmp(skl_map) + 1

    return tmap - 1

def MaximumIntensityProjectionZ(img: array_t, cmap: str ='tab20', axis: int = 0, output_image_file_name: str = None) -> None:
    """ Maximum Image Projection on the Z axis. """
    xy = np_.amax(img, axis=axis)
    pl_.imshow(xy, cmap=cmap)
    if output_image_file_name is not None:
        pl_.imsave(output_image_file_name, xy, cmap=cmap)
        print('Image saved in', output_image_file_name)