Mentions légales du service

Skip to content
Snippets Groups Projects
soma_validation.py 21.4 KiB
Newer Older
# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2019)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

from __future__ import annotations

import tkinter as tknt
from typing import Callable, Optional, Sequence, Tuple, Union
import matplotlib.cm as mpcm
import matplotlib.pyplot as pypl
import numpy as nmpy
import skimage.measure as sims
import skimage.segmentation as sisg
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg as matplotlib_widget_t
from matplotlib.backends.backend_tkagg import NavigationToolbar2Tk as toolbar_widget_t
from matplotlib.figure import Figure as figure_t
from mpl_toolkits.mplot3d import Axes3D as axes_3d_t
from PIL import Image as plim
DEBREUVE Eric's avatar
DEBREUVE Eric committed
from PIL import ImageTk as pltk
array_t = nmpy.ndarray
image_t = plim.Image
DEBREUVE Eric's avatar
DEBREUVE Eric committed
tk_image_t = pltk.PhotoImage


STATIC_ROW_MIN_HEIGHT = 30
INITIAL_RELATIVE_ISOVALUE = 0.6


class soma_validation_window_t:
        "three_d_selector",
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        "mip_axis_wgt",
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        "gfp_wgt",
        "gfp_3d_wgt",
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        "lmap_wgt",
        "lmap_3d_wgt",
        "cursor_nfo",
        "invisible_widgets",
    three_d_selector: tknt.Checkbutton
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    mip_axis_wgt: tknt.Menubutton
    isovalue_wgt: tknt.Scale
    gfp_wgt: mip_widget_t
    gfp_3d_wgt: Optional[three_d_widget_t]
    lmap_wgt: mip_widget_t
    lmap_3d_wgt: Optional[three_d_widget_t]
    cursor_nfo: tknt.Label
    invisible_widgets: Sequence
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        gfp: array_t,
        mip_axis: int = -1,
        with_cm: str = None,
    ):
        """
        with_cm: "plasma" and "viridis" seem to be good options
        """
DEBREUVE Eric's avatar
DEBREUVE Eric committed

