import warnings
import math
import torch
from spherical_models.compression_models import SphereCompressionModel
from spherical_models import SphereGaussianConditional
from compressai.models.utils import update_registered_buffers
from compressai.ans import BufferedRansEncoder, RansDecoder
from spherical_models import SLB_Downsample, SLB_Upsample
from spherical_models.sdpa_conv import sdpaconv_node_n
import numpy as np
# pyright: reportGeneralTypeIssues=warning
# From Balle's tensorflow compression examples
SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64


def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):  # pylint: disable=W0622
    return torch.exp(torch.linspace(math.log(min), math.log(max), levels))


class SphereScaleHyperprior(SphereCompressionModel):
    r"""Scale Hyperprior model for spherical data.

    Args:
        N (int): Number of channels
        M (int): Number of channels in the expansion layers (last layer of the
            encoder and last layer of the hyperprior decoder)
    """
    def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', single_conv:bool=False, context_model:str='', arrelu:bool=False, arabs:bool=False, arhops:int=2, **kwargs):
        super().__init__(entropy_bottleneck_channels=N, **kwargs)
        self.attention = attention
        self.arrelu = arrelu
        self.arabs = arabs
        self.arhops = arhops
        a_pos = 1 if attention else -1
        self.activation = activation
        ####################### g_a #######################
        self.g_a = torch.nn.ModuleList()
        self.g_a.append(SLB_Downsample(conv_name, 3, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                        activation=activation,
                                        pool_func=pool_func, pool_size_sqrt=2))
        self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                       activation=activation,
                                       pool_func=pool_func, pool_size_sqrt=2, attention_pos=a_pos))
        self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                       activation=activation,
                                       pool_func=pool_func, pool_size_sqrt=2))
        # For the last layer there is no GDN anymore:
        self.g_a.append(SLB_Downsample(conv_name, N, M, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                       pool_func=pool_func, pool_size_sqrt=2, attention_pos=a_pos))
        ####################### g_s #######################
        self.g_s = torch.nn.ModuleList()
        a_pos = (0 if attention else -1)
        self.g_s.append(SLB_Upsample(conv_name, M, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                     activation=activation, activation_args={"inverse":True},
                                     unpool_func=unpool_func, unpool_size_sqrt=2, attention_pos=a_pos))
        a_pos = (2 if attention else -1)
        self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                     activation=activation, activation_args={"inverse": True},
                                     unpool_func=unpool_func, unpool_size_sqrt=2, attention_pos=a_pos))
        self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                     activation=activation, activation_args={"inverse": True},
                                     unpool_func=unpool_func, unpool_size_sqrt=2))
        # For the last layer there is no GDN anymore:
        self.g_s.append(SLB_Upsample(conv_name, N, 3, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                     unpool_func=unpool_func, unpool_size_sqrt=2))
        ####################### h_a #######################
        self.h_a = torch.nn.ModuleList()
        # effective hop=1 => num_conv = 1, and there is no downsampling
        self.h_a.append(SLB_Downsample(conv_name, M, N, hop=1, single_conv=single_conv, skip_conn_aggr=None,
                                       activation="ReLU", activation_args={"inplace": True}))
        self.h_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                       activation="ReLU", activation_args={"inplace": True},
                                       pool_func=pool_func, pool_size_sqrt=2))
        # For the last layer there is no ReLu anymore:
        self.h_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                       pool_func=pool_func, pool_size_sqrt=2))
        ####################### h_s #######################
        self.h_s = torch.nn.ModuleList()
        self.h_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                     activation="ReLU", activation_args={"inplace": True},
                                     unpool_func=unpool_func, unpool_size_sqrt=2))
        self.h_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
                                     activation="ReLU", activation_args={"inplace": True},
                                     unpool_func=unpool_func, unpool_size_sqrt=2))
        # effective hop=1 => num_conv = 1, and there is no Upsampling
        self.h_s.append(SLB_Upsample(conv_name, N, M, hop=1, single_conv=single_conv, skip_conn_aggr=None,
                                     activation="ReLU", activation_args={"inplace": True}))
        ################## Context Model ##################
        self.context_model = context_model
        if context_model:
            mask = 'full' if context_model=='full' else 'A'
            if arrelu:
                self.autoregressive = SLB_Downsample(conv_name, M, M, hop=arhops, skip_conn_aggr=skip_conn_aggr,
                                     activation="ReLU", activation_args={"inplace": True}, mask=mask)
            else: # default
                self.autoregressive = SLB_Downsample(conv_name, M, M, hop=arhops, skip_conn_aggr=skip_conn_aggr,
                                     activation=None, mask=mask)
            self.combine_ar_hp = torch.nn.Sequential(
                SLB_Downsample('SDPAConv', 2*M, M+256, hop=1, skip_conn_aggr=skip_conn_aggr,
                               activation="ReLU", activation_args={"inplace": True}, conv1x1=True),
                SLB_Downsample('SDPAConv', M+256, M+128, hop=1, skip_conn_aggr=skip_conn_aggr,
                               activation="ReLU", activation_args={"inplace": True}, conv1x1=True),
                SLB_Downsample('SDPAConv', M+128, M, hop=1, skip_conn_aggr=skip_conn_aggr, conv1x1=True)
            )
        ###################################################
        self.gaussian_conditional = SphereGaussianConditional(None)
        self.N = int(N)
        self.M = int(M)
        self._computeResOffset()

    def _computeResOffset(self):
        # compute convolution resolution offset
        g_a_output = list(np.cumsum([layerBlock.get_output_res_offset() for layerBlock in self.g_a]))
        self._g_a_offset = [self.g_a[0].get_conv_input_res_offset()]
        self._g_a_offset.extend([self.g_a[i].get_conv_input_res_offset() + g_a_output[i-1] for i in range(1, len(self.g_a))])

        h_a_output = list(np.cumsum([layerBlock.get_output_res_offset() for layerBlock in self.h_a]))
        h_a_output = [res+g_a_output[-1] for res in h_a_output]
        self._h_a_offset = [self.h_a[0].get_conv_input_res_offset() + g_a_output[-1]]
        self._h_a_offset.extend([self.h_a[i].get_conv_input_res_offset() + h_a_output[i-1] for i in range(1, len(self.h_a))])

        h_s_output = list(np.cumsum([layerBlock.get_output_res_offset() for layerBlock in self.h_s]))
        h_s_output = [res+h_a_output[-1] for res in h_s_output]
        self._h_s_offset = [self.h_s[0].get_conv_input_res_offset() + h_a_output[-1]]
        self._h_s_offset.extend([self.h_s[i].get_conv_input_res_offset() + h_s_output[i-1] for i in range(1, len(self.h_s))])

        assert h_s_output[-1] == g_a_output[-1], "resolutions do not match"

        g_s_output = list(np.cumsum([layerBlock.get_output_res_offset() for layerBlock in self.g_s]))
        g_s_output = [res+g_a_output[-1] for res in g_s_output]
        self._g_s_offset = [self.g_s[0].get_conv_input_res_offset() + g_a_output[-1]]
        self._g_s_offset.extend([self.g_s[i].get_conv_input_res_offset() + g_s_output[i-1] for i in range(1, len(self.g_s))])

    def get_resOffset(self):
        return set(self._g_a_offset + self._h_a_offset + self._h_s_offset + self._g_s_offset)

    def forward(self, x, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None):     # x is a tensor of size [batch_size, num_nodes, num_features]
        data_res = res if patch_res is None else (res, patch_res)
        ########### apply g_a ###########
        y = x
        for i in range(len(self.g_a)):
            conv_res = type(data_res)(np.add(data_res, self._g_a_offset[i]))
            if (self.activation.upper()=='RB') or (self.attention and (self.g_a[i].attention_pos > 0)): # for attention after convolution or RBs as nonlinearity
                conv_res_ = type(data_res)(np.add(data_res, self._g_a_offset[i+1] if i<(len(self.g_a)-1) else self._g_s_offset[0]))
                index_, weight_, valid_index_ = dict_index[conv_res_], dict_weight[conv_res_], dict_valid_index[conv_res_] if dict_valid_index is not None else None
            else:
                index_, weight_, valid_index_ = None, None, None
            y = self.g_a[i](y, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
                index_=index_, weight_=weight_, valid_index_=valid_index_)
        # print("applying g_a")
        # print("y.mean()=", y.mean(), "x.mean()=", x.mean())
        # print("y.max()=", y.max(), "y.min()=", y.min())
        # print("x.max()=", x.max(), "x.min()=", x.min())
        ########### apply h_a ###########
        z = torch.abs(y)
        for i in range(len(self.h_a)):
            conv_res = type(data_res)(np.add(data_res, self._h_a_offset[i]))
            z = self.h_a[i](z, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
        # print("applying h_a")
        # print("z.mean()=", z.mean(), "torch.abs(y).mean()=", torch.abs(y).mean())
        # print("z.max()=", z.max(), "z.min()=", z.min())
        # print("torch.abs(y).max()=", torch.abs(y).max(), "torch.abs(y).min()=", torch.abs(y).min())
        z_hat, z_likelihoods = self.entropy_bottleneck(z)
        ########### apply h_s ###########
        scales_hat = z_hat
        for i in range(len(self.h_s)):
            conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
            scales_hat = self.h_s[i](scales_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
        # print("applying h_s")
        # print("scales_hat.mean()=", scales_hat.mean(), "z_hat.mean()=", z_hat.mean())
        # print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min())
        # print("z_hat.max()=", z_hat.max(), "z_hat.min()=", z_hat.min())
        ###### apply context model ######
        skip_quant = bool(self.context_model)
        if skip_quant: # take quantization out of self.gaussian_conditional.forward()
            y_hat = self.gaussian_conditional.quantize(torch.abs(y) if self.arabs else y, 'noise' if self.gaussian_conditional.training else 'dequantize', means=None)
            conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
            phi_sp = self.autoregressive(y_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
            scales_hat = self.combine_ar_hp(torch.cat([scales_hat, phi_sp], dim=-1))
            # print("applying context model")
            # print("scales_hat.mean()=", scales_hat.mean(), "y_hat.mean()=", y_hat.mean())
            # print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min())
            # print("y_hat.max()=", y_hat.max(), "y_hat.min()=", y_hat.min())
        y_hat, y_likelihoods = self.gaussian_conditional(y if not skip_quant else y_hat, scales_hat, skip_quant=skip_quant)
        ########### apply g_s ###########
        x_hat = y_hat
        for i in range(len(self.g_s)):
            conv_res = type(data_res)(np.add(data_res, self._g_s_offset[i]))
            if (self.activation.upper()=='RB') or (self.attention and (self.g_s[i].attention_pos > 0)): # for attention after convolution
                conv_res_ = type(data_res)(np.add(data_res, self._g_s_offset[i+1] if i<(len(self.g_s)-1) else self._g_a_offset[0]))
                index_, weight_, valid_index_ = dict_index[conv_res_], dict_weight[conv_res_], dict_valid_index[conv_res_] if dict_valid_index is not None else None
            else:
                index_, weight_, valid_index_ = None, None, None
            x_hat = self.g_s[i](x_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
                index_=index_, weight_=weight_, valid_index_=valid_index_)
        # print("applying g_s")
        # print("x_hat.mean()=", x_hat.mean(), "y_hat.mean()=", y_hat.mean())
        # print("x_hat.max()=", x_hat.max(), "x_hat.min()=", x_hat.min())
        # print("y_hat.max()=", y_hat.max(), "y_hat.min()=", y_hat.min())
        # with torch.no_grad():
        #     print("input/out mean ratio=", x.mean()/x_hat.mean())

        return {
            'x_hat': x_hat,
            'likelihoods': {
                'y': y_likelihoods,
                'z': z_likelihoods
            },
        }

    def load_state_dict(self, state_dict):
        update_registered_buffers(
            self.gaussian_conditional, "gaussian_conditional",
            ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
            state_dict)
        super().load_state_dict(state_dict)

    def update(self, scale_table=None, force=False):
        if scale_table is None:
            scale_table = get_scale_table()
        updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
        updated |= super().update(force=force)
        return updated

    def compress(self, x, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None):
        if self.context_model:
            if next(self.parameters()).device != torch.device("cpu"):
                warnings.warn(
                    "Inference on GPU is not recommended for the autoregressive "
                    "models (the entropy coder is run sequentially on CPU).",
                    stacklevel=2)
        data_res = res if patch_res is None else (res, patch_res)
        ########### apply g_a ###########
        y = x
        for i in range(len(self.g_a)):
            conv_res = type(data_res)(np.add(data_res, self._g_a_offset[i]))
            if (self.activation.upper()=='RB') or (self.attention and (self.g_a[i].attention_pos > 0)):
                conv_res_ = type(data_res)(np.add(data_res, self._g_a_offset[i+1] if i<(len(self.g_a)-1) else self._g_s_offset[0]))
                index_, weight_, valid_index_ = dict_index[conv_res_], dict_weight[conv_res_], dict_valid_index[conv_res_] if dict_valid_index is not None else None
            else:
                index_, weight_, valid_index_ = None, None, None
            y = self.g_a[i](y, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
                index_=index_, weight_=weight_, valid_index_=valid_index_)
        ########### apply h_a ###########
        z = torch.abs(y)
        for i in range(len(self.h_a)):
            conv_res = type(data_res)(np.add(data_res, self._h_a_offset[i]))
            z = self.h_a[i](z, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)

        z_strings = self.entropy_bottleneck.compress(z)
        z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[1])
        ########### apply h_s ###########
        scales_hat = z_hat
        for i in range(len(self.h_s)):
            conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
            scales_hat = self.h_s[i](scales_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
        if not self.context_model:
            indexes = self.gaussian_conditional.build_indexes(scales_hat)
            y_strings = self.gaussian_conditional.compress(y, indexes)
        else:
            y_strings = []
            conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
            for i in range(y.size(0)):
                string = self._compress_ar(
                    y[i:i+1],
                    scales_hat[i:i+1],
                    dict_index[conv_res],
                    dict_weight[conv_res],
                    self.autoregressive.skipconn,
                )
                y_strings.append(string)
        return {"strings": [y_strings, z_strings], "shape": z.size()[1]}
    
    def _compress_ar(self, y_hat, params, neighbors_indices, neighbors_weights, skipconn):
        assert len(self.autoregressive.list_conv) in [1,2], "only 1 or 2-hop convolution supported for autoregressive context model"
        neighbors_indices = neighbors_indices.to(y_hat.device)
        neighbors_weights = neighbors_weights.to(y_hat.device)
        
        cdf = self.gaussian_conditional._quantized_cdf.tolist()
        cdf_lengths = self.gaussian_conditional._cdf_length.tolist()
        offsets = self.gaussian_conditional._offset.tolist()

        encoder = BufferedRansEncoder()
        symbols_list = []
        indexes_list = []

        n_nodes = y_hat.size(1)
        weights = [conv.weight for conv in self.autoregressive.list_conv]
        bias = [conv.bias for conv in self.autoregressive.list_conv]
        # causal mask for convolution
        for n in range(n_nodes):
            ctx_p = sdpaconv_node_n(
                torch.abs(y_hat) if self.arabs else y_hat,
                self.arhops,
                weights,
                bias,
                n,
                neighbors_indices,
                neighbors_weights,
                skipconn,
                masked=True,
            )
            if self.arrelu:
                ctx_p = torch.nn.functional.relu(ctx_p)
            # 1x1 conv for the entropy parameters prediction network, so
            # we only keep the elements in the "center"
            p = params[:, n:n+1, :]
            scales_hat = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1))
            scales_hat = scales_hat.squeeze(1)

            indexes = self.gaussian_conditional.build_indexes(scales_hat)

            y_q = self.gaussian_conditional.quantize(y_hat[:, n, :], "symbols")
            y_hat[:, n, :] = y_q

            symbols_list.extend(y_q.squeeze().tolist())
            indexes_list.extend(indexes.squeeze().tolist())

        encoder.encode_with_indexes(
            symbols_list,
            indexes_list,
            cdf,
            cdf_lengths,
            offsets,
        )
        string = encoder.flush()
        return string
    
    def decompress(self, strings, shape, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None):
        assert isinstance(strings, list) and len(strings) == 2
        if self.context_model:
            if next(self.parameters()).device != torch.device("cpu"):
                warnings.warn(
                    "Inference on GPU is not recommended for the autoregressive "
                    "models (the entropy coder is run sequentially on CPU).",
                    stacklevel=2,
                )
        z_hat = self.entropy_bottleneck.decompress(strings[1], shape)

        data_res = res if patch_res is None else (res, patch_res)
        ########### apply h_s ###########
        scales_hat = z_hat
        for i in range(len(self.h_s)):
            conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
            scales_hat = self.h_s[i](scales_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
        if self.context_model:
            s = 4**2 # scaling factor between z and y
            y_hat = torch.zeros(z_hat.size(0), s*shape, self.M, dtype=z_hat.dtype, device=z_hat.device)
            conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
            for i, y_string in enumerate(strings[0]):
                self._decompress_ar(
                    y_string,
                    y_hat[i:i+1],
                    scales_hat[i:i+1],
                    dict_index[conv_res],
                    dict_weight[conv_res],
                    self.autoregressive.skipconn,
                )
        else:
            indexes = self.gaussian_conditional.build_indexes(scales_hat)
            y_hat = self.gaussian_conditional.decompress(strings[0], indexes)
        ########### apply g_s ###########
        x_hat = y_hat
        for i in range(len(self.g_s)):
            conv_res = type(data_res)(np.add(data_res, self._g_s_offset[i]))
            if (self.activation.upper()=='RB') or (self.attention and (self.g_s[i].attention_pos > 0)):
                conv_res_ = type(data_res)(np.add(data_res, self._g_s_offset[i+1] if i<(len(self.g_s)-1) else self._g_a_offset[0]))
                index_, weight_, valid_index_ = dict_index[conv_res_], dict_weight[conv_res_], dict_valid_index[conv_res_] if dict_valid_index is not None else None
            else:
                index_, weight_, valid_index_ = None, None, None
            x_hat = self.g_s[i](x_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
                index_=index_, weight_=weight_, valid_index_=valid_index_)
        x_hat = x_hat.clamp_(0, 1)
        return {"x_hat": x_hat}
    
    def _decompress_ar(self, y_string, y_hat, params, neighbors_indices, neighbors_weights, skipconn):
        cdf = self.gaussian_conditional._quantized_cdf.tolist()
        cdf_lengths = self.gaussian_conditional._cdf_length.tolist()
        offsets = self.gaussian_conditional._offset.tolist()

        neighbors_indices = neighbors_indices.to(y_hat.device)
        neighbors_weights = neighbors_weights.to(y_hat.device)
        
        decoder = RansDecoder()
        decoder.set_stream(y_string)
        
        n_nodes = y_hat.size(1)
        weights = [conv.weight for conv in self.autoregressive.list_conv]
        bias = [conv.bias for conv in self.autoregressive.list_conv]
        # Warning: this is slow due to the auto-regressive nature of the decoding
        for n in range(n_nodes):
            ctx_p = sdpaconv_node_n(
                torch.abs(y_hat) if self.arabs else y_hat,
                self.arhops,
                weights,
                bias,
                n,
                neighbors_indices,
                neighbors_weights,
                skipconn,
                masked=False,
            )
            if self.arrelu:
                ctx_p = torch.nn.functional.relu(ctx_p)
            # 1x1 conv for the entropy parameters prediction network, so
            # we only keep the elements in the "center"
            p = params[:, n:n+1, :]
            scales_hat = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1))

            indexes = self.gaussian_conditional.build_indexes(scales_hat)
            rv = decoder.decode_stream(
                indexes.squeeze().tolist(),
                cdf,
                cdf_lengths,
                offsets,
            )
            rv = torch.Tensor(rv).reshape(1, 1, -1)
            rv = self.gaussian_conditional.dequantize(rv)
            y_hat[:, n:n+1, :] = rv


if __name__ == '__main__':
    ssh = SphereScaleHyperprior(128, 192, "SDPAConv", "sum", "max_pool", "nearest")
    print(ssh.get_resOffset())