import tkinter as tknt
from typing import Callable, Sequence, Tuple, Union

import matplotlib.cm as mpcm
import numpy as nmpy
import skimage.segmentation as sisg
from PIL import Image as plim
from PIL import ImageTk as pltk


array_t = nmpy.ndarray
image_t = plim.Image
tk_image_t = pltk.PhotoImage


STATIC_ROW_MIN_HEIGHT = 30


class soma_validation_window_t:
    __slots__ = (
        "gfp",
        "gfp_mip",
        "gfp_pil_image_initial",
        "gfp_pil_image",
        "gfp_tk_image",
        "lmap",
        "lmap_mip",
        "lmap_mip_4_display",
        "lmap_pil_image",
        "lmap_tk_image",
        "color_version",
        "with_cm",
        "main_window",
        "mip_axis_wgt",
        "gfp_wgt",
        "lmap_wgt",
        "cursor_nfo",
    )
    gfp: array_t
    gfp_mip: array_t
    gfp_pil_image_initial: image_t
    gfp_pil_image: image_t
    gfp_tk_image: tk_image_t
    lmap: array_t
    lmap_mip: array_t
    lmap_mip_4_display: array_t
    lmap_pil_image: image_t
    lmap_tk_image: tk_image_t
    color_version: bool
    with_cm: str
    main_window: tknt.Tk
    mip_axis_wgt: tknt.Menubutton
    gfp_wgt: tknt.Canvas
    lmap_wgt: tknt.Canvas
    cursor_nfo: tknt.Label

    def __init__(
        self,
        gfp: array_t,
        lmap: array_t,
        mip_axis: int = -1,
        color_version: bool = True,
        with_cm: str = None,
    ):
        """
        with_cm: "plasma" and "viridis" seem to be good options
        """
        main_window = tknt.Tk()

        # --- Creation of MIPs and Tk images
        (
            gfp_mip,
            lmap_mip,
            lmap_mip_4_display,
            gfp_pil_image,
            gfp_tk_image,
            lmap_pil_image,
            lmap_tk_image,
        ) = _MIPImages(
            gfp,
            lmap,
            mip_axis,
            color_version,
            with_cm,
            main_window,
        )
        gfp_pil_image_initial = gfp_pil_image

        # ---- Creation of widgets
        if mip_axis < 0:
            mip_axis = gfp.ndim + mip_axis
        mip_axis_wgt = _MIPAxisChoiceWidget(
            mip_axis, self._ChangeMIPAxis, gfp.shape, main_window
        )
        gfp_wgt = tknt.Label(main_window, image=gfp_tk_image)
        lmap_wgt = tknt.Label(main_window, image=lmap_tk_image)
        cursor_nfo = tknt.Label(main_window, text="")
        done_button = tknt.Button(main_window, text="Done", command=main_window.quit)

        # --- Event management
        gfp_wgt.bind("<Configure>", self._OnResize)
        lmap_wgt.bind("<Motion>", self._DisplaySomaLabel)
        lmap_wgt.bind("<Button-1>", self._DeleteSoma)

        # --- Widget placement in grid
        next_available_row = 0

        mip_axis_wgt.grid(row=next_available_row, column=0)
        next_available_row += 1

        gfp_wgt.grid(row=next_available_row, column=0)
        lmap_wgt.grid(row=next_available_row, column=1)
        next_available_row += 1

        cursor_nfo.grid(row=next_available_row, column=0)
        done_button.grid(row=next_available_row, column=1)
        next_available_row += 1

        # --- Window resize management
        main_window.rowconfigure(0, weight=1, minsize=STATIC_ROW_MIN_HEIGHT)
        main_window.rowconfigure(1, weight=10)
        main_window.rowconfigure(2, weight=1, minsize=STATIC_ROW_MIN_HEIGHT)
        main_window.columnconfigure(0, weight=1)
        main_window.columnconfigure(1, weight=1)

        # --- Saving required variables as object attributes
        for attribute in self.__class__.__slots__:
            setattr(self, attribute, eval(attribute))

    def LaunchValidation(self) -> int:
        """"""
        self.main_window.mainloop()
        self.main_window.destroy()

        relabeled, _, _ = sisg.relabel_sequential(self.lmap)
        self.lmap[...] = relabeled

        return nmpy.amax(self.lmap)

    def _OnResize(self, event: tknt.EventType.Configure) -> None:
        """"""
        self.gfp_pil_image = self.gfp_pil_image_initial.resize((event.width, event.height))
        self.gfp_tk_image = pltk.PhotoImage(master=self.main_window, image=self.gfp_pil_image)
        self.gfp_wgt.configure(image=self.gfp_tk_image)

    def _ChangeMIPAxis(self, mip_axis: tknt.IntVar, *_, **__):
        """"""
        new_mip_axis = mip_axis.get()

        (
            gfp_mip,
            lmap_mip,
            lmap_mip_4_display,
            gfp_pil_image,
            gfp_tk_image,
            lmap_pil_image,
            lmap_tk_image,
        ) = _MIPImages(
            self.gfp,
            self.lmap,
            new_mip_axis,
            self.color_version,
            self.with_cm,
            self.main_window,
        )
        self.gfp_pil_image_initial = self.gfp_pil_image
        self.gfp_wgt.configure(image=gfp_tk_image)
        self.lmap_wgt.configure(image=lmap_tk_image)

        self.gfp_mip = gfp_mip
        self.gfp_pil_image = gfp_pil_image
        self.gfp_tk_image = gfp_tk_image

        self.lmap_mip = lmap_mip
        self.lmap_mip_4_display = lmap_mip_4_display
        self.lmap_pil_image = lmap_pil_image
        self.lmap_tk_image = lmap_tk_image

    def _DisplaySomaLabel(self, event: tknt.EventType.Motion) -> None:
        """"""
        row = event.y
        col = event.x
        try:
            label = self.lmap_mip[row, col]
        except IndexError:
            # This problem appeared when pack was replaced with grid. Setting ipad? and pad? to zero when adding the
            # lmap_wgt to the grid did not solve it. Is this a bug in TkInter grid?
            return
        self.cursor_nfo.configure(text=f"Label:{label}@{row}x{col}")

    def _DeleteSoma(self, event: tknt.EventType.ButtonPress) -> None:
        """"""
        row = event.y
        col = event.x
        try:
            label = self.lmap_mip[row, col]
        except IndexError:
            # This problem appeared when pack was replaced with grid. Setting ipad? and pad? to zero when adding the
            # lmap_wgt to the grid did not solve it. Is this a bug in TkInter grid?
            return

        if label > 0:
            self.lmap[self.lmap == label] = 0

            soma_bmap = self.lmap_mip == label
            self.lmap_mip[soma_bmap] = 0
            if self.lmap_mip_4_display.ndim == 1:
                self.lmap_mip_4_display[soma_bmap] = 0
            else:
                for channel in range(self.lmap_mip_4_display.shape[2]):
                    self.lmap_mip_4_display[..., channel][soma_bmap] = 0
            self.lmap_tk_image, self.lmap_pil_image = _TkImageFromNumpyArray(
                self.lmap_mip_4_display, self.main_window
            )
            self.lmap_wgt.configure(image=self.lmap_tk_image)