DEBREUVE Eric's avatar
DEBREUVE Eric committed
        # ---- Creation of widgets
        if mip_axis < 0:
            mip_axis = gfp.ndim + mip_axis
        gfp_wgt = mip_widget_t(
            gfp,
            mip_axis=mip_axis,
            color_version=False,
            static_image=False,
            probed_image=False,
            resizeable=True,
            parent=main_window,
        )
        lmap_wgt = mip_widget_t(
            lmap,
            mip_axis=mip_axis,
            color_version=True,
            with_cm=with_cm,
            static_image=False,
            probed_image=True,
            resizeable=True,
            parent=main_window,
        )
        isovalue_wgt = None
        gfp_3d_wgt = None
        lmap_3d_wgt = None
        state_variable = tknt.IntVar()
        Toggle3D = lambda *args, **kwargs: self._Toggle3D(state_variable)
        three_d_selector = tknt.Checkbutton(
            main_window, text="3D View", variable=state_variable, command=Toggle3D
        )
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        mip_axis_wgt = _MIPAxisChoiceWidget(
            mip_axis,
            (gfp_wgt.ChangeMIPAxis, lmap_wgt.ChangeMIPAxis),
            gfp.shape,
            main_window,
        cursor_nfo = tknt.Label(main_window, text="")
        done_button = tknt.Button(main_window, text="Done", command=main_window.quit)
        invisible_widgets = ()
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        # --- Event management
        lmap_wgt.AddProbe(cursor_nfo)
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        lmap_wgt.bind("<Button-1>", self._DeleteSoma)

        # --- Widget placement in grid
        next_available_row = 0

        mip_axis_wgt.grid(row=next_available_row, column=0)
        three_d_selector.grid(row=next_available_row, column=1)
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        next_available_row += 1

        # sticky=... solves the super-slow-to-resize issue!
        gfp_wgt.grid(
            row=next_available_row, column=0, sticky=tknt.W + tknt.E + tknt.N + tknt.S
        )
        lmap_wgt.grid(
            row=next_available_row, column=1, sticky=tknt.W + tknt.E + tknt.N + tknt.S
        )
        next_available_row += 2  # Leave one row free for isovalue slider
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        cursor_nfo.grid(row=next_available_row, column=0)
        done_button.grid(row=next_available_row, column=1)
        next_available_row += 1
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        # --- Window resize management
        main_window.rowconfigure(0, weight=1, minsize=STATIC_ROW_MIN_HEIGHT)
        main_window.rowconfigure(1, weight=10)
        main_window.rowconfigure(2, weight=1, minsize=STATIC_ROW_MIN_HEIGHT)
        main_window.columnconfigure(0, weight=1)
        main_window.columnconfigure(1, weight=1)
DEBREUVE Eric's avatar
DEBREUVE Eric committed

        # --- Saving required variables as object attributes
        for attribute in self.__class__.__slots__:
            setattr(self, attribute, eval(attribute))
    def LaunchValidation(self) -> int:
        self.main_window.mainloop()
        self.main_window.destroy()
        lmap = self.lmap_wgt.image
        relabeled, *_ = sisg.relabel_sequential(lmap)
        lmap[...] = relabeled

        return nmpy.amax(lmap)

    def _DeleteSoma(self, event: tknt.EventType.ButtonPress, /) -> None:
        """"""
        row, col = self.lmap_wgt.ArrayIndicesFromPixel(event.y, event.x)

        lmap_mip = self.lmap_wgt.image_mip
        label = lmap_mip[row, col]

        if label > 0:
            lmap = self.lmap_wgt.image
            lmap[lmap == label] = 0

            soma_bmap = lmap_mip == label
            lmap_mip[soma_bmap] = 0

            lmap_mip_4_display = self.lmap_wgt.display_version
            if lmap_mip_4_display.ndim == 1:
                lmap_mip_4_display[soma_bmap] = 0
            else:
                for channel in range(lmap_mip_4_display.shape[2]):
                    lmap_mip_4_display[..., channel][soma_bmap] = 0

            self.lmap_wgt.UpdateImage(lmap, lmap_mip, lmap_mip_4_display)

    def _Toggle3D(self, state: tknt.IntVar) -> None:
        """"""
        three_d_state = state.get()

        if three_d_state == 0:
            soon_invisible_widgets = (self.gfp_3d_wgt, self.lmap_3d_wgt)
        else:
            soon_invisible_widgets = (self.gfp_wgt, self.lmap_wgt, self.mip_axis_wgt)
        for widget in soon_invisible_widgets:
            widget.grid_remove()

        if self.invisible_widgets.__len__() > 0:
            for widget in self.invisible_widgets:
                widget.grid()
        else:
            gfp = self.gfp_wgt.image
            lmap = self.lmap_wgt.image
            gfp_min = nmpy.amin(gfp)
            gfp_max = nmpy.amax(gfp)
            gfp_extent = gfp_max - gfp_min

            isovalue_variable = tknt.DoubleVar()
            isovalue_variable.set(INITIAL_RELATIVE_ISOVALUE * gfp_max)

            self.gfp_3d_wgt = three_d_widget_t(gfp, self.main_window)
            self.lmap_3d_wgt = three_d_widget_t(lmap, self.main_window)
            self.gfp_3d_wgt.ComputeAndPlotIsosurface(isovalue_variable)
            self.lmap_3d_wgt.ComputeAndPlotIsosurface(0.5)
            self.gfp_3d_wgt.AddCompanionAxes(self.lmap_3d_wgt.axes)
            self.lmap_3d_wgt.AddCompanionAxes(self.gfp_3d_wgt.axes)

            self.isovalue_wgt = tknt.Scale(
                self.main_window,
                orient="horizontal",
                from_=gfp_min + 0.1 * gfp_extent,
                to=gfp_max - 0.1 * gfp_extent,
                resolution=0.8 * gfp_extent / 100.0,
                tickinterval=0.8 * gfp_extent / 10.0,
                variable=isovalue_variable,
                label="Isovalue",
            )
            self.isovalue_wgt.set(isovalue_variable.get())
            ChangeIsovalue = (
                lambda *args, **kwargs: self.gfp_3d_wgt.ComputeAndPlotIsosurface(
                    isovalue_variable
                )
            )
            isovalue_variable.trace_add("write", ChangeIsovalue)

            self.gfp_3d_wgt.grid(
                row=1, column=0, sticky=tknt.W + tknt.E + tknt.N + tknt.S
            )
            self.lmap_3d_wgt.grid(
                row=1, column=1, sticky=tknt.W + tknt.E + tknt.N + tknt.S
            )
            self.isovalue_wgt.grid(
                row=2, column=0, sticky=tknt.W + tknt.E + tknt.N + tknt.S
            )
        self.invisible_widgets = soon_invisible_widgets

        if three_d_state == 0:
            self.cursor_nfo.configure(text="")
        else:
            self.cursor_nfo.configure(text="No MIP info in 3-D mode")

class mip_widget_t(tknt.Label):
    image: Optional[array_t]
    color_version: bool
    with_cm: Optional[str]
    image_mip: Optional[array_t]
    display_version: Optional[array_t]
    pil_version_original: Optional[image_t]
    pil_version: image_t
    tk_version: tk_image_t
    mip_axis: int
    static_image: bool
    probed_image: bool
    resizeable: bool
    probe_info_wgt: Optional[tknt.Widget]
    parent: Union[tknt.Widget, tknt.Tk]
    def __init__(
        self,
        image: array_t,
        /,
        *,
        mip_axis: int = -1,
        color_version: bool = True,
        with_cm: str = None,
        static_image: bool = False,
        probed_image: bool = True,
        resizeable: bool = True,
        parent: Union[tknt.Widget, tknt.Tk] = None,
    ):
        image_mip, display_version, pil_version, tk_version = _MIPImages(
            image,
            mip_axis,
            color_version,
            with_cm,
            parent,
        )
        self.color_version = color_version
        self.with_cm = with_cm
        if static_image:
            self.image = None
            self.display_version = None
        else:
            self.image = image
            self.display_version = display_version
        if probed_image:
            self.image_mip = image_mip
        else:
            self.image_mip = None
        if resizeable:
            self.pil_version_original = pil_version
        else:
            self.pil_version_original = None
        self.pil_version = pil_version
        self.tk_version = tk_version

        self.mip_axis = mip_axis

        self.static_image = static_image
        self.probed_image = probed_image
        self.resizeable = resizeable

        self.probe_info_wgt = None
        self.parent = parent

        super().__init__(parent, image=self.tk_version, borderwidth=0, padx=0, pady=0)
        if resizeable:
            # Binding cannot be done before super init
            self.bind("<Configure>", self._OnResize)

    def AddProbe(
        self, probe_info_wgt: tknt.Widget, /, *, for_event: str = "<Motion>"
    ) -> None:
        """
        probe_info_wgt: Must have a text attribute updatable through the configure(text=...) method.

        Instead of checking that self.image_mip is not None here, the widget constructor could have been declared with
        a probe info widget as a parameter. However, to avoid requiring the probe info widget being created before the
        MIP widget, this 2-step+checking option was chosen.
        """
        if self.image_mip is None:
            raise ValueError(
                'Adding a probe to a MIP widget instantiated with "probed_image" to False'
            )

        self.probe_info_wgt = probe_info_wgt
        self.bind(for_event, self._DisplayInfo)

    def _DisplayInfo(self, event: tknt.EventType.Motion, /) -> None:
        """"""
        row, col = self.ArrayIndicesFromPixel(event.y, event.x)
        label = self.image_mip[row, col]
        self.probe_info_wgt.configure(text=f"Label: {label} @ {row}x{col}")

    def _OnResize(self, event: tknt.EventType.Configure, /) -> None:
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        """"""
        self.pil_version = self.pil_version_original.resize((event.width, event.height))
        self.tk_version = pltk.PhotoImage(master=self.parent, image=self.pil_version)
        self.configure(image=self.tk_version)

    def UpdateImage(
        self, image: array_t, image_mip: array_t, display_version: array_t, /
    ) -> None:
        """"""
        if not self.static_image:
            self.image = image
            self.display_version = display_version
        if self.probed_image:
            self.image_mip = image_mip

        tk_version, pil_version = _TkAndPILImagesFromNumpyArray(
            self.display_version, self.parent
        )

        if self.resizeable:
            self.pil_version_original = pil_version
        self.pil_version = pil_version
        self.tk_version = tk_version

        self.configure(image=self.tk_version)

    def ChangeMIPAxis(self, mip_axis: tknt.IntVar, /, *_, **__) -> None:
        """"""
        image_mip, display_version, pil_version, tk_version = _MIPImages(
            self.image,
            mip_axis.get(),
DEBREUVE Eric's avatar
DEBREUVE Eric committed
            self.color_version,
            self.with_cm,
            self.parent,
        if not self.static_image:
            self.display_version = display_version
        if self.probed_image:
            self.image_mip = image_mip
        if self.resizeable:
            self.pil_version_original = pil_version
        self.pil_version = pil_version
        self.tk_version = tk_version
        self.configure(image=self.tk_version)
    def ArrayIndicesFromPixel(self, row: int, col: int, /) -> Tuple[int, int]:
        """"""
        shape = self.image_mip.shape
        row = int(round(shape[0] * row / self.winfo_height()))
        col = int(round(shape[1] * col / self.winfo_width()))

        return row, col


class three_d_widget_t(tknt.Frame):
    """"""

    def __init__(
        self,
        image: array_t,
        parent: Union[tknt.Widget, tknt.Tk],
        /,
        *args,
        **kwargs,
    ):
        super().__init__(
            parent,
            *args,
            borderwidth=0,
            padx=0,
            pady=0,
            **kwargs,
            class_=self.__class__.__name__,
        )
        figure = figure_t()
        axes = figure.add_subplot(111, projection=axes_3d_t.name)
        axes.set_xlabel("First")
        axes.set_ylabel("Second")
        axes.set_zlabel("Third")
        plot_wgt = matplotlib_widget_t(figure, master=self)
        plot_wgt.draw()

        toolbar = toolbar_widget_t(plot_wgt, self, pack_toolbar=False)
        toolbar.update()

        plot_wgt.get_tk_widget().pack(fill=tknt.BOTH, expand=True)
        toolbar.pack(side=tknt.BOTTOM, fill=tknt.X)

        self.image = image

        self.figure = figure
        self.axes = axes
        self.vertices = None
        self.triangles = None
        self.isosurface = None

        self.companion_axes = None
        self.live_synchronization = False
        self.synchronization_context = None
        self.motion_event_id = None
        self.button_release_id = None
        self.plot_wgt = plot_wgt
        self.toolbar = toolbar

    def ComputeAndPlotIsosurface(self, isovalue: Union[float, tknt.DoubleVar]) -> None:
        """"""
        if self.isosurface is not None:
            self.isosurface.remove()

        if isinstance(isovalue, tknt.DoubleVar):
            isovalue = isovalue.get()
        vertices, triangles, *_ = sims.marching_cubes(
            self.image, level=isovalue, step_size=2
        )
        isosurface = self.axes.plot_trisurf(
            vertices[:, 0],
            vertices[:, 1],
            triangles,
            vertices[:, 2],
            cmap="Spectral",
            lw=1,
        )

        self.vertices = vertices
        self.triangles = triangles
        self.isosurface = isosurface

        self.figure.canvas.draw_idle()

    def AddCompanionAxes(
        self, axes: pypl.Axes, /, *, live_synchronization: bool = False
    ) -> None:
        """"""
        self.companion_axes = axes
        self.live_synchronization = live_synchronization
        _ = self.figure.canvas.mpl_connect("button_press_event", self._OnButtonPress)

    def _OnButtonPress(self, _) -> None:
        """"""
        self.motion_event_id = self.figure.canvas.mpl_connect(
            "motion_notify_event", self._OnMotion
        )
        self.button_release_id = self.figure.canvas.mpl_connect(
            "button_release_event", self._OnButtonRelease
        )

    def _OnButtonRelease(self, _, /) -> None:
        """"""
        self.figure.canvas.mpl_disconnect(self.motion_event_id)
        self.figure.canvas.mpl_disconnect(self.button_release_id)
        if not self.live_synchronization:
            _Synchronize3DViews(
                self.synchronization_context, self.axes, self.companion_axes
            self.companion_axes.figure.canvas.draw_idle()

    def _OnMotion(self, event, /) -> None:
        """"""
        if event.inaxes == self.axes:
            if self.live_synchronization:
                _Synchronize3DViews(
                    self.axes.button_pressed, self.axes, self.companion_axes
                )
                self.companion_axes.figure.canvas.draw_idle()
            else:
                self.synchronization_context = self.axes.button_pressed


def _MIPAxisChoiceWidget(
    current_axis: int,
    Action: Union[Callable, Sequence[Callable]],
    shape: Sequence[int],
    parent: Union[tknt.Widget, tknt.Tk],
    /,
) -> tknt.Menubutton:
    """"""
    title = f"MIP Axis [{','.join(str(_lgt) for _lgt in shape)}]"
    output = tknt.Menubutton(parent, text=title, relief="raised")

    menu = tknt.Menu(output, tearoff=False)
    entries = ("First dim", "Second dim", "Third dim")
    selected_mip_axis = tknt.IntVar()
    for idx, entry in enumerate(entries):
        menu.add_radiobutton(label=entry, value=idx, variable=selected_mip_axis)
    menu.invoke(current_axis)

    # Set action only after calling invoke to avoid redundant call at window creation
    if isinstance(Action, Callable):
        Actions = (Action,)
    else:
        Actions = Action

    def Callback(*args, **kwargs) -> None:
        #
        for OneAction in Actions:
            OneAction(selected_mip_axis, *args, **kwargs)

    selected_mip_axis.trace_add("write", Callback)

    output["menu"] = menu

    return output
DEBREUVE Eric's avatar
DEBREUVE Eric committed


def _MIPImages(
    image: array_t,
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    mip_axis: int,
    color_version: bool,
    with_cm: Optional[str],
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    parent: Union[tknt.Widget, tknt.Tk],
    /,
    *,
    offset: int = 0,
) -> Tuple[array_t, array_t, image_t, tk_image_t]:
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    """"""
    image_mip = nmpy.amax(image, axis=mip_axis)
DEBREUVE Eric's avatar
DEBREUVE Eric committed

    if color_version:
        if with_cm is None:
            display_version = _ColoredVersion(image_mip)
DEBREUVE Eric's avatar
DEBREUVE Eric committed
        else:
            display_version = _ColoredVersionFromColormap(image_mip, with_cm)
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    else:
        display_version = _ScaledVersion(image_mip, offset=offset)
    tk_version, pil_version = _TkAndPILImagesFromNumpyArray(display_version, parent)
        image_mip,
        display_version,
        pil_version,
        tk_version,
def _ScaledVersion(image: array_t, /, *, offset: int = 0) -> array_t:
    """
    offset: Value of darkest non-background intensity
    """
    scaling = (255.0 - offset) / nmpy.amax(image)
    output = scaling * image + offset
    output[image == 0] = 0

    return nmpy.around(output).astype(nmpy.uint8)
def _ColoredVersion(image: array_t, /) -> array_t:
    """"""
    max_label = nmpy.amax(image)

    labels = tuple(range(1, max_label + 1))
    half_length = int(round(0.5 * max_label))
    shuffled_labels = labels[half_length:] + labels[:half_length]
    shuffled_image = nmpy.zeros_like(image)
    for label, shuffled_label in enumerate(shuffled_labels):
        shuffled_image[image == label] = shuffled_label

    output = nmpy.dstack((image, shuffled_image, max_label - image))
    output = (255.0 / max_label) * output
    output[image == 0] = 0

    return output.astype(nmpy.uint8)


def _ColoredVersionFromColormap(image: array_t, colormap_name: str, /) -> array_t:
    """"""
    output = nmpy.zeros(image.shape + (3,), dtype=nmpy.uint8)

    LinearValueToRGB = mpcm.get_cmap(colormap_name)
    max_label = nmpy.amax(image)
    for label in range(1, max_label + 1):
        color_01 = LinearValueToRGB((label - 1.0) / (max_label - 1.0))
        color_255 = nmpy.around(255.0 * nmpy.array(color_01[:3]))
        output[image == label, :] = color_255

    return output


def _TkAndPILImagesFromNumpyArray(
    array: array_t, parent: Union[tknt.Widget, tknt.Tk], /
) -> Tuple[tk_image_t, image_t]:
    pil_image = plim.fromarray(array)
    tk_image = pltk.PhotoImage(master=parent, image=pil_image)
    return tk_image, pil_image
def _Synchronize3DViews(context, source: pypl.Axes, target: pypl.Axes) -> None:
DEBREUVE Eric's avatar
DEBREUVE Eric committed
    """"""
    if context in source._rotate_btn:
        target.view_init(elev=source.elev, azim=source.azim)
    elif context in source._zoom_btn:
        target.set_xlim3d(source.get_xlim3d())
        target.set_ylim3d(source.get_ylim3d())
        target.set_zlim3d(source.get_zlim3d())