Mentions légales du service

Skip to content
Snippets Groups Projects 28.6 KiB
Newer Older
Navid's avatar
Navid committed
import os
import torch
import numpy as np
import healpy as hp
from utils import healpix as hp_utils
PaulWawerek-L's avatar
PaulWawerek-L committed
# pyright: reportGeneralTypeIssues=warning
Navid's avatar
Navid committed

class HealpixSdpaStructLoader:
    def __init__(self, weight_type, use_geodesic, use_4connectivity, normalization_method, cutGraphForPatchOutside, load_save_folder=None):
        self.weight_type = weight_type
        self.use_geodesic = use_geodesic
        self.use_4connectivity = use_4connectivity
        self.isNest = True
        self.folder = load_save_folder
        self.normalization_method = normalization_method
        self.cutGraph = cutGraphForPatchOutside
        if self.folder:
            os.makedirs(self.folder, exist_ok=True)

PaulWawerek-L's avatar
PaulWawerek-L committed
    def getStruct(self, sampling_res, num_hops, patch_res=None, patch_id=-1) -> tuple:
Navid's avatar
Navid committed
        if (num_hops is None) or (num_hops <= 0):
            num_hops = 1

        if self.folder:
PaulWawerek-L's avatar
PaulWawerek-L committed
            filename = f"sdpa_{self.weight_type}_{self.normalization_method}_{self.use_geodesic}_{self.use_4connectivity}_{sampling_res}_{num_hops}"
Navid's avatar
Navid committed
            if patch_res:
                filename += f"_{patch_res}_{patch_id}_{self.cutGraph}"
Navid's avatar
Navid committed
            filename += ".pth"
            file_address = os.path.join(self.folder, filename)
            if os.path.isfile(file_address):
PaulWawerek-L's avatar
PaulWawerek-L committed
                # print(f"Loading file {file_address}")
Navid's avatar
Navid committed
                data_dict = torch.load(file_address)
                index = data_dict.get("index", None)
                weight = data_dict.get("weight", None)
                valid_neighbors = data_dict.get("mask_valid", None)
                if patch_res is None:
                    return index, weight, valid_neighbors
                nodes = data_dict.get("nodes", None)
                mapping = data_dict.get("mapping", None)
                return index, weight, valid_neighbors, nodes, mapping

        if patch_res is None:
            nside = hp.order2nside(sampling_res)  # == 2 ** sampling_resolution
            nPix = hp.nside2npix(nside)
            pixel_id = np.arange(0, nPix, dtype=int)
Navid's avatar
Navid committed

            index, weight, valid_neighbors = hp_utils.k_hop_healpix_weightmatrix(resolution=sampling_res,
            # print("weights before=", weight[:10,:])
            # print("valid neighbor before=", valid_neighbors[:10, :])
            index, weight = self.__normalize(index, weight, valid_neighbors, self.normalization_method)
            # print("after=", weight[:10, :])
            # print("valid neighbor after=", valid_neighbors[:10, :])
            # index = torch.from_numpy(index)
            # weight = torch.from_numpy(weight)
            # valid_neighbors = torch.from_numpy(valid_neighbors)
            # index[~valid_neighbors] = 0
            # weight[~valid_neighbors] = 0
            if self.folder:
PaulWawerek-L's avatar
PaulWawerek-L committed
                print(f"Saving file {file_address}")
Navid's avatar
Navid committed
      {"index": index, "weight": weight, "mask_valid": valid_neighbors}, file_address)

            return index, weight, valid_neighbors

        # for Patch based, we temporary deactivate normalization for the whole data because we want to have the normalization per patch
        tmp_norm = self.normalization_method
        self.normalization_method = "non"
        index, weight, valid_neighbors = self.getStruct(sampling_res=sampling_res, num_hops=num_hops)
        self.normalization_method = tmp_norm # return back to the original normalization

        n_patches, nPix_per_patch = self.getPatchesInfo(sampling_res, patch_res)
PaulWawerek-L's avatar
PaulWawerek-L committed
        assert (patch_id >=0) and (patch_id < n_patches), f"patch_id={patch_id} is not in valid range [0, {n_patches})"
Navid's avatar
Navid committed

        interested_nodes = torch.arange(nPix_per_patch * patch_id, nPix_per_patch * (patch_id + 1), dtype=torch.long)

        if self.cutGraph:
            index = index.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).detach().clone()
            weight = weight.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).detach().clone()
            valid_neighbors = (index >= nPix_per_patch * patch_id) & (index < nPix_per_patch * (patch_id + 1)).detach().clone()
            index -= nPix_per_patch * patch_id

            nodes = interested_nodes
            mapping = None
            tmp_valid = valid_neighbors.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).clone().detach()
            nodes, inv = index.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch)[tmp_valid].unique(return_inverse=True)
            mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]
