Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import os
import torch
import numpy as np
import healpy as hp
import torch_geometric
from utils import healpix as hp_utils
class HealpixGraphLoader:
def __init__(self, weight_type, use_geodesic, use_4connectivity, load_save_folder=None):
self.weight_type = weight_type
self.use_geodesic = use_geodesic
self.use_4connectivity = use_4connectivity
self.isNest = True
self.folder = load_save_folder
if self.folder:
os.makedirs(self.folder, exist_ok=True)
def getGraph(self, sampling_res, patch_res=None, num_hops=None, patch_id=None):
if self.folder:
filename = "{}_{}_{}_{}".format(self.weight_type, self.use_geodesic, self.use_4connectivity, sampling_res)
if patch_res:
filename = filename + "_{}_{}_{}".format(patch_res, num_hops, patch_id)
filename += ".pth"
file_address = os.path.join(self.folder, filename)
if os.path.isfile(file_address):
# print("Loading file {}".format(file_address))
data_dict = torch.load(file_address)
index = data_dict.get("index", None)
weight = data_dict.get("weight", None)
if patch_res is None:
return index, weight
nodes = data_dict.get("nodes", None)
mapping = data_dict.get("mapping", None)
return index, weight, nodes, mapping
if patch_res is None:
nside = hp.order2nside(sampling_res) # == 2 ** sampling_resolution
nPix = hp.nside2npix(nside)
pixel_id = np.arange(0, nPix, dtype=np.int)
graph_structure, _ = hp_utils.healpix_weightmatrix(resolution=sampling_res,
weight_type=self.weight_type,
use_geodesic=self.use_geodesic,
use_4=self.use_4connectivity,
nodes_id=pixel_id,
dtype=np.float32,
nest=self.isNest,
)
index = torch.from_numpy(graph_structure[0])
weight = torch.from_numpy(graph_structure[1])
if self.folder:
print("Saving file {}".format(file_address))
torch.save({"index": index, "weight": weight}, file_address)
return index, weight
index, weight = self.getGraph(sampling_res=sampling_res)
# patch_res is not None
if num_hops is None:
raise ValueError("num_hops must be given when we are splitting the graph to patches")
n_patches, nPix_per_patch = self.getPatchesInfo(sampling_res, patch_res)
assert patch_id >=0 and patch_id < n_patches, "patch_id={} is not in valid range [0, {})".format(patch_id, n_patches)
# https://github.com/rusty1s/pytorch_geometric/issues/1205
# https://github.com/rusty1s/pytorch_geometric/issues/973
interested_nodes = torch.arange(nPix_per_patch * patch_id, nPix_per_patch * (patch_id + 1), dtype=torch.long)
subset, sub_edge_index, mapping, edge_mask = torch_geometric.utils.k_hop_subgraph(interested_nodes,
edge_index=index,
num_hops=num_hops,
relabel_nodes=True)
sub_edge_weight = weight[edge_mask]
if self.folder:
print("Saving file {}".format(file_address))
torch.save({"index": sub_edge_index,
"weight": sub_edge_weight,
"nodes": subset,
"mapping": mapping},
file_address)
return sub_edge_index, sub_edge_weight, subset, mapping
def getMapHopToHop(self, sampling_res, patch_res, exterior_hop_number, patch_id, interior_hop_number):
assert exterior_hop_number >= interior_hop_number, "num_hops_larger must be greater than num_hops_smaller"
if self.folder:
# Note self.weight_type, self.use_geodesic does not change anything in the mapping. So we can use
filename = "hopToHop_{}_{}_{}_{}_{}".format(self.use_4connectivity, sampling_res, patch_res, patch_id, exterior_hop_number, interior_hop_number)
filename += ".pth"
file_address = os.path.join(self.folder, filename)
if os.path.isfile(file_address):
# print("Loading hop to hop file {}".format(file_address))
data_dict = torch.load(file_address)
mapping = data_dict.get("map_hop_to_hop", None)
return mapping
_, _, interested_nodes, _ = self.getGraph(sampling_res=sampling_res, patch_res=patch_res, num_hops=interior_hop_number, patch_id=patch_id)
index, _ = self.getGraph(sampling_res=sampling_res)
_, _, mapping, _ = torch_geometric.utils.k_hop_subgraph(interested_nodes, edge_index=index,
num_hops=exterior_hop_number - interior_hop_number,
relabel_nodes=True)
if self.folder:
print("Saving hop to hop file {}".format(file_address))
torch.save({"map_hop_to_hop": mapping},
file_address)
return mapping
def getPatchesInfo(self, sampling_res, patch_res):
nside = hp.order2nside(sampling_res) # == 2 ** sampling_resolution
if patch_res is None or patch_res < 0: # Negative value means that the whole sphere is desired
return 1, hp.nside2npix(nside)
patch_width = hp.order2nside(patch_res)
nPix_per_patch = patch_width * patch_width
nside_patch = nside // patch_width
n_patches = hp.nside2npix(nside_patch)
return n_patches, nPix_per_patch
def getLayerGraphUpsampling(self, scaling_factor_upsampling, hop_upsampling, resolution, patch_resolution=None, patch_id=None, inputHopFromDownsampling=None):
# print("starting unsampling graph construction", flush=True)
assert len(scaling_factor_upsampling) == len(hop_upsampling), "list size for scaling factor and hop numbers must be equal"
nconv_layers = len(scaling_factor_upsampling)
list_sampling_res_conv, list_patch_res_conv = [[None] * nconv_layers for i in range(2)]
list_sampling_res_conv[0] = resolution
list_patch_res_conv[0] = patch_resolution
patching = False
if all(v is not None for v in [patch_resolution, patch_id]) and (patch_resolution > 0):
patching = True
for l in range(1, nconv_layers):
list_sampling_res_conv[l] = hp_utils.healpix_getResolutionUpsampled(list_sampling_res_conv[l-1], scaling_factor_upsampling[l-1])
if patching:
list_patch_res_conv[l] = hp_utils.healpix_getResolutionUpsampled(list_patch_res_conv[l-1], scaling_factor_upsampling[l-1])
highest_sampling_res = hp_utils.healpix_getResolutionUpsampled(list_sampling_res_conv[-1], scaling_factor_upsampling[-1])
if patching:
highest_patch_res = hp_utils.healpix_getResolutionUpsampled(list_patch_res_conv[-1], scaling_factor_upsampling[-1])
list_index, list_weight, list_mapping_upsampling = [[None] * nconv_layers for i in range(3)]
if not patching:
index, weight = self.getGraph(sampling_res=list_sampling_res_conv[-1])
l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
list_index[l_first], list_weight[l_first] = index, weight
for l in reversed(range(nconv_layers - 1)):
if list_sampling_res_conv[l] != list_sampling_res_conv[l+1]:
index, weight = self.getGraph(sampling_res=list_sampling_res_conv[l])
l_first = next((i for i in reversed(range(l+1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
list_index[l_first], list_weight[l_first] = index, weight
return {"list_sampling_res":list_sampling_res_conv, "list_index":list_index, "list_weight":list_weight, "output_sampling_res":highest_sampling_res}
if all(v<0 for v in hop_upsampling): # cutting the graph in the patch part. This means that border nodes lose their connectivity with outside of the patch
index, weight, _, _ = self.getGraph(sampling_res=list_sampling_res_conv[-1], patch_res=list_patch_res_conv[-1], num_hops=0, patch_id=patch_id)
l_first = next( (i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
list_index[l_first], list_weight[l_first] = index, weight
for l in reversed(range(nconv_layers - 1)):
if list_sampling_res_conv[l] != list_sampling_res_conv[l + 1]:
index, weight, _, _ = self.getGraph(sampling_res=list_sampling_res_conv[l], patch_res=list_patch_res_conv[l], num_hops=0, patch_id=patch_id)
l_first = next( (i for i in reversed(range(l + 1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
list_index[l_first], list_weight[l_first] = index, weight
return {"list_sampling_res": list_sampling_res_conv, "list_patch_res": list_patch_res_conv,
"list_index": list_index, "list_weight": list_weight,
"output_sampling_res": highest_sampling_res, "output_patch_res": highest_patch_res}
K = hop_upsampling.copy()
if inputHopFromDownsampling is not None:
K[0] += inputHopFromDownsampling
l_first = next((i for i in reversed(range(nconv_layers)) if list_sampling_res_conv[-1] != list_sampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(K[l_first:]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
index, weight, nodes, mapping = self.getGraph(sampling_res=list_sampling_res_conv[-1], patch_res=list_patch_res_conv[-1], num_hops=aggregated_K + 1, patch_id=patch_id)
if highest_sampling_res != list_sampling_res_conv[-1]:
n_bitshit = 2 * (highest_sampling_res - list_sampling_res_conv[-1])
n_children = 1 << n_bitshit
mapping = mapping << n_bitshit
mapping = mapping.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
mapping = mapping.flatten()
list_mapping_upsampling[-1] = mapping
list_index[l_first], list_weight[l_first] = index, weight
for l in reversed(range(nconv_layers-1)):
if list_sampling_res_conv[l] != list_sampling_res_conv[l+1]:
n_bitshit = 2 * (list_sampling_res_conv[l+1] - list_sampling_res_conv[l])
parent_nodes = nodes >> n_bitshit
# parent_nodes = parent_nodes.unique()
index, weight = self.getGraph(sampling_res=list_sampling_res_conv[l])
l_first = next((i for i in reversed(range(l+1)) if list_sampling_res_conv[l] != list_sampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(K[l_first:l+1]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
parent_nodes, index, _, edge_mask = torch_geometric.utils.k_hop_subgraph(parent_nodes, edge_index=index,
num_hops=aggregated_K + 1,
relabel_nodes=True)
weight = weight[edge_mask]
n_children = 1 << n_bitshit
generated_children_nodes_next_layer = parent_nodes << n_bitshit
generated_children_nodes_next_layer = generated_children_nodes_next_layer.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
generated_children_nodes_next_layer = generated_children_nodes_next_layer.flatten()
mapping = (nodes.unsqueeze(1) == generated_children_nodes_next_layer).nonzero()[:, 1]
nodes = parent_nodes
list_mapping_upsampling[l] = mapping
list_index[l_first], list_weight[l_first] = index, weight
# print("ending unsampling graph construction", flush=True)
return {"list_sampling_res": list_sampling_res_conv, "list_patch_res": list_patch_res_conv,
"list_index": list_index, "list_weight": list_weight,
"list_mapping": list_mapping_upsampling,
"input_nodes": nodes,
"output_sampling_res": highest_sampling_res, "output_patch_res": highest_patch_res}
def getLayerGraphs(self, scaling_factor_downsampling, hop_downsampling, scaling_factor_upsampling, hop_upsampling, upsampled_resolution, patch_upsampled_resolution=None, patch_id=None):
assert len(scaling_factor_downsampling) == len(hop_downsampling), "number of layers between scale factor and hops must be equal"
nlayers_downsampling = len(scaling_factor_downsampling)
assert len(scaling_factor_upsampling) == len(hop_upsampling), "number of layers between scale factor and hops must be equal"
patching = False
if all(v is not None for v in [patch_upsampled_resolution, patch_id]) and (patch_upsampled_resolution > 0):
patching = True
list_downsampling_res_conv, list_downsampling_patch_res_conv = [[None] * nlayers_downsampling for i in range(2)]
list_downsampling_res_conv[0] = upsampled_resolution
list_downsampling_patch_res_conv[0] = patch_upsampled_resolution
for l in range(1, nlayers_downsampling):
list_downsampling_res_conv[l] = hp_utils.healpix_getResolutionDownsampled(list_downsampling_res_conv[l-1], scaling_factor_downsampling[l-1])
if patching:
list_downsampling_patch_res_conv[l] = hp_utils.healpix_getResolutionDownsampled(list_downsampling_patch_res_conv[l-1], scaling_factor_downsampling[l-1])
lowest_sampling_res = hp_utils.healpix_getResolutionDownsampled(list_downsampling_res_conv[-1], scaling_factor_downsampling[-1])
if patching:
lowest_patch_res = hp_utils.healpix_getResolutionDownsampled(list_downsampling_patch_res_conv[-1], scaling_factor_downsampling[-1])
list_index_downsampling, list_weight_downsampling, list_mapping_downsampling = [[None] * nlayers_downsampling for i in range(3)]
assert all(v < 0 for v in hop_downsampling) == all(v < 0 for v in hop_upsampling), "for cutting graph both downsampling and upsamling hops must have negative elements"
if not patching:
dict_graphs = dict()
dict_graphs["upsampling"] = self.getLayerGraphUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res)
index, weight = self.getGraph(sampling_res=list_downsampling_res_conv[-1])
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
for l in reversed(range(nlayers_downsampling - 1)):
if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
index, weight = self.getGraph(sampling_res=list_downsampling_res_conv[l])
l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_index":list_index_downsampling, "list_weight":list_weight_downsampling}
return dict_graphs
if all(v < 0 for v in hop_downsampling): # cutting the graph in the patch part. This means that border nodes lose their connectivity with outside of the patch
dict_graphs = dict()
dict_graphs["upsampling"] = self.getLayerGraphUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, patch_resolution=lowest_patch_res, patch_id=patch_id)
index, weight, node_ids, _ = self.getGraph(sampling_res=list_downsampling_res_conv[-1], patch_res=list_downsampling_patch_res_conv[-1], num_hops=0, patch_id=patch_id)
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
for l in reversed(range(nlayers_downsampling - 1)):
if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
index, weight, node_ids, _ = self.getGraph(sampling_res=list_downsampling_res_conv[l], patch_res=list_downsampling_patch_res_conv[l], num_hops=0, patch_id=patch_id)
l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
_, nPixPerPatch = self.getPatchesInfo(upsampled_resolution, patch_upsampled_resolution)
range_downsampling_input_to_patch = (int(patch_id*nPixPerPatch), int((patch_id+1)*nPixPerPatch))
# Maybe later I can remove the next assert check.
assert torch.all(torch.eq(node_ids, torch.arange(range_downsampling_input_to_patch[0], range_downsampling_input_to_patch[1], dtype=node_ids.dtype))), "node_ids must match range"
dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_patch_res":list_downsampling_patch_res_conv,
"list_index": list_index_downsampling, "list_weight": list_weight_downsampling,
"range_downsampling_input_to_patch":range_downsampling_input_to_patch}
return dict_graphs
lowest_res_aggregated_hop = 0
if list_downsampling_res_conv[-1] == lowest_sampling_res:
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
lowest_res_aggregated_hop = np.sum(hop_downsampling[l_first:])
dict_graphs = dict()
dict_graphs["upsampling"] = self.getLayerGraphUpsampling(scaling_factor_upsampling, hop_upsampling, lowest_sampling_res, patch_resolution=lowest_patch_res, patch_id=patch_id, inputHopFromDownsampling=lowest_res_aggregated_hop)
# print("starting downsampling graph construction", flush=True)
nodes = dict_graphs["upsampling"]["input_nodes"]
index = dict_graphs["upsampling"]["list_index"][0]
weight = dict_graphs["upsampling"]["list_weight"][0]
_, nPixPerPatch = self.getPatchesInfo(lowest_sampling_res, lowest_patch_res)
ind_start = (nodes == patch_id*nPixPerPatch).nonzero().item() # to find index of the node==patch_id*nPixPerPatch
# Maybe later I can remove the next assert check.
assert torch.all(torch.eq(nodes.narrow(dim=0, start=ind_start, length=nPixPerPatch), torch.arange(patch_id*nPixPerPatch, (patch_id+1)*nPixPerPatch, dtype=nodes.dtype))), "patch nodes from upsampling must already contains last resolution patch nodes in a sorted order"
range_downsampling_output_to_patch = (ind_start, ind_start+nPixPerPatch)
if list_downsampling_res_conv[-1] == lowest_sampling_res: # This means that last conv layer of downsampling has same size of first conv layer of upsampling
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
list_mapping_downsampling[-1] = None # This means that we are in the middle of layer so no mapping is needed
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
else:
n_bitshit = 2 * (list_downsampling_res_conv[-1] - lowest_sampling_res)
n_children = 1 << n_bitshit
nodes = nodes << n_bitshit
nodes = nodes.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
nodes = nodes.flatten()
index, weight = self.getGraph(sampling_res=list_downsampling_res_conv[-1])
l_first = next((i for i in reversed(range(nlayers_downsampling)) if list_downsampling_res_conv[-1] != list_downsampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(hop_downsampling[l_first:]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
nodes, index, mapping, edge_mask = torch_geometric.utils.k_hop_subgraph(nodes, edge_index=index,
num_hops=aggregated_K + 1,
relabel_nodes=True)
weight = weight[edge_mask]
list_mapping_downsampling[-1] = mapping
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
for l in reversed(range(nlayers_downsampling - 1)):
if list_downsampling_res_conv[l] != list_downsampling_res_conv[l + 1]:
n_bitshit = 2 * (list_downsampling_res_conv[l] - list_downsampling_res_conv[l+1])
n_children = 1 << n_bitshit
nodes = nodes << n_bitshit
nodes = nodes.unsqueeze(1).repeat(1, n_children) + torch.arange(n_children)
nodes = nodes.flatten()
index, weight = self.getGraph(sampling_res=list_downsampling_res_conv[l])
l_first = next((i for i in reversed(range(l + 1)) if list_downsampling_res_conv[l] != list_downsampling_res_conv[i]), -1) + 1
aggregated_K = np.sum(hop_downsampling[l_first:l + 1]) # casacde of conv layers at the same resolution has an effective hop equal to sum of each hop
nodes, index, mapping, edge_mask = torch_geometric.utils.k_hop_subgraph(nodes, edge_index=index,
num_hops=aggregated_K + 1,
relabel_nodes=True)
weight = weight[edge_mask]
list_mapping_downsampling[l] = mapping
list_index_downsampling[l_first], list_weight_downsampling[l_first] = index, weight
_, nPixPerPatch = self.getPatchesInfo(upsampled_resolution, patch_upsampled_resolution)
ind_start = (nodes == patch_id * nPixPerPatch).nonzero().item() # to find index of the node==patch_id*nPixPerPatch
# Maybe later I can remove the next assert check.
assert torch.all(torch.eq(nodes.narrow(dim=0, start=ind_start, length=nPixPerPatch), torch.arange(patch_id * nPixPerPatch, (patch_id + 1) * nPixPerPatch, dtype=nodes.dtype))), "patch nodes from upsampling must already contains last resolution patch nodes in a sorted order"
range_downsampling_input_to_patch = (ind_start, ind_start+nPixPerPatch)
# print("ending downsampling graph construction", flush=True)
dict_graphs["downsampling"] = {"list_sampling_res":list_downsampling_res_conv, "list_patch_res":list_downsampling_patch_res_conv,
"list_index":list_index_downsampling, "list_weight":list_weight_downsampling,
"input_nodes":nodes, "list_mapping":list_mapping_downsampling,
"range_downsampling_output_to_patch":range_downsampling_output_to_patch,
"range_downsampling_input_to_patch":range_downsampling_input_to_patch}
return dict_graphs