Mentions légales du service

Skip to content
Snippets Groups Projects
healpix_sdpa_struct_loader.py 28.6 KiB
Newer Older
Navid's avatar
Navid committed
import os
import torch
import numpy as np
import healpy as hp
from utils import healpix as hp_utils
PaulWawerek-L's avatar
PaulWawerek-L committed
# pyright: reportGeneralTypeIssues=warning
Navid's avatar
Navid committed

class HealpixSdpaStructLoader:
    def __init__(self, weight_type, use_geodesic, use_4connectivity, normalization_method, cutGraphForPatchOutside, load_save_folder=None):
        self.weight_type = weight_type
        self.use_geodesic = use_geodesic
        self.use_4connectivity = use_4connectivity
        self.isNest = True
        self.folder = load_save_folder
        self.normalization_method = normalization_method
        self.cutGraph = cutGraphForPatchOutside
        if self.folder:
            os.makedirs(self.folder, exist_ok=True)

PaulWawerek-L's avatar
PaulWawerek-L committed
    def getStruct(self, sampling_res, num_hops, patch_res=None, patch_id=-1) -> tuple:
Navid's avatar
Navid committed
        if (num_hops is None) or (num_hops <= 0):
            num_hops = 1

        if self.folder:
PaulWawerek-L's avatar
PaulWawerek-L committed
            filename = f"sdpa_{self.weight_type}_{self.normalization_method}_{self.use_geodesic}_{self.use_4connectivity}_{sampling_res}_{num_hops}"
Navid's avatar
Navid committed
            if patch_res:
                filename += f"_{patch_res}_{patch_id}_{self.cutGraph}"
Navid's avatar
Navid committed
            filename += ".pth"
            file_address = os.path.join(self.folder, filename)
            if os.path.isfile(file_address):
PaulWawerek-L's avatar
PaulWawerek-L committed
                # print(f"Loading file {file_address}")
Navid's avatar
Navid committed
                data_dict = torch.load(file_address)
                index = data_dict.get("index", None)
                weight = data_dict.get("weight", None)
                valid_neighbors = data_dict.get("mask_valid", None)
                if patch_res is None:
                    return index, weight, valid_neighbors
                nodes = data_dict.get("nodes", None)
                mapping = data_dict.get("mapping", None)
                return index, weight, valid_neighbors, nodes, mapping

        if patch_res is None:
            nside = hp.order2nside(sampling_res)  # == 2 ** sampling_resolution
            nPix = hp.nside2npix(nside)
            pixel_id = np.arange(0, nPix, dtype=int)
Navid's avatar
Navid committed

            index, weight, valid_neighbors = hp_utils.k_hop_healpix_weightmatrix(resolution=sampling_res,
                                                                                 weight_type=self.weight_type,
                                                                                 use_geodesic=self.use_geodesic,
                                                                                 use_4=self.use_4connectivity,
                                                                                 nodes_id=pixel_id,
                                                                                 dtype=np.float32,
                                                                                 nest=self.isNest,
                                                                                 num_hops=num_hops,
                                                                                 )
            # print("weights before=", weight[:10,:])
            # print("valid neighbor before=", valid_neighbors[:10, :])
            index, weight = self.__normalize(index, weight, valid_neighbors, self.normalization_method)
            # print("after=", weight[:10, :])
            # print("valid neighbor after=", valid_neighbors[:10, :])
            # index = torch.from_numpy(index)
            # weight = torch.from_numpy(weight)
            # valid_neighbors = torch.from_numpy(valid_neighbors)
            # index[~valid_neighbors] = 0
            # weight[~valid_neighbors] = 0
            if self.folder:
PaulWawerek-L's avatar
PaulWawerek-L committed
                print(f"Saving file {file_address}")
