# Copyright CNRS/Inria/UNS
# Contributor(s): Eric Debreuve (since 2019), Morgane Nadal (2020)
#
# eric.debreuve@cnrs.fr
#
# This software is governed by the CeCILL  license under French law and
# abiding by the rules of distribution of free software.  You can  use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
#
# As a counterpart to the access to the source code and  rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty  and the software's author,  the holder of the
# economic rights,  and the successive licensors  have only  limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using,  modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean  that it is complicated to manipulate,  and  that  also
# therefore means  that it is reserved for developers  and  experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and,  more generally, to use and operate it in the
# same conditions as regards security.
#
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

import re as re_
import numpy as np_
import math as mt_
import scipy.stats as st_
import pandas as pd_

from brick.component.soma import soma_t
from brick.general.type import array_t
import brick.processing.best_fit_ellipsoid as bf_
import brick.processing.input as in_
from typing import Tuple, Dict, Union, Any


def FindGraphsRootWithEdges(soma: soma_t, ext_nfo: Dict[str, Union[array_t, Any]]) -> dict:
    """
    Finds the soma roots of the graph extension.
    """
    # For a given soma, find the roots of the graphs
    root_nodes = {}

    # Finds the primary extensions
    primary_extension_uids = tuple(extension.uid for extension in soma.extensions)
    print(primary_extension_uids, '\nn = ', len(primary_extension_uids))

    # List of the degree 1 nodes of the graph
    for node1_id, node2_id, edge_nfo in soma.skl_graph.edges.data('as_edge_t'):
        if (soma.skl_graph.degree[node1_id] == 1) or (soma.skl_graph.degree[node2_id] == 1):

            # Find the pixels of the terminal extension
            sites = ext_nfo['lmp'][edge_nfo.sites]
            ext_uid = np_.unique(sites)[-1]
            # sites > 0 because ext_nfo['lmp'] do not contain the connexions

            # Save the root node candidates (one-degree nodes)
            if ext_uid in primary_extension_uids:
                if soma.skl_graph.degree[node1_id] == 1:
                    root_node = node1_id
                else:
                    root_node = node2_id

                # Get the node coordinates and extend them to the 26 neighboring voxels
                root_node_coor = GetNodesCoordinates((root_node,))[0]  # tuple('x-y-z') -> list[(x,y,z)]

                root_sites = set(
                    (root_node_coor[0] + i, root_node_coor[1] + j, root_node_coor[2] + k)
                    for i in (-1, 0, 1)
                    for j in (-1, 0, 1)
                    for k in (-1, 0, 1)
                    if i != 0 or j != 0 or k != 0)

                # Find the intersection between the extended root node candidate and the soma contour points
                intersections = set(soma.contour_points).intersection(root_sites)

                # if the graph root sites are included in the soma extensions sites (non-nul intersection):
                if len(intersections) > 0:
                    # Keep the info of the root node. Key = ext uid, Value = root node
                    root_nodes[ext_uid] = root_node
                        ## By construction, only one root node possible for an ext

    return root_nodes  # TODO: find out why there are less root points than extensions !!


def FindGraphsRootWithNodes(soma: soma_t) -> dict:
    """
    Find the roots of the {extension+connexion} graphs to be lined to the soma.
    Add a key "root" (bool) in the dict of nodes attributes.
    """

    node_degree_bool = tuple(degree == 1 for _, degree in soma.skl_graph.degree)
    node_coord = tuple(xyz for xyz, _ in soma.skl_graph.degree)

    root_nodes = {}

    # get the coordinates of the nodes (x,y,z)
    coordinates = GetNodesCoordinates(node_coord)

    # get a list with elements = (soma_uid, extension_uid, root coordinates)
    roots = GetListRoots(soma)

    # for each node in the graph, search among the degree 1 nodes the nodes that are roots (linked to soma)
    for node in range(len(coordinates)):
        if node_degree_bool[node]:
            # compare the coor with end points
            for ext_root in roots:
                if ext_root[1] == coordinates[node]:
                    root_nodes[ext_root[0]] = node_coord[node]

    return root_nodes


def GetListRoots(soma: soma_t) -> list:
    """
    Gives a list containing the following information for all somas: [soma id: int, extension id: int, root = (x,y,z): tuple]
    """
    roots = []
    for ext_id, ext_root in enumerate(soma.ext_roots):
        roots.append((soma.extensions[ext_id].uid, ext_root))
    return roots


