Mentions légales du service

Skip to content
Snippets Groups Projects
Commit b444a0c3 authored by PaulWawerek-L's avatar PaulWawerek-L
Browse files

function conv_hop_2_node_n for cm_ar

parent c72d3258
No related branches found
No related tags found
1 merge request!2OSLO-IC
...@@ -6,6 +6,7 @@ from spherical_models import SphereGaussianConditional ...@@ -6,6 +6,7 @@ from spherical_models import SphereGaussianConditional
from compressai.models.utils import update_registered_buffers from compressai.models.utils import update_registered_buffers
from compressai.ans import BufferedRansEncoder, RansDecoder from compressai.ans import BufferedRansEncoder, RansDecoder
from spherical_models import SLB_Downsample, SLB_Upsample from spherical_models import SLB_Downsample, SLB_Upsample
from spherical_models.sdpa_conv import conv_hop_2_node_n
import numpy as np import numpy as np
# pyright: reportGeneralTypeIssues=warning # pyright: reportGeneralTypeIssues=warning
# From Balle's tensorflow compression examples # From Balle's tensorflow compression examples
...@@ -209,29 +210,21 @@ class SphereFactorizedPrior(SphereCompressionModel): ...@@ -209,29 +210,21 @@ class SphereFactorizedPrior(SphereCompressionModel):
symbols_list = [] symbols_list = []
indexes_list = [] indexes_list = []
n_nodes, n_neighbors = y_hat.size(1), neighbors_indices.size(1) n_nodes = y_hat.size(1)
convs = self.autoregressive.list_conv weights = [conv.weight for conv in self.autoregressive.list_conv]
bias = [conv.bias for conv in self.autoregressive.list_conv]
# causal mask for convolution # causal mask for convolution
for n in range(n_nodes): for n in range(n_nodes):
neighbors_weights_masked = torch.mul(neighbors_weights, (neighbors_indices < n)) ctx_p = conv_hop_2_node_n(
# buffer for SDPA conv, should only be calculated at current node n torch.abs(y_hat) if self.arabs else y_hat,
xs = [torch.zeros(y_hat.size(0), n_neighbors+1 if i==0 else 1, convs[i].weight.size(-1), dtype=y_hat.dtype, device=y_hat.device) for i in range(len(convs))] weights,
y_hat_in = torch.abs(y_hat) if self.arabs else y_hat bias,
for k, k_ind in enumerate([n]+neighbors_indices[n].tolist()): # iteration over current node and its neighbors n,
# if k_ind > n: continue # TODO try without this line neighbors_indices,
y_neighbors = torch.mul(neighbors_weights_masked[k_ind].view(1,-1,1), y_hat_in[:, neighbors_indices[k_ind].tolist(), :]) neighbors_weights,
xs[0][:, k, :] = torch.matmul(y_neighbors.flatten(1,2), convs[0].weight.data[1:].flatten(0,1)) skipconn,
if convs[0].bias is not None: masked=True,
xs[0] += convs[0].bias.data )
# compute second convolution at current node
xs[1][:, 0, :] = torch.matmul(xs[0][:, 0, :], convs[1].weight.data[0])
s = torch.mul(neighbors_weights[n].view(1,-1,1), xs[0][:, 1:, :])
xs[1][:, 0, :] = torch.matmul(s.flatten(1,2), convs[1].weight.data[1:].flatten(0,1))
if convs[1].bias is not None:
xs[1] += convs[1].bias.data
# only keep current node
xs[0] = xs[0][:, 0:1, :]
ctx_p = skipconn(xs) if skipconn is not None else xs[-1]
if self.arrelu: if self.arrelu:
ctx_p = torch.nn.functional.relu(ctx_p) ctx_p = torch.nn.functional.relu(ctx_p)
scales_hat = ctx_p.squeeze(1) scales_hat = ctx_p.squeeze(1)
...@@ -306,27 +299,21 @@ class SphereFactorizedPrior(SphereCompressionModel): ...@@ -306,27 +299,21 @@ class SphereFactorizedPrior(SphereCompressionModel):
decoder = RansDecoder() decoder = RansDecoder()
decoder.set_stream(y_string) decoder.set_stream(y_string)
n_nodes, n_neighbors = y_hat.size(1), neighbors_indices.size(1) n_nodes = y_hat.size(1)
convs = self.autoregressive.list_conv weights = [conv.weight for conv in self.autoregressive.list_conv]
bias = [conv.bias for conv in self.autoregressive.list_conv]
# Warning: this is slow due to the auto-regressive nature of the decoding # Warning: this is slow due to the auto-regressive nature of the decoding
for n in range(n_nodes): for n in range(n_nodes):
xs = [torch.zeros(y_hat.size(0), n_neighbors+1 if i==0 else 1, convs[i].weight.size(-1), dtype=y_hat.dtype, device=y_hat.device) for i in range(len(convs))] ctx_p = conv_hop_2_node_n(
y_hat_in = torch.abs(y_hat) if self.arabs else y_hat torch.abs(y_hat) if self.arabs else y_hat,
for k, k_ind in enumerate([n]+neighbors_indices[n].tolist()): # iteration over current node and its neighbors weights,
# if k_ind > n: continue bias,
y_neighbors = torch.mul(neighbors_weights[k_ind].view(1,-1,1), y_hat_in[:, neighbors_indices[k_ind].tolist(), :]) n,
xs[0][:, k, :] = torch.matmul(y_neighbors.flatten(1,2), convs[0].weight.data[1:].flatten(0,1)) neighbors_indices,
if convs[0].bias is not None: neighbors_weights,
xs[0] += convs[0].bias.data skipconn,
# compute second convolution at current node masked=False,
xs[1][:, 0, :] = torch.matmul(xs[0][:, 0, :], convs[1].weight.data[0]) )
s = torch.mul(neighbors_weights[n].view(1,-1,1), xs[0][:, 1:, :])
xs[1][:, 0, :] = torch.matmul(s.flatten(1,2), convs[1].weight.data[1:].flatten(0,1))
if convs[1].bias is not None:
xs[1] += convs[1].bias.data
# only keep current node
xs[0] = xs[0][:, 0:1, :]
ctx_p = skipconn(xs) if skipconn is not None else xs[-1]
if self.arrelu: if self.arrelu:
ctx_p = torch.nn.functional.relu(ctx_p) ctx_p = torch.nn.functional.relu(ctx_p)
# TODO for mean: scales_hat, means_hat = ctx_p.chunk(2, 1) # TODO for mean: scales_hat, means_hat = ctx_p.chunk(2, 1)
......
...@@ -6,6 +6,7 @@ from spherical_models import SphereGaussianConditional ...@@ -6,6 +6,7 @@ from spherical_models import SphereGaussianConditional
from compressai.models.utils import update_registered_buffers from compressai.models.utils import update_registered_buffers
from compressai.ans import BufferedRansEncoder, RansDecoder from compressai.ans import BufferedRansEncoder, RansDecoder
from spherical_models import SLB_Downsample, SLB_Upsample from spherical_models import SLB_Downsample, SLB_Upsample
from spherical_models.sdpa_conv import conv_hop_2_node_n
import numpy as np import numpy as np
# pyright: reportGeneralTypeIssues=warning # pyright: reportGeneralTypeIssues=warning
# From Balle's tensorflow compression examples # From Balle's tensorflow compression examples
...@@ -285,28 +286,21 @@ class SphereScaleHyperprior(SphereCompressionModel): ...@@ -285,28 +286,21 @@ class SphereScaleHyperprior(SphereCompressionModel):
symbols_list = [] symbols_list = []
indexes_list = [] indexes_list = []
n_nodes, n_neighbors = y_hat.size(1), neighbors_indices.size(1) n_nodes = y_hat.size(1)
convs = self.autoregressive.list_conv weights = [conv.weight for conv in self.autoregressive.list_conv]
bias = [conv.bias for conv in self.autoregressive.list_conv]
# causal mask for convolution # causal mask for convolution
for n in range(n_nodes): for n in range(n_nodes):
neighbors_weights_masked = torch.mul(neighbors_weights, (neighbors_indices < n)) ctx_p = conv_hop_2_node_n(
# buffer for SDPA conv, should only be calculated at current node n torch.abs(y_hat) if self.arabs else y_hat,
xs = [torch.zeros(y_hat.size(0), n_neighbors+1 if i==0 else 1, convs[i].weight.size(-1), dtype=y_hat.dtype, device=y_hat.device) for i in range(len(convs))] weights,
y_hat_in = torch.abs(y_hat) if self.arabs else y_hat bias,
for k, k_ind in enumerate([n]+neighbors_indices[n].tolist()): # iteration over current node and its neighbors n,
y_neighbors = torch.mul(neighbors_weights_masked[k_ind].view(1,-1,1), y_hat_in[:, neighbors_indices[k_ind].tolist(), :]) neighbors_indices,
xs[0][:, k, :] = torch.matmul(y_neighbors.flatten(1,2), convs[0].weight.data[1:].flatten(0,1)) neighbors_weights,
if convs[0].bias is not None: skipconn,
xs[0] += convs[0].bias.data masked=True,
# compute second convolution at current node )
xs[1][:, 0, :] = torch.matmul(xs[0][:, 0, :], convs[1].weight.data[0])
s = torch.mul(neighbors_weights[n].view(1,-1,1), xs[0][:, 1:, :])
xs[1][:, 0, :] += torch.matmul(s.flatten(1,2), convs[1].weight.data[1:].flatten(0,1))
if convs[1].bias is not None:
xs[1] += convs[1].bias.data
# only keep current node
xs[0] = xs[0][:, 0:1, :]
ctx_p = skipconn(xs) if skipconn is not None else xs[-1]
if self.arrelu: if self.arrelu:
ctx_p = torch.nn.functional.relu(ctx_p) ctx_p = torch.nn.functional.relu(ctx_p)
# 1x1 conv for the entropy parameters prediction network, so # 1x1 conv for the entropy parameters prediction network, so
...@@ -392,26 +386,21 @@ class SphereScaleHyperprior(SphereCompressionModel): ...@@ -392,26 +386,21 @@ class SphereScaleHyperprior(SphereCompressionModel):
decoder = RansDecoder() decoder = RansDecoder()
decoder.set_stream(y_string) decoder.set_stream(y_string)
n_nodes, n_neighbors = y_hat.size(1), neighbors_indices.size(1) n_nodes = y_hat.size(1)
convs = self.autoregressive.list_conv weights = [conv.weight for conv in self.autoregressive.list_conv]
bias = [conv.bias for conv in self.autoregressive.list_conv]
# Warning: this is slow due to the auto-regressive nature of the decoding # Warning: this is slow due to the auto-regressive nature of the decoding
for n in range(n_nodes): for n in range(n_nodes):
xs = [torch.zeros(y_hat.size(0), n_neighbors+1 if i==0 else 1, convs[i].weight.size(-1), dtype=y_hat.dtype, device=y_hat.device) for i in range(len(convs))] ctx_p = conv_hop_2_node_n(
y_hat_in = torch.abs(y_hat) if self.arabs else y_hat torch.abs(y_hat) if self.arabs else y_hat,
for k, k_ind in enumerate([n]+neighbors_indices[n].tolist()): # iteration over current node and its neighbors weights,
y_neighbors = torch.mul(neighbors_weights[k_ind].view(1,-1,1), y_hat_in[:, neighbors_indices[k_ind].tolist(), :]) bias,
xs[0][:, k, :] = torch.matmul(y_neighbors.flatten(1,2), convs[0].weight.data[1:].flatten(0,1)) n,
if convs[0].bias is not None: neighbors_indices,
xs[0] += convs[0].bias.data neighbors_weights,
# compute second convolution at current node skipconn,
xs[1][:, 0, :] = torch.matmul(xs[0][:, 0, :], convs[1].weight.data[0]) masked=False,
s = torch.mul(neighbors_weights[n].view(1,-1,1), xs[0][:, 1:, :]) )
xs[1][:, 0, :] += torch.matmul(s.flatten(1,2), convs[1].weight.data[1:].flatten(0,1))
if convs[1].bias is not None:
xs[1] += convs[1].bias.data
# only keep current node
xs[0] = xs[0][:, 0:1, :]
ctx_p = skipconn(xs) if skipconn is not None else xs[-1]
if self.arrelu: if self.arrelu:
ctx_p = torch.nn.functional.relu(ctx_p) ctx_p = torch.nn.functional.relu(ctx_p)
# 1x1 conv for the entropy parameters prediction network, so # 1x1 conv for the entropy parameters prediction network, so
......
...@@ -125,6 +125,33 @@ def sdpaconv(x, weights:list, biases=None, skip_conn_aggr:str='sum', mask:str='f ...@@ -125,6 +125,33 @@ def sdpaconv(x, weights:list, biases=None, skip_conn_aggr:str='sum', mask:str='f
out = _sphere_skip_connection(xs, skip_conn_aggr) if num_conv and skip_conn_aggr else xs[-1] out = _sphere_skip_connection(xs, skip_conn_aggr) if num_conv and skip_conn_aggr else xs[-1]
return out return out
def conv_hop_2_node_n(x, weights:list, bias:list, n, neighbors_indices, neighbors_weights, skipconn, masked=False):
'calculate the 2-hop convolution only at node n'
n_neighbors = neighbors_indices.size(1)
if masked:
neighbors_weights_in = torch.mul(neighbors_weights, (neighbors_indices < n))
else:
neighbors_weights_in = neighbors_weights
# buffer for SDPA conv, should only be calculated at current node n
xs = [torch.zeros(x.size(0), n_neighbors+1 if i==0 else 1, weights[i].size(-1), dtype=x.dtype, device=x.device) for i in range(len(weights))]
for k, k_ind in enumerate([n]+neighbors_indices[n].tolist()): # iteration over current node and its neighbors
y_neighbors = torch.mul(neighbors_weights_in[k_ind].view(1,-1,1), x[:, neighbors_indices[k_ind].tolist(), :])
xs[0][:, k, :] = torch.matmul(y_neighbors.flatten(1,2), weights[0].data[1:].flatten(0,1))
if k_ind < n:
xs[0][:, k, :] += torch.matmul(x[:, k_ind, :], weights[0].data[0])
if bias[0] is not None:
xs[0] += bias[0].data
# compute second convolution at current node
xs[1][:, 0, :] = torch.matmul(xs[0][:, 0, :], weights[1].data[0])
s = torch.mul(neighbors_weights[n].view(1,-1,1), xs[0][:, 1:, :])
xs[1][:, 0, :] += torch.matmul(s.flatten(1,2), weights[1].data[1:].flatten(0,1))
if bias[1] is not None:
xs[1] += bias[1].data
# only keep current node
xs[0] = xs[0][:, 0:1, :]
out = skipconn(xs) if skipconn is not None else xs[-1]
return out
def _sphere_skip_connection(xs, mode:str='sum'): def _sphere_skip_connection(xs, mode:str='sum'):
r"""Aggregates representations across different layers. r"""Aggregates representations across different layers.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment