from __future__ import annotations

import frangi_3d as fg_
import map_labeling as ml_
from type import array_t, site_h

from typing import Sequence, Tuple

import numpy as np_
import skimage.filters as fl_
import skimage.measure as ms_
import skimage.morphology as mp_
from scipy import ndimage as im_

min_area_c = 100

class extension_t:
    # bmp=boolean map
    # lmp=labeled map (intXX or uintXX array)
    # map=extension map (map=binary, int8 or uint8 array))
    # ep_=end point
    # soma_uid: connected to a soma somewhere upstream (as opposed to downstream extensions)
    # extension_uids: downstream (as opposed to being upstreamed connected)
    __slots__ = (

    def __init__(self):
        self.uid = None
        self.sites = None
        self.scales = None
        self.end_points = None
        self.ep_closest_somas = None
        self.soma_uid = None
        self.extension_uids = None
        self.__cache__ = None

    def FromMaps(
        cls, lmp: array_t, end_point_lmp: array_t, scales: array_t, uid: int
    ) -> extension_t:
        instance = cls()

        # sites: might contain voxels that could be removed w/o breaking connectivity
        instance.uid = uid
        instance.sites = (lmp == uid).nonzero()
        instance.scales = scales[instance.sites]
        instance.end_points = (end_point_lmp == uid).nonzero()
        instance.extension_uids = []
        instance.__cache__ = {}

        return instance

    def end_points_as_array(self) -> array_t:
        pty_name = "end_points_as_array"
        if pty_name not in self.__cache__:
            self.__cache__[pty_name] = np_.array(self.end_points)

        return self.__cache__[pty_name]

    def CaptureClosestSomas(self, soma_influence_map: array_t) -> None:
        self.ep_closest_somas = soma_influence_map[self.end_points]

    def EndPointsForSoma(self, soma_uid: int) -> Tuple[site_h, ...]:
        ep_bmp = self.ep_closest_somas == soma_uid  # bmp=boolean map
        if ep_bmp.any():
            end_point_idc = ep_bmp.nonzero()[0]
            end_points = self.end_points_as_array[:, end_point_idc]

            return tuple(zip(*end_points.tolist()))

        return ()

    def Extend(self, extensions: Sequence[extension_t], costs: array_t) -> None:
        print(f"{__name__}.py: {self.Extend.__name__}: To be implemented")

    def EnhancedForDetection(
        image: array_t, in_parallel: bool = False
    ) -> Tuple[array_t, array_t]:
        import os.path as ph_

        if ph_.exists("./frangi.npz"):
            print("/!\\ Reading from precomputed data file")
            loaded = np_.load("./frangi.npz")
            enhanced_img = loaded["enhanced_img"]
            scale_map = loaded["scale_map"]

            return enhanced_img, scale_map

        preprocessed_img = im_.morphology.white_tophat(
            image, size=2, mode="constant", cval=0.0, origin=0

        enhanced_img, scale_map = fg_.FrangiEnhancement(
            scale_range=(0.1, 3),

            "./frangi.npz", enhanced_img=enhanced_img, scale_map=scale_map

        return enhanced_img, scale_map

    def CoarseMap(image: array_t, low: float, high: float) -> array_t:
        result = __HysterisisImage__(image, low, high)
        result = __MorphologicalCleaning__(result)

        return result

    def FilteredCoarseMap(map_: array_t) -> array_t:
        result = map_.copy()
        lmp = ms_.label(map_)

        for region in ms_.regionprops(lmp):
            if region.area <= min_area_c:
                region_sites = (lmp == region.label).nonzero()
                result[region_sites] = 0

        return result

    def FineMapFromCoarseMap(coarse_map: array_t) -> array_t:
        # Might contain True-voxels that could be removed w/o breaking connectivity
        result = mp_.skeletonize_3d(coarse_map.astype(np_.uint8, copy=False))

        return result.astype(np_.int8, copy=False)

    def EndPointMap(map_: array_t) -> array_t:
        part_map = ml_.PartLMap(map_)
        result = part_map == 1

        return result.astype(np_.int8)

def NormalizedImage(image: array_t) -> array_t:
    middle_values = image[np_.logical_and(image > 0.0, image < image.max())]
    image_mean = middle_values.mean()
    result = image / image_mean

    return result

def __HysterisisImage__(image: array_t, low: float, high: float) -> array_t:
    # low = 0.02
    # high = 0.04
    nonzero_sites = (image > 0).nonzero()
    nonzero_values = image[nonzero_sites]

    low = low * nonzero_values.min()
    high = high * image.max()
    # lowt = low*(max_image_f-min_image_f)+max_image_f
    # hight = high*(max_image_f- min_image_f)+min_image_f
    # lowt = (image_f >lowt).astype(int)
    # hight = (image_f <hight).astype(int)

    result = fl_.apply_hysteresis_threshold(image, low, high)
    result = result.astype(np_.int8)

    return result

def __MorphologicalCleaning__(image: array_t) -> array_t:

    result = image.copy()

    selem = mp_.disk(1)
    for dep in range(result.shape[0]):
        result[dep, :, :] = mp_.closing(result[dep, :, :], selem)
        result[dep, :, :] = mp_.opening(result[dep, :, :], selem)

    return result