Newer
Older
import os
import torch
import numpy as np
import healpy as hp
from utils import healpix as hp_utils
class HealpixSdpaStructLoader:
def __init__(self, weight_type, use_geodesic, use_4connectivity, normalization_method, cutGraphForPatchOutside, load_save_folder=None):
self.weight_type = weight_type
self.use_geodesic = use_geodesic
self.use_4connectivity = use_4connectivity
self.isNest = True
self.folder = load_save_folder
self.normalization_method = normalization_method
self.cutGraph = cutGraphForPatchOutside
if self.folder:
os.makedirs(self.folder, exist_ok=True)
def getStruct(self, sampling_res, num_hops, patch_res=None, patch_id=-1) -> tuple:
if (num_hops is None) or (num_hops <= 0):
num_hops = 1
if self.folder:
filename = f"sdpa_{self.weight_type}_{self.normalization_method}_{self.use_geodesic}_{self.use_4connectivity}_{sampling_res}_{num_hops}"
filename += f"_{patch_res}_{patch_id}_{self.cutGraph}"
filename += ".pth"
file_address = os.path.join(self.folder, filename)
if os.path.isfile(file_address):
data_dict = torch.load(file_address)
index = data_dict.get("index", None)
weight = data_dict.get("weight", None)
valid_neighbors = data_dict.get("mask_valid", None)
if patch_res is None:
return index, weight, valid_neighbors
nodes = data_dict.get("nodes", None)
mapping = data_dict.get("mapping", None)
return index, weight, valid_neighbors, nodes, mapping
if patch_res is None:
nside = hp.order2nside(sampling_res) # == 2 ** sampling_resolution
nPix = hp.nside2npix(nside)
index, weight, valid_neighbors = hp_utils.k_hop_healpix_weightmatrix(resolution=sampling_res,
weight_type=self.weight_type,
use_geodesic=self.use_geodesic,
use_4=self.use_4connectivity,
nodes_id=pixel_id,
dtype=np.float32,
nest=self.isNest,
num_hops=num_hops,
)
# print("weights before=", weight[:10,:])
# print("valid neighbor before=", valid_neighbors[:10, :])
index, weight = self.__normalize(index, weight, valid_neighbors, self.normalization_method)
# print("after=", weight[:10, :])
# print("valid neighbor after=", valid_neighbors[:10, :])
# index = torch.from_numpy(index)
# weight = torch.from_numpy(weight)
# valid_neighbors = torch.from_numpy(valid_neighbors)
# index[~valid_neighbors] = 0
# weight[~valid_neighbors] = 0
if self.folder:
torch.save({"index": index, "weight": weight, "mask_valid": valid_neighbors}, file_address)
return index, weight, valid_neighbors
# for Patch based, we temporary deactivate normalization for the whole data because we want to have the normalization per patch
tmp_norm = self.normalization_method
self.normalization_method = "non"
index, weight, valid_neighbors = self.getStruct(sampling_res=sampling_res, num_hops=num_hops)
self.normalization_method = tmp_norm # return back to the original normalization
n_patches, nPix_per_patch = self.getPatchesInfo(sampling_res, patch_res)
assert (patch_id >=0) and (patch_id < n_patches), f"patch_id={patch_id} is not in valid range [0, {n_patches})"
# https://github.com/rusty1s/pytorch_geometric/issues/1205
# https://github.com/rusty1s/pytorch_geometric/issues/973
interested_nodes = torch.arange(nPix_per_patch * patch_id, nPix_per_patch * (patch_id + 1), dtype=torch.long)
if self.cutGraph:
index = index.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).detach().clone()
weight = weight.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).detach().clone()
valid_neighbors = (index >= nPix_per_patch * patch_id) & (index < nPix_per_patch * (patch_id + 1)).detach().clone()
index -= nPix_per_patch * patch_id
nodes = interested_nodes
mapping = None
else:
tmp_valid = valid_neighbors.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch).clone().detach()
nodes, inv = index.narrow(dim=0, start=nPix_per_patch * patch_id, length=nPix_per_patch)[tmp_valid].unique(return_inverse=True)
mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]
weight = weight.index_select(dim=0, index=nodes)
valid_neighbors = torch.zeros(len(nodes), valid_neighbors.size(1), dtype=torch.bool)
valid_neighbors[mapping, :] = tmp_valid
index[valid_neighbors] = inv
# print("before=", weight[:10, :])
# print("valid neighbor before=", valid_neighbors[:10, :])
index, weight = self.__normalize(index, weight, valid_neighbors, self.normalization_method)
# print("after=", weight[:10, :])
# print("valid neighbor after=", valid_neighbors[:10, :])
# index[~valid_neighbors] = 0
# weight[~valid_neighbors] = 0
if self.folder:
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
torch.save({"index": index,
"weight": weight,
"mask_valid": valid_neighbors,
"nodes": nodes,
"mapping": mapping},
file_address)
return index, weight, valid_neighbors, nodes, mapping
def __normalize(self, index, weight, valid_neighbors, normalization_method):
assert normalization_method in ['non', 'sym', "sym8", 'sym_neighbors','global_directional_avg'], 'normalization_method not defined'
if not isinstance(index, torch.Tensor):
index = torch.from_numpy(index)
if not isinstance(weight, torch.Tensor):
weight = torch.from_numpy(weight)
if not isinstance(valid_neighbors, torch.Tensor):
valid_neighbors = torch.from_numpy(valid_neighbors)
index[~valid_neighbors] = 0
weight[~valid_neighbors] = 0
if normalization_method == "non":
return index, weight
if normalization_method == "sym":
weight.div_(weight.sum(dim=1, keepdim=True))
elif normalization_method == "sym8":
weight.div_(weight.sum(dim=1, keepdim=True))
weight *= 8
elif normalization_method == "sym_neighbors":
n_neighbors = valid_neighbors.sum(dim=1, keepdim=True)
weight.div_(weight.sum(dim=1, keepdim=True))
weight.mul_(n_neighbors)
elif normalization_method == "global_directional_avg":
for col in range(weight.shape[1]):
weight_col = weight[:, col]
weight_col.div_(weight_col.sum())
if self.weight_type == "distance":
weight_col = 2. - weight_col
raise NotImplementedError("Not sure about it")
return index, weight
def getPatchesInfo(self, sampling_res, patch_res):
assert patch_res <= sampling_res, "patch_res can not be greater than sampling_res"
nside = hp.order2nside(sampling_res) # == 2 ** sampling_resolution
if (patch_res is None) or (patch_res < 0): # Negative value means that the whole sphere is desired
return 1, hp.nside2npix(nside)
patch_width = hp.order2nside(patch_res)
nPix_per_patch = patch_width * patch_width
nside_patch = nside // patch_width
n_patches = hp.nside2npix(nside_patch)
return n_patches, nPix_per_patch
def getLayerStructUpsampling(self, scaling_factor_upsampling, hop_upsampling, resolution, patch_resolution=None, patch_id=-1, inputHopFromDownsampling=None):
# print("starting unsampling graph construction", flush=True)
assert len(scaling_factor_upsampling)==len(hop_upsampling), "list size for scaling factor and hop numbers must be equal"
nconv_layers = len(scaling_factor_upsampling)
list_sampling_res_conv, list_patch_res_conv = [[None] * nconv_layers for i in range(2)]
list_sampling_res_conv[0] = resolution
list_patch_res_conv[0] = patch_resolution
patching = False
if (patch_resolution is not None) and (patch_id != -1) and (patch_resolution > 0):
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
patching = True
for l in range(1, nconv_layers):
list_sampling_res_conv[l] = hp_utils.healpix_getResolutionUpsampled(list_sampling_res_conv[l-1], scaling_factor_upsampling[l-1])
if patching:
list_patch_res_conv[l] = hp_utils.healpix_getResolutionUpsampled(list_patch_res_conv[l-1], scaling_factor_upsampling[l-1])
highest_sampling_res = hp_utils.healpix_getResolutionUpsampled(list_sampling_res_conv[-1], scaling_factor_upsampling[-1])
if patching:
highest_patch_res = hp_utils.healpix_getResolutionUpsampled(list_patch_res_conv[-1], scaling_factor_upsampling[-1])
list_index, list_weight, list_mapping_upsampling = [[None] * nconv_layers for i in range(3)]
K = hop_upsampling.copy()
if inputHopFromDownsampling is not None:
K[0] += inputHopFromDownsampling
if not patching:
l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(K[l_first:]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _ = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K)
list_index[l_first], list_weight[l_first] = index, weight
for l in reversed(range(nconv_layers - 1)):
if list_sampling_res_conv[l] != list_sampling_res_conv[l+1]:
l_first = next((i for i in reversed(range(l+1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(K[l_first:l + 1]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _ = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K)
list_index[l_first], list_weight[l_first] = index, weight
return {"list_sampling_res":list_sampling_res_conv, "list_index":list_index, "list_weight":list_weight, "output_sampling_res":highest_sampling_res}
if self.cutGraph: # cutting the graph in the patch part. This means that border nodes lose their connectivity with outside of the patch
l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(K[l_first:]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _, _, _ = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_patch_res_conv[-1], patch_id=patch_id)
list_index[l_first], list_weight[l_first] = index, weight
for l in reversed(range(nconv_layers - 1)):
if list_sampling_res_conv[l] != list_sampling_res_conv[l + 1]:
l_first = next((i for i in reversed(range(l + 1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(K[l_first:l + 1]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _, _, _ = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K, patch_res=list_patch_res_conv[l], patch_id=patch_id)
list_index[l_first], list_weight[l_first] = index, weight
return {"list_sampling_res": list_sampling_res_conv, "list_patch_res": list_patch_res_conv,
"list_index": list_index, "list_weight": list_weight,
"output_sampling_res": highest_sampling_res, "output_patch_res": highest_patch_res}
# TODO: This part has not been checked for bugs
l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(K[l_first:]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _, nodes, mapping = self.getStruct(sampling_res=list_sampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_patch_res_conv[-1], patch_id=patch_id)
if highest_sampling_res != list_sampling_res_conv[-1]:
n_bitshit = 2 * (highest_sampling_res - list_sampling_res_conv[-1])
n_children = 1 << n_bitshit
mapping = mapping << n_bitshit
mapping = mapping.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
mapping = mapping.flatten()
list_mapping_upsampling[-1] = mapping
list_index[l_first], list_weight[l_first] = index, weight
for l in reversed(range(nconv_layers-1)):
if list_sampling_res_conv[l] != list_sampling_res_conv[l+1]:
n_bitshit = 2 * (list_sampling_res_conv[l+1] - list_sampling_res_conv[l])
parent_nodes = nodes >> n_bitshit
parent_nodes = parent_nodes.unique()
l_first = next((i for i in reversed(range(l+1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(K[l_first:l+1]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, valid_neighbors = self.getStruct(sampling_res=list_sampling_res_conv[l], num_hops=aggregated_K)
index = index.index_select(0, parent_nodes)
weight = weight.index_select(0, parent_nodes)
valid_neighbors = valid_neighbors.index_select(0, parent_nodes)
parent_nodes, inv = index[valid_neighbors].unique(return_inverse=True)
index[valid_neighbors] = inv
index[~valid_neighbors] = 0
weight[~valid_neighbors] = 0
n_children = 1 << n_bitshit
generated_children_nodes_next_layer = parent_nodes << n_bitshit
generated_children_nodes_next_layer = generated_children_nodes_next_layer.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
generated_children_nodes_next_layer = generated_children_nodes_next_layer.flatten()
mapping = (nodes.unsqueeze(1) == generated_children_nodes_next_layer).nonzero()[:, 1]
nodes = parent_nodes
list_mapping_upsampling[l] = mapping
list_index[l_first], list_weight[l_first] = index, weight
# print("ending unsampling graph construction", flush=True)
return {"list_sampling_res": list_sampling_res_conv, "list_patch_res": list_patch_res_conv,
"list_index": list_index, "list_weight": list_weight,
"list_mapping": list_mapping_upsampling,
"input_nodes": nodes,
"output_sampling_res": highest_sampling_res, "output_patch_res": highest_patch_res}
def getLayerStructs(self, scaling_factor_downsampling, hop_downsampling, scaling_factor_upsampling, hop_upsampling, upsampled_resolution, patch_upsampled_resolution=None, patch_id=-1):
assert len(scaling_factor_downsampling) == len(hop_downsampling), "number of layers between scale factor and hops must be equal"
nlayers_downsampling = len(scaling_factor_downsampling)
assert len(scaling_factor_upsampling) == len(hop_upsampling), "number of layers between scale factor and hops must be equal"
patching = False
if (patch_upsampled_resolution is not None) and (patch_id != -1) and (patch_upsampled_resolution > 0):
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
patching = True
list_downsampling_res_conv, list_downsampling_patch_res_conv = [[None] * nlayers_downsampling for i in range(2)]
list_downsampling_res_conv[0] = upsampled_resolution
list_downsampling_patch_res_conv[0] = patch_upsampled_resolution
for l in range(1, nlayers_downsampling):
list_downsampling_res_conv[l] = hp_utils.healpix_getResolutionDownsampled(list_downsampling_res_conv[l-1], scaling_factor_downsampling[l-1])
if patching:
list_downsampling_patch_res_conv[l] = hp_utils.healpix_getResolutionDownsampled(list_downsampling_patch_res_conv[l-1], scaling_factor_downsampling[l-1])
lowest_sampling_res = hp_utils.healpix_getResolutionDownsampled(list_downsampling_res_conv[-1], scaling_factor_downsampling[-1])
if patching:
lowest_patch_res = hp_utils.healpix_getResolutionDownsampled(list_downsampling_patch_res_conv[-1], scaling_factor_downsampling[-1])
list_index_downsampling, list_weight_downsampling, list_mapping_downsampling = [[None] * nlayers_downsampling for i in range(3)]
lowest_res_aggregated_hop = 0
if list_downsampling_res_conv[-1] == lowest_sampling_res:
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
lowest_res_aggregated_hop = np.sum(hop_downsampling[l_first:])
if not patching:
dict_graphs = dict()
dict_graphs["upsampling"] = self.getLayerStructUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, inputHopFromDownsampling=lowest_res_aggregated_hop)
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
if list_downsampling_res_conv[-1] == lowest_sampling_res:
index = dict_graphs["upsampling"]["list_index"][0]
weight = dict_graphs["upsampling"]["list_weight"][0]
else:
aggregated_K = np.sum(hop_downsampling[l_first:]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _ = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K)
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
for l in reversed(range(nlayers_downsampling - 1)):
if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(hop_downsampling[l_first:l + 1]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _ = self.getStruct(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K)
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_index":list_index_downsampling, "list_weight":list_weight_downsampling}
return dict_graphs
if self.cutGraph: # cutting the graph in the patch part. This means that border nodes lose their connectivity with outside of the patch
dict_graphs = dict()
dict_graphs["upsampling"] = self.getLayerStructUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, patch_resolution=lowest_patch_res, patch_id=patch_id, inputHopFromDownsampling=lowest_res_aggregated_hop)
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
if list_downsampling_res_conv[-1] == lowest_sampling_res:
index = dict_graphs["upsampling"]["list_index"][0]
weight = dict_graphs["upsampling"]["list_weight"][0]
else:
aggregated_K = np.sum(hop_downsampling[l_first:]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _, _, _ = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K, patch_res=list_downsampling_patch_res_conv[-1], patch_id=patch_id)
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
for l in reversed(range(nlayers_downsampling - 1)):
if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(hop_downsampling[l_first:l + 1]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, _, _, _ = self.getStruct(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K, patch_res=list_downsampling_patch_res_conv[l], patch_id=patch_id)
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
_, nPixPerPatch = self.getPatchesInfo(upsampled_resolution, patch_upsampled_resolution)
range_downsampling_input_to_patch = (int(patch_id*nPixPerPatch), int((patch_id+1)*nPixPerPatch))
dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_patch_res":list_downsampling_patch_res_conv,
"list_index": list_index_downsampling, "list_weight": list_weight_downsampling,
"range_downsampling_input_to_patch":range_downsampling_input_to_patch}
return dict_graphs
# TODO: This part has not been checked for bugs
dict_graphs = dict()
dict_graphs["upsampling"] = self.getLayerGraphUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, patch_resolution=lowest_patch_res, patch_id=patch_id, inputHopFromDownsampling=lowest_res_aggregated_hop)
# print("starting downsampling graph construction", flush=True)
nodes = dict_graphs["upsampling"]["input_nodes"]
index = dict_graphs["upsampling"]["list_index"][0]
weight = dict_graphs["upsampling"]["list_weight"][0]
_, nPixPerPatch = self.getPatchesInfo(lowest_sampling_res, lowest_patch_res)
ind_start = (nodes == patch_id*nPixPerPatch).nonzero().item() # to find index of the node==patch_id*nPixPerPatch
# Maybe later I can remove the next assert check.
assert torch.all(torch.eq(nodes.narrow(dim=0, start=ind_start, length=nPixPerPatch), torch.arange(patch_id*nPixPerPatch, (patch_id+1)*nPixPerPatch, dtype=nodes.dtype))), "patch nodes from upsampling must already contains last resolution patch nodes in a sorted order"
range_downsampling_output_to_patch = (ind_start, ind_start+nPixPerPatch)
if list_downsampling_res_conv[-1] == lowest_sampling_res: # This means that last conv layer of downsampling has same size of first conv layer of upsampling
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
list_mapping_downsampling[-1] = None # This means that we are in the middle of layer so no mapping is needed
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
else:
n_bitshit = 2 * (list_downsampling_res_conv[-1] - lowest_sampling_res)
n_children = 1 << n_bitshit
interested_nodes = nodes << n_bitshit
interested_nodes = interested_nodes.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
interested_nodes = interested_nodes.flatten()
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(hop_downsampling[l_first:]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, valid_neighbors = self.getStruct(sampling_res=list_downsampling_res_conv[-1], num_hops=aggregated_K)
index = index.index_select(0, interested_nodes)
weight = weight.index_select(0, interested_nodes)
valid_neighbors = valid_neighbors.index_select(0, interested_nodes)
nodes, inv = index[valid_neighbors].unique(return_inverse=True)
index[valid_neighbors] = inv
mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]
index[~valid_neighbors] = 0
weight[~valid_neighbors] = 0
interested_nodes = nodes
list_mapping_downsampling[-1] = mapping
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
for l in reversed(range(nlayers_downsampling - 1)):
if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
n_bitshit = 2 * (list_downsampling_res_conv[l] - list_downsampling_res_conv[l+1])
n_children = 1 << n_bitshit
nodes = nodes << n_bitshit
interested_nodes = interested_nodes.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
interested_nodes = interested_nodes.flatten()
l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(hop_downsampling[l_first:l + 1]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, valid_neighbors = self.getGraph(sampling_res=list_downsampling_res_conv[l], num_hops=aggregated_K)
index = index.index_select(0, interested_nodes)
weight = weight.index_select(0, interested_nodes)
valid_neighbors = valid_neighbors.index_select(0, interested_nodes)
nodes, inv = index[valid_neighbors].unique(return_inverse=True)
index[valid_neighbors] = inv
mapping = (nodes.unsqueeze(1) == interested_nodes).nonzero()[:, 0]
index[~valid_neighbors] = 0
weight[~valid_neighbors] = 0
interested_nodes = nodes
list_mapping_downsampling[l] = mapping
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
_, nPixPerPatch = self.getPatchesInfo(upsampled_resolution, patch_upsampled_resolution)
ind_start = (nodes == patch_id * nPixPerPatch).nonzero().item() # to find index of the node==patch_id*nPixPerPatch
# Maybe later I can remove the next assert check.
assert torch.all(torch.eq(nodes.narrow(dim=0, start=ind_start, length=nPixPerPatch), torch.arange(patch_id * nPixPerPatch, (patch_id + 1) * nPixPerPatch, dtype=nodes.dtype))), "patch nodes from upsampling must already contains last resolution patch nodes in a sorted order"
range_downsampling_input_to_patch = (ind_start, ind_start+nPixPerPatch)
# print("ending downsampling graph construction", flush=True)
dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_patch_res":list_downsampling_patch_res_conv,
"list_index":list_index_downsampling, "list_weight":list_weight_downsampling,
"input_nodes":nodes, "list_mapping":list_mapping_downsampling,
"range_downsampling_output_to_patch":range_downsampling_output_to_patch,
"range_downsampling_input_to_patch":range_downsampling_input_to_patch}
return dict_graphs