def _MIPImages(
    gfp: array_t,
    lmap: array_t,
    mip_axis: int,
    color_version: bool,
    with_cm: str,
    parent: Union[tknt.Widget, tknt.Tk],
) -> Tuple[array_t, array_t, array_t, image_t, tk_image_t, image_t, tk_image_t]:
    """"""
    gfp_mip = nmpy.amax(gfp, axis=mip_axis)
    gfp_mip *= 255.0 / nmpy.amax(gfp_mip)

    lmap_mip = nmpy.amax(lmap, axis=mip_axis)

    if color_version:
        if with_cm is None:
            lmap_mip_4_display = _ColoredVersion(lmap_mip)
        else:
            lmap_mip_4_display = _ColoredVersionFromColormap(lmap_mip, with_cm)
    else:
        lmap_mip_4_display = _ScaledVersion(lmap_mip)

    gfp_tk_image, gfp_pil_image = _TkImageFromNumpyArray(gfp_mip, parent)
    lmap_tk_image, lmap_pil_image = _TkImageFromNumpyArray(lmap_mip_4_display, parent)

    return (
        gfp_mip,
        lmap_mip,
        lmap_mip_4_display,
        gfp_pil_image,
        gfp_tk_image,
        lmap_pil_image,
        lmap_tk_image,
    )


def _ScaledVersion(image: array_t, offset: int = 50) -> array_t:
    """
    offset: Value of darkest non-background intensity
    """
    scaling = (255.0 - offset) / nmpy.max(image)
    output = scaling * image + offset
    output[image == 0] = 0

    return output.astype(nmpy.uint8)


def _ColoredVersion(image: array_t) -> array_t:
    """"""
    max_label = nmpy.amax(image)

    labels = tuple(range(1, max_label + 1))
    half_length = int(round(0.5 * max_label))
    shuffled_labels = labels[half_length:] + labels[:half_length]
    shuffled_image = nmpy.zeros_like(image)
    for label, shuffled_label in enumerate(shuffled_labels):
        shuffled_image[image == label] = shuffled_label

    output = nmpy.dstack((image, shuffled_image, max_label - image))
    output = (255.0 / max_label) * output
    output[image == 0] = 0

    return output.astype(nmpy.uint8)


def _ColoredVersionFromColormap(image: array_t, colormap_name: str) -> array_t:
    """"""
    output = nmpy.zeros(image.shape + (3,), dtype=nmpy.uint8)

    LinearValueToRGB = mpcm.get_cmap(colormap_name)
    max_label = nmpy.amax(image)
    for label in range(1, max_label + 1):
        color_01 = LinearValueToRGB((label - 1.0) / (max_label - 1.0))
        color_255 = nmpy.round(255.0 * nmpy.array(color_01[:3]))
        output[image == label, :] = color_255

    return output


def _TkImageFromNumpyArray(
    array: array_t, parent: Union[tknt.Widget, tknt.Tk]
) -> Tuple[tk_image_t, image_t]:
    """"""
    pil_image = plim.fromarray(array)
    tk_image = pltk.PhotoImage(master=parent, image=pil_image)

    return tk_image, pil_image


def _MIPAxisChoiceWidget(
    current_axis: int,
    Action: Callable,
    shape: Sequence[int],
    parent: Union[tknt.Widget, tknt.Tk],
) -> tknt.Menubutton:
    """"""
    title = f"MIP Axis [{','.join(str(_lgt) for _lgt in shape)}]"
    output = tknt.Menubutton(parent, text=title, relief="raised")

    menu = tknt.Menu(output, tearoff=False)
    entries = ("First dim", "Second dim", "Third dim")
    selected_mip_axis = tknt.IntVar()
    for idx, entry in enumerate(entries):
        menu.add_radiobutton(label=entry, value=idx, variable=selected_mip_axis)
    menu.invoke(current_axis)

    # Set action only after calling invoke to avoid redundant call at window creation
    TkAction = lambda *args, **kwargs: Action(selected_mip_axis, *args, **kwargs)
    selected_mip_axis.trace_add("write", TkAction)

    output["menu"] = menu

    return output