Navid's avatar
Navid committed
                torch.save({"index": index, "weight": weight, "mask_valid": valid_neighbors}, file_address)

            return index, weight, valid_neighbors

        # for Patch based, we temporary deactivate normalization for the whole data because we want to have the normalization per patch
        tmp_norm = self.normalization_method
        self.normalization_method = "non"
        index, weight, valid_neighbors = self.getStruct(sampling_res=sampling_res, num_hops=num_hops)
        self.normalization_method = tmp_norm # return back to the original normalization

        n_patches, nPix_per_patch = self.getPatchesInfo(sampling_res, patch_res)
PaulWawerek-L's avatar
PaulWawerek-L committed
        assert (patch_id >=0) and (patch_id < n_patches), f"patch_id={patch_id} is not in valid range [0, {n_patches})"
Navid's avatar
Navid committed

        # https://github.com/rusty1s/pytorch_geometric/issues/1205
        # https://github.com/rusty1s/pytorch_geometric/issues/973
        interested_nodes = torch.arange(nPix_per_patch * patch_id, nPix_per_patch * (patch_id + 1), dtype=torch.long)



        if self.cutGraph:
            index = index.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).detach().clone()
            weight = weight.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).detach().clone()
            valid_neighbors = (index >= nPix_per_patch * patch_id) & (index < nPix_per_patch * (patch_id + 1)).detach().clone()
            index -= nPix_per_patch * patch_id

            nodes = interested_nodes
            mapping = None
        else:
            tmp_valid = valid_neighbors.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).clone().detach()
            nodes, inv = index.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch)[tmp_valid].unique(return_inverse=True)
            mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]
PaulWawerek-L's avatar
PaulWawerek-L committed
            index  = index.index_select(dim=0, index=nodes)
Navid's avatar
Navid committed
            weight = weight.index_select(dim=0, index=nodes)
            valid_neighbors = torch.zeros(len(nodes), valid_neighbors.size(1), dtype=torch.bool)
            valid_neighbors[mapping, :] = tmp_valid
            index[valid_neighbors] = inv

        # print("before=", weight[:10, :])
        # print("valid neighbor before=", valid_neighbors[:10, :])
        index, weight = self.__normalize(index, weight, valid_neighbors, self.normalization_method)
        # print("after=", weight[:10, :])
        # print("valid neighbor after=", valid_neighbors[:10, :])
        # index[~valid_neighbors] = 0
        # weight[~valid_neighbors] = 0

        if self.folder:
PaulWawerek-L's avatar
PaulWawerek-L committed
            print(f"Saving file {file_address}")
Navid's avatar
Navid committed
            torch.save({"index": index,
                        "weight": weight,
                        "mask_valid": valid_neighbors,
                        "nodes": nodes,
                        "mapping": mapping},
                       file_address)
        return index, weight, valid_neighbors, nodes, mapping

    def __normalize(self, index, weight, valid_neighbors, normalization_method):
        assert normalization_method in ['non', 'sym', "sym8", 'sym_neighbors','global_directional_avg'], 'normalization_method not defined'

        if not isinstance(index, torch.Tensor):
            index = torch.from_numpy(index)
        if not isinstance(weight, torch.Tensor):
            weight = torch.from_numpy(weight)
        if not isinstance(valid_neighbors, torch.Tensor):
            valid_neighbors = torch.from_numpy(valid_neighbors)

        index[~valid_neighbors] = 0
        weight[~valid_neighbors] = 0

        if normalization_method == "non":
            return index, weight

        if normalization_method == "sym":
            weight.div_(weight.sum(dim=1, keepdim=True))
        elif normalization_method == "sym8":
            weight.div_(weight.sum(dim=1, keepdim=True))
            weight *= 8
        elif normalization_method == "sym_neighbors":
            n_neighbors = valid_neighbors.sum(dim=1, keepdim=True)
            weight.div_(weight.sum(dim=1, keepdim=True))
            weight.mul_(n_neighbors)
        elif normalization_method == "global_directional_avg":
            for col in range(weight.shape[1]):
                weight_col = weight[:, col]
                weight_col.div_(weight_col.sum())
                if self.weight_type == "distance":
                    weight_col = 2. - weight_col
                    raise NotImplementedError("Not sure about it")

        return index, weight


    def getPatchesInfo(self, sampling_res, patch_res):
        assert patch_res <= sampling_res, "patch_res can not be greater than sampling_res"
        nside = hp.order2nside(sampling_res)  # == 2 ** sampling_resolution