PaulWawerek-L's avatar
PaulWawerek-L committed
            index  = index.index_select(dim=0, index=nodes)
Navid's avatar
Navid committed
            weight = weight.index_select(dim=0, index=nodes)
            valid_neighbors = torch.zeros(len(nodes), valid_neighbors.size(1), dtype=torch.bool)
            valid_neighbors[mapping, :] = tmp_valid
            index[valid_neighbors] = inv

        # print("before=", weight[:10, :])
        # print("valid neighbor before=", valid_neighbors[:10, :])
        index, weight = self.__normalize(index, weight, valid_neighbors, self.normalization_method)
        # print("after=", weight[:10, :])
        # print("valid neighbor after=", valid_neighbors[:10, :])
        # index[~valid_neighbors] = 0
        # weight[~valid_neighbors] = 0

        if self.folder:
PaulWawerek-L's avatar
PaulWawerek-L committed
            print(f"Saving file {file_address}")
Navid's avatar
Navid committed
  {"index": index,
                        "weight": weight,
                        "mask_valid": valid_neighbors,
                        "nodes": nodes,
                        "mapping": mapping},
        return index, weight, valid_neighbors, nodes, mapping

    def __normalize(self, index, weight, valid_neighbors, normalization_method):
        assert normalization_method in ['non', 'sym', "sym8", 'sym_neighbors','global_directional_avg'], 'normalization_method not defined'

        if not isinstance(index, torch.Tensor):
            index = torch.from_numpy(index)
        if not isinstance(weight, torch.Tensor):
            weight = torch.from_numpy(weight)
        if not isinstance(valid_neighbors, torch.Tensor):
            valid_neighbors = torch.from_numpy(valid_neighbors)

        index[~valid_neighbors] = 0
        weight[~valid_neighbors] = 0

        if normalization_method == "non":
            return index, weight

        if normalization_method == "sym":
            weight.div_(weight.sum(dim=1, keepdim=True))
        elif normalization_method == "sym8":
            weight.div_(weight.sum(dim=1, keepdim=True))
            weight *= 8
        elif normalization_method == "sym_neighbors":
            n_neighbors = valid_neighbors.sum(dim=1, keepdim=True)
            weight.div_(weight.sum(dim=1, keepdim=True))
        elif normalization_method == "global_directional_avg":
            for col in range(weight.shape[1]):
                weight_col = weight[:, col]
                if self.weight_type == "distance":
                    weight_col = 2. - weight_col
                    raise NotImplementedError("Not sure about it")

        return index, weight

    def getPatchesInfo(self, sampling_res, patch_res):
        assert patch_res <= sampling_res, "patch_res can not be greater than sampling_res"
        nside = hp.order2nside(sampling_res)  # == 2 ** sampling_resolution

PaulWawerek-L's avatar
PaulWawerek-L committed
        if (patch_res is None) or (patch_res < 0):   # Negative value means that the whole sphere is desired
Navid's avatar
Navid committed
            return 1, hp.nside2npix(nside)

        patch_width = hp.order2nside(patch_res)
        nPix_per_patch = patch_width * patch_width
        nside_patch = nside // patch_width
        n_patches = hp.nside2npix(nside_patch)
        return n_patches, nPix_per_patch

PaulWawerek-L's avatar
PaulWawerek-L committed
    def getLayerStructUpsampling(self, scaling_factor_upsampling, hop_upsampling, resolution, patch_resolution=None, patch_id=-1, inputHopFromDownsampling=None):
Navid's avatar
Navid committed
        # print("starting unsampling graph construction", flush=True)
        assert len(scaling_factor_upsampling)==len(hop_upsampling), "list size for scaling factor and hop numbers must be equal"
        nconv_layers = len(scaling_factor_upsampling)
        list_sampling_res_conv, list_patch_res_conv = [[None] * nconv_layers for i in range(2)]
        list_sampling_res_conv[0] = resolution
        list_patch_res_conv[0] = patch_resolution

        patching = False
PaulWawerek-L's avatar
PaulWawerek-L committed
        if (patch_resolution is not None) and (patch_id != -1) and (patch_resolution > 0):
