Newer
Older
class SDPAConv (torch.nn.Module):
r"""Class for implementing Sphere Directional and Position-Aware convolution
"""
def __init__(self, in_channels, out_channels, kernel_size, node_dim=0, bias=True, mask:str='full'):
if mask not in ['A', 'B', 'full']: raise NotImplementedError("masked convolution either 'A', 'B' or 'full'")
super(SDPAConv, self).__init__()
assert node_dim >= 0
self.node_dim = node_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.weight = torch.nn.Parameter(torch.Tensor(kernel_size, in_channels, out_channels))
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
# Took it from torch.nn.Conv2d()
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
# torch.nn.init.xavier_uniform_(self.weight, gain=2.)
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
# Took it from torch_geometric.nn.ChebConv
# torch_geometric.nn.inits.glorot(self.weight)
# torch_geometric.nn.inits.zeros(self.bias)
def forward(self, x, neighbors_indices=None, neighbors_weights=None, valid_index=None):
# in case of 1x1 convolution, the neighbors_indices and neighbors_weights are not needed
assert (self.kernel_size==1) or ((self.kernel_size-1)<=neighbors_weights.size(1)), "size does not match"
if self.kernel_size > 1:
assert (neighbors_indices is not None) and (neighbors_weights is not None), "neighbors_indices and _weights must be provided for 2-hop convolution"
neighbors_indices = neighbors_indices[:, :self.kernel_size-1]
neighbors_weights = neighbors_weights[:, :self.kernel_size-1]
if valid_index is not None: valid_index = valid_index[:, :self.kernel_size-1]
if self.mask in ['A', 'B']: # set future neighbors to zero
neighbors_weights = torch.mul(neighbors_weights, (neighbors_indices < torch.arange(x.size(1), device=neighbors_weights.device).view(-1, 1)))
if self.mask != 'A': # current node included in convolution
out = torch.matmul(x, self.weight[0])
else:
out = torch.zeros(x.size(0), x.size(1), self.out_channels, dtype=x.dtype, device=x.device)
# test_out = torch.zeros(x.size(), dtype=x.dtype)
# for k in range(neighbors_weights.size(1)):
# test_out += torch.mul(neighbors_weights.narrow(dim=1, start=k, length=1), x.index_select(self.node_dim, neighbors_indices[:, k]))
# print("test_out finished")
for k in range(1, self.kernel_size):
col = k-1
if valid_index is None:
s = torch.mul(neighbors_weights.narrow(dim=1, start=col, length=1), x.index_select(self.node_dim, neighbors_indices[:, col])) # or I could use neighbors_weights[:,col].view(-1, 1)
out += torch.matmul(s, self.weight[k])
else:
valid_rows = valid_index[:, col]
s = torch.mul(neighbors_weights[valid_rows, col].view(-1, 1), x.index_select(self.node_dim, neighbors_indices[valid_rows, col]))
out[:, valid_rows, :] += torch.matmul(s, self.weight[k])
if self.bias is not None:
out += self.bias
return out
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def sdpaconv(x, weights:list, biases=None, skip_conn_aggr:str='sum', mask:str='full', neighbors_indices=None, neighbors_weights=None, valid_index=None):
r"""SDPA convolution
Args:
x (tensor): input tensor
weights (list|tuple): weight tensors in list or tuple, len(weights) = n_hops
biases (list|None): bias tensors, len(biases) = n_hops. Default None (no bias)
skip_conn_aggr (str): skip connection aggregation method from ['', 'cat', 'max', 'sum'].
Default 'sum'.
mask (str): mask type from ['A', 'B', 'full']. Default 'full' (no masked convolution).
'A': masked convolution with center excluded,
'B': masked convolution with center included,
'full': no masked convolution.
neighbors_indices (tensor|None): neighbor indices. Not needed for 1x1 convolution.
Default None.
neighbors_weights (tensor|None): neighbor weights. Not needed for 1x1 convolution.
Default None.
valid_index (tensor|None): bool tensor of valid neighbors. Not needed for 1x1 convolution.
Default None.
"""
node_dim = 1
kernel_size, num_conv = len(weights[0]), len(weights)
# in case of 1x1 convolution, the neighbors_indices and neighbors_weights are not needed
assert (kernel_size==1) or ((kernel_size-1)==neighbors_weights.size(1)), "size does not match"
if biases is not None: assert len(biases)==len(weights)
if mask in ['A', 'B']:
assert (neighbors_indices is not None) and (neighbors_weights is not None), "neighbors_indices and _weights must be provided for masked convolution"
neighbors_weights = torch.mul(neighbors_weights, (neighbors_indices < torch.arange(x.size(1), device=neighbors_weights.device).view(-1, 1)))
if neighbors_indices is not None: neighbors_indices = neighbors_indices.to(x.device)
if neighbors_weights is not None: neighbors_weights = neighbors_weights.to(x.device)
if valid_index is not None: valid_index = valid_index.to(x.device)
xs = []
for i in range(num_conv):
if mask != 'A':
out = torch.matmul(x, weights[i][0])
else:
out = torch.zeros(x.size(0), x.size(1), weights[i].size(-1), dtype=x.dtype, device=x.device)
for k in range(1, kernel_size):
col = k-1
if valid_index is None:
s = torch.mul(neighbors_weights.narrow(dim=1, start=col, length=1), x.index_select(node_dim, neighbors_indices[:, col]))
out += torch.matmul(s, weights[i][k])
else:
valid_rows = valid_index[:, col]
s = torch.mul(neighbors_weights[valid_rows, col].view(-1, 1), x.index_select(node_dim, neighbors_indices[valid_rows, col]))
out[:, valid_rows, :] += torch.matmul(s, weights[i][k])
if biases is not None: out += biases[i]
xs.append(out)
out = _sphere_skip_connection(xs, skip_conn_aggr) if num_conv and skip_conn_aggr else xs[-1]
return out
def sdpaconv_node_n(x, hops:int, weights:list, bias:list, n, neighbors_indices, neighbors_weights, skipconn=None, masked=False):
'calculate the 2-hop convolution only at node n'
if hops not in [1, 2]: raise NotImplementedError("Number of hops must be 1 or 2")
assert len(weights)==hops, "number of weights must match number of hops"
if bias is not None: assert len(bias)==hops, "number of biases must match number of hops"
n_neighbors = neighbors_indices.size(1)
if masked:
neighbors_weights_in = torch.mul(neighbors_weights, (neighbors_indices < n))
else:
neighbors_weights_in = neighbors_weights
# buffer for SDPA conv, should only be calculated at current node n
if hops == 1:
xs = [torch.zeros(x.size(0), 1, weights[0].size(-1), dtype=x.dtype, device=x.device)]
else:
xs = [torch.zeros(x.size(0), n_neighbors+1 if i==0 else 1, weights[i].size(-1), dtype=x.dtype, device=x.device) for i in range(len(weights))]
calc_indices = [n]
if hops==2: calc_indices += neighbors_indices[n].tolist()
for k, k_ind in enumerate(calc_indices): # iteration over current node and its neighbors
cur_neighbors = torch.mul(neighbors_weights_in[k_ind].view(1,-1,1), x[:, neighbors_indices[k_ind].tolist(), :])
xs[0][:, k, :] = torch.matmul(cur_neighbors.flatten(1,2), weights[0].data[1:].flatten(0,1))
if not (masked and (k_ind >= n)):
xs[0][:, k, :] += torch.matmul(x[:, k_ind, :], weights[0].data[0])
if bias[0] is not None:
xs[0] += bias[0].data
if hops == 2: # compute second convolution at current node
xs[1][:, 0, :] = torch.matmul(xs[0][:, 0, :], weights[1].data[0])
s = torch.mul(neighbors_weights[n].view(1,-1,1), xs[0][:, 1:, :])
xs[1][:, 0, :] += torch.matmul(s.flatten(1,2), weights[1].data[1:].flatten(0,1))
if bias[1] is not None:
xs[1] += bias[1].data
# only keep current node
xs[0] = xs[0][:, 0:1, :]
out = skipconn(xs) if (hops==2) and (skipconn is not None) else xs[-1]
def _sphere_skip_connection(xs, mode:str='sum'):
r"""Aggregates representations across different layers.
Args:
xs (list or tuple): List containing layer-wise representations.
mode (str): Aggregation mode. Can be one of the following: cat, max, sum
"""
assert isinstance(xs, list) or isinstance(xs, tuple)
assert mode in ['cat', 'max', 'sum']
if mode == 'cat': return torch.cat(xs, dim=-1)
elif mode == 'max': return torch.stack(xs, dim=-1).max(dim=-1)[0]
elif mode == 'sum': return torch.stack(xs, dim=-1).sum(dim=-1)