def GetNodesCoordinates(node_coord: Tuple[str, ...]) -> list:
    """
    Input: nodes attributes -> Tuple('x1-y1-z1', 'x2-y2-z2', ...) .
    Output: coordinates -> List[Tuple(x1,y1,z1), Tuple(x2,y2,z2), ...]
    """
    coord = []
    for c in node_coord:
        coord.append(c)

    for node in range(len(node_coord)):
        coord_node = coord[node]
        pattern = '\d+'
        coord_node = re_.findall(pattern, coord_node)
        coor = []
        for i in range(3):
            coor.append(int(coord_node[i]))
        coor = tuple(coor)
        coord[node] = coor

    return coord


def ExtractFeaturesInDF(somas, size_voxel_in_micron: list, number_of_bins: int, max_range: float, hist_min_length: float, scale_map: array_t, decimals: int = 4):
    """
    Extract the features from somas and graphs.
    Returns a pandas dataframe.
    """
    somas_features_dict = {}  # Dict{soma 1: [features], soma 2: [features], ...}
    columns = [
        "Coef_V_soma__V_convex_hull",
        # "theta_a",
        # "phi_a",
        # "theta_b",
        # "phi_b",
        "Coef_axes_ellips_y__x",
        "Coef_axes_ellips_z__x",
        #
        "N_nodes",
        "N_ext",
        "N_primary_ext",
        "N_sec_ext",
        "min_degree",
        "mean_degree",
        "median_degree",
        "max_degree",
        "std_degree",
        #
        "total_ext_length",
        "min_length",
        "mean_length",
        "median_length",
        "max_length",
        "std_lengths",
        "entropy_lengths",
        "hist_lengths",
        "min_thickness",
        "mean_thickness",
        "median_thickness",
        "max_thickness",
        "std_thickness",
        "entropy_thickness",
        "min_volume",
        "mean_volume",
        "median_volume",
        "max_volume",
        "std_volume",
        "entropy_volume",
        #
        "total_ext_length_P",
        "min_length_P",
        "mean_length_P",
        "median_length_P",
        "max_length_P",
        "std_lengths_P",
        "entropy_lengths_P",
        "hist_lengths_P",
        "min_thickness_P",
        "mean_thickness_P",
        "median_thickness_P",
        "max_thickness_P",
        "std_thickness_P",
        "entropy_thickness_P",
        "min_volume_P",
        "mean_volume_P",
        "median_volume_P",
        "max_volume_P",
        "std_volume_P",
        "entropy_volume_P",
        #
        "total_ext_length_S",
        "min_length_S",
        "mean_length_S",
        "median_length_S",
        "max_length_S",
        "std_lengths_S",
        "entropy_lengths_S",
        "hist_lengths_S",
        "min_thickness_S",
        "mean_thickness_S",
        "median_thickness_S",
        "max_thickness_S",
        "std_thickness_S",
        "entropy_thickness_S",
        "min_volume_S",
        "mean_volume_S",
        "median_volume_S",
        "max_volume_S",
        "std_volume_S",
        "entropy_volume_S",
    ]

    for soma in somas:
        # Soma features
        # print('***Soma***')
        # # Volume of the soma
        volume_pixel_micron = round(np_.prod(size_voxel_in_micron), 4)
        soma.volume_soma_micron = volume_pixel_micron * len(soma.sites[0])
        volume_convex_hull = volume_pixel_micron * bf_.GetConvexHull3D(soma.sites)[1]
        Coef_V_soma__V_convex_hull = soma.volume_soma_micron / volume_convex_hull
        # print(
        #     f"Volume soma = {soma.volume_soma_micron}\n"
        #     f"Volume soma / Volume Convex Hull = {Coef_V_soma__V_convex_hull}"
        #       )

        # # Axes of the best fitting ellipsoid
        soma.axes_ellipsoid = bf_.FindBestFittingEllipsoid3D(soma)[2]
        Coef_axes_ellips_y__x = soma.axes_ellipsoid[1] / soma.axes_ellipsoid[0]
        Coef_axes_ellips_z__x = soma.axes_ellipsoid[2] / soma.axes_ellipsoid[0]

        # -- Extension features
        # # Graph features
        N_nodes = soma.skl_graph.n_nodes  # number of nodes
        N_ext = soma.skl_graph.n_edges - len(
            soma.graph_roots)  # number of edges except the constructed ones from node soma to the roots
        N_primary_ext = len(
            soma.graph_roots)  # number of primary edges = linked to the soma except the constructed ones from node soma to the roots
        N_sec_ext = N_ext - N_primary_ext  # number of secondary edges = not linked to the soma.

        print(
            # f"\n***Extension***\n"
            f"\n Soma {soma.uid}\n"
            f"N nodes = {N_nodes}\n"
            f"N edges = {N_ext}\n"
            f"N primary extensions = {N_primary_ext}\n"
            f"N secondary extensions = {N_sec_ext}\n"
        )

        if N_primary_ext > 0:
            # Calculate the extensions lengths
            ext_lengths = list(soma.skl_graph.edge_lengths)
            for idx, length in enumerate(ext_lengths):
                ext_lengths[idx] = in_.ToMicron(length, size_voxel_in_micron, decimals=decimals)
            total_ext_length = in_.ToMicron(soma.skl_graph.length, size_voxel_in_micron, decimals=decimals)
            #
            # Lengths histogram
            hist_lengths = np_.histogram(ext_lengths, bins=number_of_bins, range=(hist_min_length, max_range))[0]
            #
            # min, mean, median, max and standard deviation of the ALL extensions
            min_length = in_.ToMicron(soma.skl_graph.min_length, size_voxel_in_micron, decimals=decimals)
            mean_length = in_.ToMicron(soma.skl_graph.mean_length, size_voxel_in_micron, decimals=decimals)
            median_length = in_.ToMicron(soma.skl_graph.median_length, size_voxel_in_micron, decimals=decimals)
            max_length = in_.ToMicron(soma.skl_graph.max_length, size_voxel_in_micron, decimals=decimals)
            std_lengths = np_.std(ext_lengths)
            entropy_lengths = st_.entropy(ext_lengths)
            #
            # Curvature
            for _, _, edge in soma.skl_graph.edges.data("as_edge_t"):
                if edge is not None:
                    edge.SetEndPointDirections(size_voxel_in_micron)
                    for point in

            # Find the thickness of the extensions
            for ___, ___, edge in soma.skl_graph.edges.data("as_edge_t"):
                if edge is not None:
                    edge.widths = scale_map[edge.sites] * size_voxel_in_micron[1]
            mean_widths = soma.skl_graph.edge_reduced_widths()
            ext_thickness = np_.array(mean_widths) ** 2
            min_thickness = min(ext_thickness)
            mean_thickness = np_.mean(ext_thickness)
            median_thickness = np_.median(ext_thickness)
            max_thickness = max(ext_thickness)
            std_thickness = np_.std(ext_thickness)
            entropy_thickness = st_.entropy(ext_thickness)
            #
            ext_volume = np_.array(ext_lengths) * ext_thickness
            min_volume = min(ext_volume)
            mean_volume = np_.mean(ext_volume)
            median_volume = np_.median(ext_volume)
            max_volume = max(ext_volume)
            std_volume = np_.std(ext_volume)
            entropy_volume = st_.entropy(ext_volume)

            # print(
            #     f"ALL EXTENSIONS\n  Total Length = {total_ext_length} <- {ext_lengths}\n"
            #     f"  Min/Mean/Median/Max Length = {min_length} / {mean_length} / {median_length} / {max_length}\n"
            #     f"  Standard Deviation = {std_lengths} / Entropy = {entropy_lengths}")

            # pl_.plot(hist_lengths[1][:-1], hist_lengths[0])

            # PRIMARY extensions
            ext_lengths_P = list(soma.skl_graph.primary_edge_lengths(soma))
            for idx, length in enumerate(ext_lengths_P):
                ext_lengths_P[idx] = in_.ToMicron(length, size_voxel_in_micron, decimals=decimals)
            total_ext_length_P = sum(ext_lengths_P)
            #
            # Lengths histogram
            hist_lengths_P = np_.histogram(ext_lengths_P, bins=number_of_bins, range=(hist_min_length, max_range))[0]
            #
            # min, mean, median, max and standard deviation of the PRIMARY extensions
            min_length_P = min(ext_lengths_P)
            mean_length_P = np_.mean(ext_lengths_P)
            median_length_P = np_.median(ext_lengths_P)
            max_length_P = max(ext_lengths_P)
            std_lengths_P = np_.std(ext_lengths_P)
            entropy_lengths_P = st_.entropy(ext_lengths_P)
            #
            mean_widths_P = soma.skl_graph.P_edge_reduced_widths(soma)
            ext_thickness_P = np_.array(mean_widths_P) ** 2

            min_thickness_P = min(ext_thickness_P)
            mean_thickness_P = np_.mean(ext_thickness_P)
            median_thickness_P = np_.median(ext_thickness_P)
            max_thickness_P = max(ext_thickness_P)
            std_thickness_P = np_.std(ext_thickness_P)
            entropy_thickness_P = st_.entropy(ext_thickness_P)
            #
            #
            ext_volume_P = np_.array(ext_lengths_P) * ext_thickness_P
            min_volume_P = min(ext_volume_P)
            mean_volume_P = np_.mean(ext_volume_P)
            median_volume_P = np_.median(ext_volume_P)
            max_volume_P = max(ext_volume_P)
            std_volume_P = np_.std(ext_volume_P)
            entropy_volume_P = st_.entropy(ext_volume_P)

            # print(
            #     f"PRIMARY EXTENSIONS\n  Total Length = {total_ext_length_P}\n"
            #     f"  Min/Mean/Median/Max Length = {min_length_P} / {mean_length_P} / {median_length_P} / {max_length_P}\n"
            #     f"  Standard Deviation = {std_lengths_P} / Entropy = {entropy_lengths_P}")

            # pl_.plot(hist_lengths_P[1][:-1], hist_lengths_P[0])

            if N_sec_ext > 0:
                # min, mean, median, max and standard deviation of the degrees of non-leaves nodes
                min_degree = soma.skl_graph.min_degree_except_leaves_and_roots
                mean_degree = soma.skl_graph.mean_degree_except_leaves_and_roots
                median_degree = soma.skl_graph.median_degree_except_leaves_and_roots
                max_degree = soma.skl_graph.max_degree_except_leaves_an_roots
                std_degree = soma.skl_graph.std_degree_except_leaves_and_roots

                # SECONDARY extensions length
                ext_lengths_S = list(soma.skl_graph.secondary_edge_lengths(soma))
                for idx, length in enumerate(ext_lengths_S):
                    ext_lengths_S[idx] = in_.ToMicron(length, size_voxel_in_micron, decimals=decimals)
                total_ext_length_S = sum(ext_lengths_S)
                #
                # Lengths histogram
                hist_lengths_S = np_.histogram(ext_lengths_S, bins=number_of_bins, range=(hist_min_length, max_range))[0]
                #
                # min, mean, median, max and standard deviation of the PRIMARY extensions
                min_length_S = min(ext_lengths_S)
                mean_length_S = np_.mean(ext_lengths_S)
                median_length_S = np_.median(ext_lengths_S)
                max_length_S = max(ext_lengths_S)
                std_lengths_S = np_.std(ext_lengths_S)
                entropy_lengths_S = st_.entropy(ext_lengths_S)
                #
                mean_widths_S = soma.skl_graph.S_edge_reduced_widths(soma)
                ext_thickness_S = np_.array(mean_widths_S) ** 2
                min_thickness_S = min(ext_thickness_S)
                mean_thickness_S = np_.mean(ext_thickness_S)
                median_thickness_S = np_.median(ext_thickness_S)
                max_thickness_S = max(ext_thickness_S)
                std_thickness_S = np_.std(ext_thickness_S)
                entropy_thickness_S = st_.entropy(ext_thickness_S)
                #
                ext_volume_S = np_.array(ext_lengths_S) * ext_thickness_S
                min_volume_S = min(ext_volume_S)
                mean_volume_S = np_.mean(ext_volume_S)
                median_volume_S = np_.median(ext_volume_S)
                max_volume_S = max(ext_volume_S)
                std_volume_S = np_.std(ext_volume_S)
                entropy_volume_S = st_.entropy(ext_volume_S)

                # print(
                #     f"SECONDARY EXTENSIONS\n  Total Length = {total_ext_length_S}\n"
                #     f"  Min/Mean/Median/Max Length = {min_length_S} / {mean_length_S} / {median_length_S} / {max_length_S}\n"
                #     f"  Standard Deviation = {std_lengths_S} / Entropy = {entropy_lengths_S}"
                # )

                # pl_.plot(hist_lengths_S[1][:-1], hist_lengths_S[0])

            if N_sec_ext == 0:
                 # min, mean, median, max and standard deviation of the degrees of non-leaves nodes
                min_degree = 1
                mean_degree = 1
                median_degree = 1
                max_degree = 1
                std_degree = 0

                total_ext_length_S = 0
                min_length_S = 0
                mean_length_S = 0
                median_length_S = 0
                max_length_S = 0
                std_lengths_S = 0
                entropy_lengths_S = 0
                hist_lengths_S = 0
                #
                min_thickness_S = 0
                mean_thickness_S = 0
                median_thickness_S = 0
                max_thickness_S = 0
                std_thickness_S = 0
                entropy_thickness_S = 0
                #
                min_volume_S = 0
                mean_volume_S = 0
                median_volume_S = 0
                max_volume_S = 0
                std_volume_S = 0
                entropy_volume_S = 0

        else:
            min_degree = 0
            mean_degree = 0
            median_degree = 0
            max_degree = 0
            std_degree = 0
            #
            total_ext_length = 0
            min_length = 0
            mean_length = 0
            median_length = 0
            max_length = 0
            std_lengths = 0
            entropy_lengths = 0
            hist_lengths = 0
            min_thickness = 0
            mean_thickness = 0
            median_thickness = 0
            max_thickness = 0
            std_thickness = 0
            entropy_thickness = 0
            min_volume = 0
            mean_volume = 0
            median_volume = 0
            max_volume = 0
            std_volume = 0
            entropy_volume = 0
            #
            total_ext_length_P = 0
            min_length_P = 0
            mean_length_P = 0
            median_length_P = 0
            max_length_P = 0
            std_lengths_P = 0
            entropy_lengths_P = 0
            hist_lengths_P = 0
            min_thickness_P = 0
            mean_thickness_P = 0
            median_thickness_P = 0
            max_thickness_P = 0
            std_thickness_P = 0
            entropy_thickness_P = 0
            min_volume_P = 0
            mean_volume_P = 0
            median_volume_P = 0
            max_volume_P = 0
            std_volume_P = 0
            entropy_volume_P = 0
            #
            total_ext_length_S = 0
            min_length_S = 0
            mean_length_S = 0
            median_length_S = 0
            max_length_S = 0
            std_lengths_S = 0
            entropy_lengths_S = 0
            hist_lengths_S = 0
            min_thickness_S = 0
            mean_thickness_S = 0
            median_thickness_S = 0
            max_thickness_S = 0
            std_thickness_S = 0
            entropy_thickness_S = 0
            min_volume_S = 0
            mean_volume_S = 0
            median_volume_S = 0
            max_volume_S = 0
            std_volume_S = 0
            entropy_volume_S = 0

        #
        # print(
        #     f"NODES DEGREES\n"
        #     f"Min/Mean/Median/Max degree (except soma & leaves) = {min_degree} / {mean_degree} / {median_degree} / {max_degree}\n"
        #     f"Standard deviation (except soma & leaves) = {std_degree}\n\n"
        # )

        somas_features_dict[f"soma {soma.uid}"] = [
            Coef_V_soma__V_convex_hull,
            Coef_axes_ellips_y__x,
            Coef_axes_ellips_z__x,
            N_nodes,
            N_ext,
            N_primary_ext,
            N_sec_ext,
            min_degree,
            mean_degree,
            median_degree,
            max_degree,
            std_degree,
            #
            total_ext_length,
            min_length,
            mean_length,
            median_length,
            max_length,
            std_lengths,
            entropy_lengths,
            hist_lengths,
            min_thickness,
            mean_thickness,
            median_thickness,
            max_thickness,
            std_thickness,
            entropy_thickness,
            min_volume,
            mean_volume,
            median_volume,
            max_volume,
            std_volume,
            entropy_volume,
            #
            total_ext_length_P,
            min_length_P,
            mean_length_P,
            median_length_P,
            max_length_P,
            std_lengths_P,
            entropy_lengths_P,
            hist_lengths_P,
            min_thickness_P,
            mean_thickness_P,
            median_thickness_P,
            max_thickness_P,
            std_thickness_P,
            entropy_thickness_P,
            min_volume_P,
            mean_volume_P,
            median_volume_P,
            max_volume_P,
            std_volume_P,
            entropy_volume_P,
            #
            total_ext_length_S,
            min_length_S,
            mean_length_S,
            median_length_S,
            max_length_S,
            std_lengths_S,
            entropy_lengths_S,
            hist_lengths_S,
            min_thickness_S,
            mean_thickness_S,
            median_thickness_S,
            max_thickness_S,
            std_thickness_S,
            entropy_thickness_S,
            min_volume_S,
            mean_volume_S,
            median_volume_S,
            max_volume_S,
            std_volume_S,
            entropy_volume_S,
        ]


    features_df = pd_.DataFrame.from_dict(somas_features_dict, orient="index", columns=columns)

    return features_df