Mentions légales du service

Skip to content
Snippets Groups Projects
sdpa_conv.py 9.32 KiB
Newer Older
Navid's avatar
Navid committed
import torch
import math
import torch_geometric
# pyright: reportOptionalSubscript=warning
Navid's avatar
Navid committed

class SDPAConv (torch.nn.Module):
    r"""Class for implementing Sphere Directional and Position-Aware convolution
    """
PaulWawerek-L's avatar
PaulWawerek-L committed
    def __init__(self, in_channels, out_channels, kernel_size, node_dim=0, bias=True, mask:str='full'):
        if mask not in ['A', 'B', 'full']: raise NotImplementedError("masked convolution either 'A', 'B' or 'full'")
Navid's avatar
Navid committed
        super(SDPAConv, self).__init__()

        assert node_dim >= 0
        self.node_dim = node_dim

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size

        self.weight = torch.nn.Parameter(torch.Tensor(kernel_size, in_channels, out_channels))
PaulWawerek-L's avatar
PaulWawerek-L committed
        self.mask = mask
Navid's avatar
Navid committed
        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        # Took it from torch.nn.Conv2d()
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        # torch.nn.init.xavier_uniform_(self.weight, gain=2.)
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)
        # Took it from torch_geometric.nn.ChebConv
        # torch_geometric.nn.inits.glorot(self.weight)
        # torch_geometric.nn.inits.zeros(self.bias)

    def forward(self, x, neighbors_indices=None, neighbors_weights=None, valid_index=None):
        # in case of 1x1 convolution, the neighbors_indices and neighbors_weights are not needed
        assert (self.kernel_size==1) or ((self.kernel_size-1)<=neighbors_weights.size(1)), "size does not match"
            assert (neighbors_indices is not None) and (neighbors_weights is not None), "neighbors_indices and _weights must be provided for 2-hop convolution"
            neighbors_indices = neighbors_indices[:, :self.kernel_size-1]
            neighbors_weights = neighbors_weights[:, :self.kernel_size-1]
            if valid_index is not None: valid_index = valid_index[:, :self.kernel_size-1]
        
PaulWawerek-L's avatar
PaulWawerek-L committed
        if self.mask in ['A', 'B']: # set future neighbors to zero
            neighbors_weights = torch.mul(neighbors_weights, (neighbors_indices < torch.arange(x.size(1), device=neighbors_weights.device).view(-1, 1)))
        if self.mask != 'A': # current node included in convolution
            out = torch.matmul(x, self.weight[0])
        else:
            out = torch.zeros(x.size(0), x.size(1), self.out_channels, dtype=x.dtype, device=x.device)
Navid's avatar
Navid committed

        # test_out = torch.zeros(x.size(), dtype=x.dtype)
        # for k in range(neighbors_weights.size(1)):
        #     test_out += torch.mul(neighbors_weights.narrow(dim=1, start=k, length=1), x.index_select(self.node_dim, neighbors_indices[:, k]))
        # print("test_out finished")
PaulWawerek-L's avatar
PaulWawerek-L committed
        
Navid's avatar
Navid committed
        for k in range(1, self.kernel_size):
            col = k-1
            if valid_index is None:
                s = torch.mul(neighbors_weights.narrow(dim=1, start=col, length=1), x.index_select(self.node_dim, neighbors_indices[:, col]))   # or I could use neighbors_weights[:,col].view(-1, 1)
                out += torch.matmul(s, self.weight[k])
            else:
                valid_rows = valid_index[:, col]
                s = torch.mul(neighbors_weights[valid_rows, col].view(-1, 1), x.index_select(self.node_dim, neighbors_indices[valid_rows, col]))
                out[:, valid_rows, :] += torch.matmul(s, self.weight[k])

        if self.bias is not None:
            out += self.bias

        return out

PaulWawerek-L's avatar
PaulWawerek-L committed