PaulWawerek-L's avatar
PaulWawerek-L committed
        if (patch_res is None) or (patch_res < 0):   # Negative value means that the whole sphere is desired
Navid's avatar
Navid committed
            return 1, hp.nside2npix(nside)

        patch_width = hp.order2nside(patch_res)
        nPix_per_patch = patch_width * patch_width
        nside_patch = nside // patch_width
        n_patches = hp.nside2npix(nside_patch)
        return n_patches, nPix_per_patch


PaulWawerek-L's avatar
PaulWawerek-L committed
    def getLayerStructUpsampling(self, scaling_factor_upsampling, hop_upsampling, resolution, patch_resolution=None, patch_id=-1, inputHopFromDownsampling=None):
Navid's avatar
Navid committed
        # print("starting unsampling graph construction", flush=True)
        assert len(scaling_factor_upsampling)==len(hop_upsampling), "list size for scaling factor and hop numbers must be equal"
        nconv_layers = len(scaling_factor_upsampling)
        list_sampling_res_conv, list_patch_res_conv = [[None] * nconv_layers for i in range(2)]
        list_sampling_res_conv[0] = resolution
        list_patch_res_conv[0] = patch_resolution

        patching = False
PaulWawerek-L's avatar
PaulWawerek-L committed
        if (patch_resolution is not None) and (patch_id != -1) and (patch_resolution > 0):