Navid's avatar
Navid committed
            patching = True

        for l in range(1, nconv_layers):
            list_sampling_res_conv[l] = hp_utils.healpix_getResolutionUpsampled(list_sampling_res_conv[l-1], scaling_factor_upsampling[l-1])
            if patching:
                list_patch_res_conv[l] = hp_utils.healpix_getResolutionUpsampled(list_patch_res_conv[l-1], scaling_factor_upsampling[l-1])

        highest_sampling_res = hp_utils.healpix_getResolutionUpsampled(list_sampling_res_conv[-1], scaling_factor_upsampling[-1])
        if patching:
            highest_patch_res = hp_utils.healpix_getResolutionUpsampled(list_patch_res_conv[-1], scaling_factor_upsampling[-1])

        list_index, list_weight, list_mapping_upsampling = [[None] * nconv_layers for i in range(3)]

        K = hop_upsampling.copy()
        if inputHopFromDownsampling is not None:
            K[0] += inputHopFromDownsampling

        if not patching:
            l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
            aggregated_K = np.sum(K[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
            index, weight, _ = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K)
            list_index[l_first], list_weight[l_first] = index, weight
            for l in reversed(range(nconv_layers - 1)):
                if list_sampling_res_conv[l] != list_sampling_res_conv[l+1]:
                    l_first = next((i for i in reversed(range(l+1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
                    aggregated_K = np.sum(K[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                    index, weight, _ = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K)

                    list_index[l_first], list_weight[l_first] = index, weight

            return {"list_sampling_res":list_sampling_res_conv, "list_index":list_index, "list_weight":list_weight, "output_sampling_res":highest_sampling_res}

        if self.cutGraph:    # cutting the graph in the patch part. This means that border nodes lose their connectivity with outside of the patch
            l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
            aggregated_K = np.sum(K[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
            index, weight, _, _, _ = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_patch_res_conv[-1], patch_id=patch_id)
            list_index[l_first], list_weight[l_first] = index, weight
            for l in reversed(range(nconv_layers - 1)):
                if list_sampling_res_conv[l] != list_sampling_res_conv[l + 1]:
                    l_first = next((i for i in reversed(range(l + 1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
                    aggregated_K = np.sum(K[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                    index, weight, _, _, _ = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K, patch_res=list_patch_res_conv[l], patch_id=patch_id)
                    list_index[l_first], list_weight[l_first] = index, weight

            return {"list_sampling_res": list_sampling_res_conv, "list_patch_res": list_patch_res_conv,
                    "list_index": list_index, "list_weight": list_weight,
                    "output_sampling_res": highest_sampling_res, "output_patch_res": highest_patch_res}

        # TODO: This part has not been checked for bugs
        l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
        aggregated_K = np.sum(K[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
        index, weight, _, nodes, mapping = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_patch_res_conv[-1], patch_id=patch_id)

        if highest_sampling_res != list_sampling_res_conv[-1]:
            n_bitshit = 2 * (highest_sampling_res - list_sampling_res_conv[-1])
            n_children = 1 << n_bitshit
            mapping = mapping << n_bitshit
            mapping = mapping.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
            mapping = mapping.flatten()
        list_mapping_upsampling[-1] = mapping
        list_index[l_first], list_weight[l_first] = index, weight

        for l in reversed(range(nconv_layers-1)):
            if list_sampling_res_conv[l] != list_sampling_res_conv[l+1]:
                n_bitshit = 2 * (list_sampling_res_conv[l+1] - list_sampling_res_conv[l])
                parent_nodes = nodes >> n_bitshit
                parent_nodes = parent_nodes.unique()

                l_first = next((i for i in reversed(range(l+1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
                aggregated_K = np.sum(K[l_first:l+1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                index, weight, valid_neighbors = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K)

                index  = index.index_select(0, parent_nodes)
                weight = weight.index_select(0, parent_nodes)
                valid_neighbors = valid_neighbors.index_select(0, parent_nodes)

                parent_nodes, inv = index[valid_neighbors].unique(return_inverse=True)
                index[valid_neighbors] = inv

                index[~valid_neighbors] = 0
                weight[~valid_neighbors] = 0

                n_children = 1 << n_bitshit
                generated_children_nodes_next_layer = parent_nodes << n_bitshit
                generated_children_nodes_next_layer = generated_children_nodes_next_layer.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
                generated_children_nodes_next_layer = generated_children_nodes_next_layer.flatten()
                mapping = (nodes.unsqueeze(1) == generated_children_nodes_next_layer).nonzero()[:, 1]

                nodes = parent_nodes

                list_mapping_upsampling[l] = mapping
                list_index[l_first], list_weight[l_first] = index, weight

        # print("ending unsampling graph construction", flush=True)
        return {"list_sampling_res": list_sampling_res_conv, "list_patch_res": list_patch_res_conv,
                "list_index": list_index, "list_weight": list_weight,
                "list_mapping": list_mapping_upsampling,
                "input_nodes": nodes,
                "output_sampling_res": highest_sampling_res, "output_patch_res": highest_patch_res}

PaulWawerek-L's avatar
PaulWawerek-L committed
    def getLayerStructs(self, scaling_factor_downsampling, hop_downsampling, scaling_factor_upsampling, hop_upsampling, upsampled_resolution, patch_upsampled_resolution=None, patch_id=-1):
Navid's avatar
Navid committed
        assert len(scaling_factor_downsampling) == len(hop_downsampling), "number of layers between scale factor and hops must be equal"
        nlayers_downsampling = len(scaling_factor_downsampling)

        assert len(scaling_factor_upsampling) == len(hop_upsampling), "number of layers between scale factor and hops must be equal"

        patching = False
PaulWawerek-L's avatar
PaulWawerek-L committed
        if (patch_upsampled_resolution is not None) and (patch_id != -1) and (patch_upsampled_resolution > 0):
Navid's avatar
Navid committed
            patching = True

        list_downsampling_res_conv, list_downsampling_patch_res_conv = [[None] * nlayers_downsampling for i in range(2)]
        list_downsampling_res_conv[0] = upsampled_resolution
        list_downsampling_patch_res_conv[0] = patch_upsampled_resolution

        for l in range(1, nlayers_downsampling):
            list_downsampling_res_conv[l] = hp_utils.healpix_getResolutionDownsampled(list_downsampling_res_conv[l-1], scaling_factor_downsampling[l-1])
            if patching:
                list_downsampling_patch_res_conv[l] = hp_utils.healpix_getResolutionDownsampled(list_downsampling_patch_res_conv[l-1], scaling_factor_downsampling[l-1])

        lowest_sampling_res = hp_utils.healpix_getResolutionDownsampled(list_downsampling_res_conv[-1], scaling_factor_downsampling[-1])
        if patching:
            lowest_patch_res = hp_utils.healpix_getResolutionDownsampled(list_downsampling_patch_res_conv[-1], scaling_factor_downsampling[-1])

        list_index_downsampling, list_weight_downsampling, list_mapping_downsampling = [[None] * nlayers_downsampling for i in range(3)]

        lowest_res_aggregated_hop = 0
        if list_downsampling_res_conv[-1] == lowest_sampling_res:
            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            lowest_res_aggregated_hop = np.sum(hop_downsampling[l_first:])

        if not patching:
            dict_graphs = dict()
            dict_graphs["upsampling"] = self.getLayerStructUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, inputHopFromDownsampling=lowest_res_aggregated_hop)
            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            if list_downsampling_res_conv[-1] == lowest_sampling_res:
                index = dict_graphs["upsampling"]["list_index"][0]
                weight = dict_graphs["upsampling"]["list_weight"][0]
                aggregated_K = np.sum(hop_downsampling[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                index, weight, _ = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K)

            list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
            for l in reversed(range(nlayers_downsampling - 1)):
                if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
                    l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
                    aggregated_K = np.sum(hop_downsampling[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                    index, weight, _ = self.getStruct(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K)
                    list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight

            dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_index":list_index_downsampling, "list_weight":list_weight_downsampling}
            return dict_graphs

        if self.cutGraph:  # cutting the graph in the patch part. This means that border nodes lose their connectivity with outside of the patch
            dict_graphs = dict()
            dict_graphs["upsampling"] = self.getLayerStructUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, patch_resolution=lowest_patch_res, patch_id=patch_id, inputHopFromDownsampling=lowest_res_aggregated_hop)
            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            if list_downsampling_res_conv[-1] == lowest_sampling_res:
                index = dict_graphs["upsampling"]["list_index"][0]
                weight = dict_graphs["upsampling"]["list_weight"][0]
                aggregated_K = np.sum(hop_downsampling[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                index, weight, _, _, _ = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_downsampling_patch_res_conv[-1], patch_id=patch_id)

            list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
            for l in reversed(range(nlayers_downsampling - 1)):
                if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
                    l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
                    aggregated_K = np.sum(hop_downsampling[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
PaulWawerek-L's avatar
PaulWawerek-L committed
                    index, weight, _, _, _ = self.getStruct(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K, patch_res=list_downsampling_patch_res_conv[l], patch_id=patch_id)
Navid's avatar
Navid committed
                    list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight

            _, nPixPerPatch = self.getPatchesInfo(upsampled_resolution, patch_upsampled_resolution)
            range_downsampling_input_to_patch = (int(patch_id*nPixPerPatch), int((patch_id+1)*nPixPerPatch))

            dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_patch_res":list_downsampling_patch_res_conv,
                                           "list_index": list_index_downsampling, "list_weight": list_weight_downsampling,
            return dict_graphs

        # TODO: This part has not been checked for bugs
        dict_graphs = dict()
        dict_graphs["upsampling"] = self.getLayerGraphUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, patch_resolution=lowest_patch_res, patch_id=patch_id, inputHopFromDownsampling=lowest_res_aggregated_hop)

        # print("starting downsampling graph construction", flush=True)

        nodes = dict_graphs["upsampling"]["input_nodes"]
        index = dict_graphs["upsampling"]["list_index"][0]
        weight = dict_graphs["upsampling"]["list_weight"][0]

        _, nPixPerPatch = self.getPatchesInfo(lowest_sampling_res, lowest_patch_res)
        ind_start = (nodes == patch_id*nPixPerPatch).nonzero().item()  # to find index of the node==patch_id*nPixPerPatch
        # Maybe later I can remove the next assert check.
        assert torch.all(torch.eq(nodes.narrow(dim=0, start=ind_start, length=nPixPerPatch), torch.arange(patch_id*nPixPerPatch, (patch_id+1)*nPixPerPatch, dtype=nodes.dtype))), "patch nodes from upsampling must already contains last resolution patch nodes in a sorted order"
        range_downsampling_output_to_patch = (ind_start, ind_start+nPixPerPatch)

        if list_downsampling_res_conv[-1] == lowest_sampling_res:   # This means that last conv layer of downsampling has same size of first conv layer of upsampling
            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            list_mapping_downsampling[-1] = None # This means that we are in the middle of layer so no mapping is needed
            list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
            n_bitshit = 2 * (list_downsampling_res_conv[-1] - lowest_sampling_res)
            n_children = 1 << n_bitshit
            interested_nodes = nodes << n_bitshit
            interested_nodes = interested_nodes.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
            interested_nodes = interested_nodes.flatten()

            l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
            aggregated_K = np.sum(hop_downsampling[l_first:])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
            index, weight, valid_neighbors = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K)

            index = index.index_select(0, interested_nodes)
            weight = weight.index_select(0, interested_nodes)
            valid_neighbors = valid_neighbors.index_select(0, interested_nodes)

            nodes, inv = index[valid_neighbors].unique(return_inverse=True)
            index[valid_neighbors] = inv
            mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]

            index[~valid_neighbors] = 0
            weight[~valid_neighbors] = 0

            interested_nodes = nodes
            list_mapping_downsampling[-1] = mapping
            list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight

        for l in reversed(range(nlayers_downsampling - 1)):
            if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
                n_bitshit = 2 * (list_downsampling_res_conv[l] - list_downsampling_res_conv[l+1])
                n_children = 1 << n_bitshit
                nodes = nodes << n_bitshit
                interested_nodes = interested_nodes.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
                interested_nodes = interested_nodes.flatten()

                l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
                aggregated_K = np.sum(hop_downsampling[l_first:l + 1])  # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
                index, weight, valid_neighbors = self.getGraph(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K)

                index = index.index_select(0, interested_nodes)
                weight = weight.index_select(0, interested_nodes)
                valid_neighbors = valid_neighbors.index_select(0, interested_nodes)

                nodes, inv = index[valid_neighbors].unique(return_inverse=True)
                index[valid_neighbors] = inv
                mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]

                index[~valid_neighbors] = 0
                weight[~valid_neighbors] = 0

                interested_nodes = nodes

                list_mapping_downsampling[l] = mapping
                list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight

        _, nPixPerPatch = self.getPatchesInfo(upsampled_resolution, patch_upsampled_resolution)
        ind_start = (nodes == patch_id * nPixPerPatch).nonzero().item()  # to find index of the node==patch_id*nPixPerPatch
        # Maybe later I can remove the next assert check.
        assert torch.all(torch.eq(nodes.narrow(dim=0, start=ind_start, length=nPixPerPatch), torch.arange(patch_id * nPixPerPatch, (patch_id + 1) * nPixPerPatch, dtype=nodes.dtype))), "patch nodes from upsampling must already contains last resolution patch nodes in a sorted order"
        range_downsampling_input_to_patch = (ind_start, ind_start+nPixPerPatch)

        # print("ending downsampling graph construction", flush=True)
        dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_patch_res":list_downsampling_patch_res_conv,
                                       "list_index":list_index_downsampling, "list_weight":list_weight_downsampling,
                                       "input_nodes":nodes, "list_mapping":list_mapping_downsampling,

        return dict_graphs