def sdpaconv(x, weights:list, biases=None, skip_conn_aggr:str='sum', mask:str='full', neighbors_indices=None, neighbors_weights=None, valid_index=None):
    r"""SDPA convolution

    Args:
        x (tensor): input tensor
        weights (list|tuple): weight tensors in list or tuple, len(weights) = n_hops
        biases (list|None): bias tensors, len(biases) = n_hops. Default None (no bias)
        skip_conn_aggr (str): skip connection aggregation method from ['', 'cat', 'max', 'sum'].
            Default 'sum'.
        mask (str): mask type from ['A', 'B', 'full']. Default 'full' (no masked convolution).
            'A': masked convolution with center excluded,
            'B': masked convolution with center included,
            'full': no masked convolution.
        neighbors_indices (tensor|None): neighbor indices. Not needed for 1x1 convolution.
            Default None.
        neighbors_weights (tensor|None): neighbor weights. Not needed for 1x1 convolution.
            Default None.
        valid_index (tensor|None): bool tensor of valid neighbors. Not needed for 1x1 convolution.
            Default None.
    """
    node_dim = 1
    kernel_size, num_conv = len(weights[0]), len(weights)
    # in case of 1x1 convolution, the neighbors_indices and neighbors_weights are not needed
    assert (kernel_size==1) or ((kernel_size-1)==neighbors_weights.size(1)), "size does not match"
    if biases is not None: assert len(biases)==len(weights)
    if mask in ['A', 'B']:
        assert (neighbors_indices is not None) and (neighbors_weights is not None), "neighbors_indices and _weights must be provided for masked convolution"
        neighbors_weights = torch.mul(neighbors_weights, (neighbors_indices < torch.arange(x.size(1), device=neighbors_weights.device).view(-1, 1)))
    if neighbors_indices is not None: neighbors_indices = neighbors_indices.to(x.device)
    if neighbors_weights is not None: neighbors_weights = neighbors_weights.to(x.device)
    if valid_index is not None: valid_index = valid_index.to(x.device)
    
    xs = []
    for i in range(num_conv):
        if mask != 'A':
            out = torch.matmul(x, weights[i][0])
        else:
            out = torch.zeros(x.size(0), x.size(1), weights[i].size(-1), dtype=x.dtype, device=x.device)

        for k in range(1, kernel_size):
            col = k-1
            if valid_index is None:
                s = torch.mul(neighbors_weights.narrow(dim=1, start=col, length=1), x.index_select(node_dim, neighbors_indices[:, col]))
                out += torch.matmul(s, weights[i][k])
            else:
                valid_rows = valid_index[:, col]
                s = torch.mul(neighbors_weights[valid_rows, col].view(-1, 1), x.index_select(node_dim, neighbors_indices[valid_rows, col]))
                out[:, valid_rows, :] += torch.matmul(s, weights[i][k])

        if biases is not None: out += biases[i]
        xs.append(out)
        
    out = _sphere_skip_connection(xs, skip_conn_aggr) if num_conv and skip_conn_aggr else xs[-1]
    return out

def sdpaconv_node_n(x, hops:int, weights:list, bias:list, n, neighbors_indices, neighbors_weights, skipconn=None, masked=False):
    'calculate the 2-hop convolution only at node n'
    if hops not in [1, 2]: raise NotImplementedError("Number of hops must be 1 or 2")
    assert len(weights)==hops, "number of weights must match number of hops"
    if bias is not None: assert len(bias)==hops, "number of biases must match number of hops"
    
    n_neighbors = neighbors_indices.size(1)
    if masked:
        neighbors_weights_in = torch.mul(neighbors_weights, (neighbors_indices < n))
    else:
        neighbors_weights_in = neighbors_weights
    # buffer for SDPA conv, should only be calculated at current node n
    if hops == 1:
        xs = [torch.zeros(x.size(0), 1, weights[0].size(-1), dtype=x.dtype, device=x.device)]
    else:
        xs = [torch.zeros(x.size(0), n_neighbors+1 if i==0 else 1, weights[i].size(-1), dtype=x.dtype, device=x.device) for i in range(len(weights))]
    
    calc_indices = [n]
    if hops==2: calc_indices += neighbors_indices[n].tolist()
    for k, k_ind in enumerate(calc_indices): # iteration over current node and its neighbors
        cur_neighbors = torch.mul(neighbors_weights_in[k_ind].view(1,-1,1), x[:, neighbors_indices[k_ind].tolist(), :])
        xs[0][:, k, :] = torch.matmul(cur_neighbors.flatten(1,2), weights[0].data[1:].flatten(0,1))
        if not (masked and (k_ind >= n)):
            xs[0][:, k, :] += torch.matmul(x[:, k_ind, :], weights[0].data[0])
    if bias[0] is not None:
        xs[0] += bias[0].data
    
    if hops == 2: # compute second convolution at current node
        xs[1][:, 0, :] = torch.matmul(xs[0][:, 0, :], weights[1].data[0])
        s = torch.mul(neighbors_weights[n].view(1,-1,1), xs[0][:, 1:, :])
        xs[1][:, 0, :] += torch.matmul(s.flatten(1,2), weights[1].data[1:].flatten(0,1))
        if bias[1] is not None:
            xs[1] += bias[1].data
        # only keep current node
        xs[0] = xs[0][:, 0:1, :]
    
    out = skipconn(xs) if (hops==2) and (skipconn is not None) else xs[-1]
PaulWawerek-L's avatar
PaulWawerek-L committed
def _sphere_skip_connection(xs, mode:str='sum'):
        r"""Aggregates representations across different layers.

        Args:
            xs (list or tuple): List containing layer-wise representations.
            mode (str): Aggregation mode. Can be one of the following: cat, max, sum
        """
        assert isinstance(xs, list) or isinstance(xs, tuple)
        assert mode in ['cat', 'max', 'sum']

        if mode == 'cat': return torch.cat(xs, dim=-1)
        elif mode == 'max': return torch.stack(xs, dim=-1).max(dim=-1)[0]
        elif mode == 'sum': return torch.stack(xs, dim=-1).sum(dim=-1)