Navid's avatar
Navid committed
            patching = True

        for l in range(1, nconv_layers):
            list_sampling_res_conv[l] = hp_utils.healpix_getResolutionUpsampled(list_sampling_res_conv[l-1], scaling_factor_upsampling[l-1])
            if patching:
                list_patch_res_conv[l] = hp_utils.healpix_getResolutionUpsampled(list_patch_res_conv[l-1], scaling_factor_upsampling[l-1])

        highest_sampling_res = hp_utils.healpix_getResolutionUpsampled(list_sampling_res_conv[-1], scaling_factor_upsampling[-1])
        if patching:
            highest_patch_res = hp_utils.healpix_getResolutionUpsampled(list_patch_res_conv[-1], scaling_factor_upsampling[-1])

        list_index, list_weight, list_mapping_upsampling = [[None] * nconv_layers for i in range(3)]

        K = hop_upsampling.copy()
        if inputHopFromDownsampling is not None:
            K[0] += inputHopFromDownsampling

        if not patching:
            l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
            aggregated_K = np.sum(K[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
            index, weight, _ = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K)
            list_index[l_first], list_weight[l_first] = index, weight
            for l in reversed(range(nconv_layers - 1)):
                if list_sampling_res_conv[l] != list_sampling_res_conv[l+1]:
                    l_first = next((i for i in reversed(range(l+1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
                    aggregated_K = np.sum(K[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                    index, weight, _ = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K)

                    list_index[l_first], list_weight[l_first] = index, weight

            return {"list_sampling_res":list_sampling_res_conv, "list_index":list_index, "list_weight":list_weight, "output_sampling_res":highest_sampling_res}


        if self.cutGraph:    # cutting the graph in the patch part. This means that border nodes lose their connectivity with outside of the patch
            l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
            aggregated_K = np.sum(K[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
            index, weight, _, _, _ = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_patch_res_conv[-1], patch_id=patch_id)
            list_index[l_first], list_weight[l_first] = index, weight
            for l in reversed(range(nconv_layers - 1)):
                if list_sampling_res_conv[l] != list_sampling_res_conv[l + 1]:
                    l_first = next((i for i in reversed(range(l + 1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
                    aggregated_K = np.sum(K[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                    index, weight, _, _, _ = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K, patch_res=list_patch_res_conv[l], patch_id=patch_id)
                    list_index[l_first], list_weight[l_first] = index, weight

            return {"list_sampling_res": list_sampling_res_conv, "list_patch_res": list_patch_res_conv,
                    "list_index": list_index, "list_weight": list_weight,
                    "output_sampling_res": highest_sampling_res, "output_patch_res": highest_patch_res}

        # TODO: This part has not been checked for bugs
        l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
        aggregated_K = np.sum(K[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
        index, weight, _, nodes, mapping = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_patch_res_conv[-1], patch_id=patch_id)

        if highest_sampling_res != list_sampling_res_conv[-1]:
            n_bitshit = 2 * (highest_sampling_res - list_sampling_res_conv[-1])
            n_children = 1 << n_bitshit
            mapping = mapping << n_bitshit
            mapping = mapping.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
            mapping = mapping.flatten()
        list_mapping_upsampling[-1] = mapping
        list_index[l_first], list_weight[l_first] = index, weight

        for l in reversed(range(nconv_layers-1)):
            if list_sampling_res_conv[l] != list_sampling_res_conv[l+1]:
                n_bitshit = 2 * (list_sampling_res_conv[l+1] - list_sampling_res_conv[l])
                parent_nodes = nodes >> n_bitshit
                parent_nodes = parent_nodes.unique()

                l_first = next((i for i in reversed(range(l+1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
                aggregated_K = np.sum(K[l_first:l+1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                index, weight, valid_neighbors = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K)

                index  = index.index_select(0, parent_nodes)
                weight = weight.index_select(0, parent_nodes)
                valid_neighbors = valid_neighbors.index_select(0, parent_nodes)

                parent_nodes, inv = index[valid_neighbors].unique(return_inverse=True)
                index[valid_neighbors] = inv

                index[~valid_neighbors] = 0
                weight[~valid_neighbors] = 0

                n_children = 1 << n_bitshit
                generated_children_nodes_next_layer = parent_nodes << n_bitshit
                generated_children_nodes_next_layer = generated_children_nodes_next_layer.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
                generated_children_nodes_next_layer = generated_children_nodes_next_layer.flatten()
                mapping = (nodes.unsqueeze(1) == generated_children_nodes_next_layer).nonzero()[:, 1]

                nodes = parent_nodes

                list_mapping_upsampling[l] = mapping
                list_index[l_first], list_weight[l_first] = index, weight

        # print("ending unsampling graph construction", flush=True)
        return {"list_sampling_res": list_sampling_res_conv, "list_patch_res": list_patch_res_conv,
                "list_index": list_index, "list_weight": list_weight,
                "list_mapping": list_mapping_upsampling,
                "input_nodes": nodes,
                "output_sampling_res": highest_sampling_res, "output_patch_res": highest_patch_res}


PaulWawerek-L's avatar
PaulWawerek-L committed
    def getLayerStructs(self, scaling_factor_downsampling, hop_downsampling, scaling_factor_upsampling, hop_upsampling, upsampled_resolution, patch_upsampled_resolution=None, patch_id=-1):
Navid's avatar
Navid committed
        assert len(scaling_factor_downsampling) == len(hop_downsampling), "number of layers between scale factor and hops must be equal"
        nlayers_downsampling = len(scaling_factor_downsampling)

        assert len(scaling_factor_upsampling) == len(hop_upsampling), "number of layers between scale factor and hops must be equal"


        patching = False
PaulWawerek-L's avatar
PaulWawerek-L committed
        if (patch_upsampled_resolution is not None) and (patch_id != -1) and (patch_upsampled_resolution > 0):
Navid's avatar
Navid committed
            patching = True

        list_downsampling_res_conv, list_downsampling_patch_res_conv = [[None] * nlayers_downsampling for i in range(2)]
        list_downsampling_res_conv[0] = upsampled_resolution
        list_downsampling_patch_res_conv[0] = patch_upsampled_resolution

        for l in range(1, nlayers_downsampling):
            list_downsampling_res_conv[l] = hp_utils.healpix_getResolutionDownsampled(list_downsampling_res_conv[l-1], scaling_factor_downsampling[l-1])
            if patching:
                list_downsampling_patch_res_conv[l] = hp_utils.healpix_getResolutionDownsampled(list_downsampling_patch_res_conv[l-1], scaling_factor_downsampling[l-1])

        lowest_sampling_res = hp_utils.healpix_getResolutionDownsampled(list_downsampling_res_conv[-1], scaling_factor_downsampling[-1])
        if patching:
            lowest_patch_res = hp_utils.healpix_getResolutionDownsampled(list_downsampling_patch_res_conv[-1], scaling_factor_downsampling[-1])

        list_index_downsampling, list_weight_downsampling, list_mapping_downsampling = [[None] * nlayers_downsampling for i in range(3)]

        lowest_res_aggregated_hop = 0
        if list_downsampling_res_conv[-1] == lowest_sampling_res:
            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            lowest_res_aggregated_hop = np.sum(hop_downsampling[l_first:])

        if not patching:
            dict_graphs = dict()
            dict_graphs["upsampling"] = self.getLayerStructUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, inputHopFromDownsampling=lowest_res_aggregated_hop)
            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            if list_downsampling_res_conv[-1] == lowest_sampling_res:
                index = dict_graphs["upsampling"]["list_index"][0]
                weight = dict_graphs["upsampling"]["list_weight"][0]
            else:
                aggregated_K = np.sum(hop_downsampling[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                index, weight, _ = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K)

            list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
            for l in reversed(range(nlayers_downsampling - 1)):
                if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
                    l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
                    aggregated_K = np.sum(hop_downsampling[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                    index, weight, _ = self.getStruct(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K)
                    list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight

            dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_index":list_index_downsampling, "list_weight":list_weight_downsampling}
            return dict_graphs

        if self.cutGraph:  # cutting the graph in the patch part. This means that border nodes lose their connectivity with outside of the patch
            dict_graphs = dict()
            dict_graphs["upsampling"] = self.getLayerStructUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, patch_resolution=lowest_patch_res, patch_id=patch_id, inputHopFromDownsampling=lowest_res_aggregated_hop)
            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            if list_downsampling_res_conv[-1] == lowest_sampling_res:
                index = dict_graphs["upsampling"]["list_index"][0]
                weight = dict_graphs["upsampling"]["list_weight"][0]
            else:
                aggregated_K = np.sum(hop_downsampling[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                index, weight, _, _, _ = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_downsampling_patch_res_conv[-1], patch_id=patch_id)

            list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
            for l in reversed(range(nlayers_downsampling - 1)):
                if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
                    l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
                    aggregated_K = np.sum(hop_downsampling[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
PaulWawerek-L's avatar
PaulWawerek-L committed
                    index, weight, _, _, _ = self.getStruct(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K, patch_res=list_downsampling_patch_res_conv[l], patch_id=patch_id)
Navid's avatar
Navid committed
                    list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight


            _, nPixPerPatch = self.getPatchesInfo(upsampled_resolution, patch_upsampled_resolution)
            range_downsampling_input_to_patch = (int(patch_id*nPixPerPatch), int((patch_id+1)*nPixPerPatch))

            dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_patch_res":list_downsampling_patch_res_conv,
                                           "list_index": list_index_downsampling, "list_weight": list_weight_downsampling,
                                           "range_downsampling_input_to_patch":range_downsampling_input_to_patch}
            return dict_graphs



        # TODO: This part has not been checked for bugs
        dict_graphs = dict()
        dict_graphs["upsampling"] = self.getLayerGraphUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, patch_resolution=lowest_patch_res, patch_id=patch_id, inputHopFromDownsampling=lowest_res_aggregated_hop)

        # print("starting downsampling graph construction", flush=True)

        nodes = dict_graphs["upsampling"]["input_nodes"]
        index = dict_graphs["upsampling"]["list_index"][0]
        weight = dict_graphs["upsampling"]["list_weight"][0]

        _, nPixPerPatch = self.getPatchesInfo(lowest_sampling_res, lowest_patch_res)
        ind_start = (nodes == patch_id*nPixPerPatch).nonzero().item()  # to find index of the node==patch_id*nPixPerPatch
        # Maybe later I can remove the next assert check.
        assert torch.all(torch.eq(nodes.narrow(dim=0, start=ind_start, length=nPixPerPatch), torch.arange(patch_id*nPixPerPatch, (patch_id+1)*nPixPerPatch, dtype=nodes.dtype))), "patch nodes from upsampling must already contains last resolution patch nodes in a sorted order"
        range_downsampling_output_to_patch = (ind_start, ind_start+nPixPerPatch)


        if list_downsampling_res_conv[-1] == lowest_sampling_res:   # This means that last conv layer of downsampling has same size of first conv layer of upsampling
            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            list_mapping_downsampling[-1] = None # This means that we are in the middle of layer so no mapping is needed
            list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
        else:
            n_bitshit = 2 * (list_downsampling_res_conv[-1] - lowest_sampling_res)
            n_children = 1 << n_bitshit
            interested_nodes = nodes << n_bitshit
            interested_nodes = interested_nodes.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
            interested_nodes = interested_nodes.flatten()

            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            aggregated_K = np.sum(hop_downsampling[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
            index, weight, valid_neighbors = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K)

            index = index.index_select(0, interested_nodes)
            weight = weight.index_select(0, interested_nodes)
            valid_neighbors = valid_neighbors.index_select(0, interested_nodes)

            nodes, inv = index[valid_neighbors].unique(return_inverse=True)
            index[valid_neighbors] = inv
            mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]

            index[~valid_neighbors] = 0
            weight[~valid_neighbors] = 0

            interested_nodes = nodes
            list_mapping_downsampling[-1] = mapping
            list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight

        for l in reversed(range(nlayers_downsampling - 1)):
            if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
                n_bitshit = 2 * (list_downsampling_res_conv[l] - list_downsampling_res_conv[l+1])
                n_children = 1 << n_bitshit
                nodes = nodes << n_bitshit
                interested_nodes = interested_nodes.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
                interested_nodes = interested_nodes.flatten()

                l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
                aggregated_K = np.sum(hop_downsampling[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                index, weight, valid_neighbors = self.getGraph(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K)

                index = index.index_select(0, interested_nodes)
                weight = weight.index_select(0, interested_nodes)
                valid_neighbors = valid_neighbors.index_select(0, interested_nodes)

                nodes, inv = index[valid_neighbors].unique(return_inverse=True)
                index[valid_neighbors] = inv
                mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]

                index[~valid_neighbors] = 0
                weight[~valid_neighbors] = 0

                interested_nodes = nodes

                list_mapping_downsampling[l] = mapping
                list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight

        _, nPixPerPatch = self.getPatchesInfo(upsampled_resolution, patch_upsampled_resolution)
        ind_start = (nodes == patch_id * nPixPerPatch).nonzero().item()  # to find index of the node==patch_id*nPixPerPatch
        # Maybe later I can remove the next assert check.
        assert torch.all(torch.eq(nodes.narrow(dim=0, start=ind_start, length=nPixPerPatch), torch.arange(patch_id * nPixPerPatch, (patch_id + 1) * nPixPerPatch, dtype=nodes.dtype))), "patch nodes from upsampling must already contains last resolution patch nodes in a sorted order"
        range_downsampling_input_to_patch = (ind_start, ind_start+nPixPerPatch)

        # print("ending downsampling graph construction", flush=True)
        dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_patch_res":list_downsampling_patch_res_conv,
                                       "list_index":list_index_downsampling, "list_weight":list_weight_downsampling,
                                       "input_nodes":nodes, "list_mapping":list_mapping_downsampling,
                                       "range_downsampling_output_to_patch":range_downsampling_output_to_patch,
                                       "range_downsampling_input_to_patch":range_downsampling_input_to_patch}

        return dict_graphs