Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 717de414 authored by PaulWawerek-L's avatar PaulWawerek-L
Browse files

context model for meanScaleHyperprior

parent b444a0c3
No related branches found
No related tags found
1 merge request!2OSLO-IC
import warnings
import torch import torch
from spherical_models.compression_models import SphereScaleHyperprior from spherical_models.compression_models import SphereScaleHyperprior
from spherical_models import SLB_Upsample from compressai.ans import BufferedRansEncoder, RansDecoder
from spherical_models import SLB_Downsample, SLB_Upsample
from spherical_models.sdpa_conv import conv_hop_2_node_n
import numpy as np import numpy as np
# pyright: reportGeneralTypeIssues=warning # pyright: reportGeneralTypeIssues=warning
...@@ -12,8 +15,11 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior): ...@@ -12,8 +15,11 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior):
N (int): Number of channels N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the encoder) M (int): Number of channels in the expansion layers (last layer of the encoder)
""" """
def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', **kwargs): def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', context_model:str='', arrelu:bool=False, arabs:bool=False, **kwargs):
super().__init__(N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention, activation, **kwargs) if arabs:
warnings.warn("arabs=True (input abs values to context model) is not recommended for SphereMeanScaleHyperprior, setting it to False")
arabs = False
super().__init__(N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention, activation, context_model, arrelu, arabs, **kwargs)
####################### h_s ####################### ####################### h_s #######################
self.h_s = torch.nn.ModuleList() self.h_s = torch.nn.ModuleList()
self.h_s.append(SLB_Upsample(conv_name, N, M, hop=2, skip_conn_aggr=skip_conn_aggr, self.h_s.append(SLB_Upsample(conv_name, N, M, hop=2, skip_conn_aggr=skip_conn_aggr,
...@@ -25,6 +31,22 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior): ...@@ -25,6 +31,22 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior):
# effective hop=1 => num_conv = 1, and there is no Upsampling # effective hop=1 => num_conv = 1, and there is no Upsampling
self.h_s.append(SLB_Upsample(conv_name, M*3//2, M*2, hop=1, skip_conn_aggr=None, self.h_s.append(SLB_Upsample(conv_name, M*3//2, M*2, hop=1, skip_conn_aggr=None,
activation="ReLU", activation_args={"inplace": True})) activation="ReLU", activation_args={"inplace": True}))
################## Context Model ##################
if context_model:
mask = 'full' if context_model=='full' else 'A'
if arrelu:
self.autoregressive = SLB_Downsample(conv_name, M, M*2, hop=2, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True}, mask=mask)
else: # default
self.autoregressive = SLB_Downsample(conv_name, M, M*2, hop=2, skip_conn_aggr=skip_conn_aggr,
activation=None, mask=mask)
self.combine_ar_hp = torch.nn.Sequential(
SLB_Downsample('SDPAConv', M*12//3, M*10//3, hop=1, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True}, conv1x1=True),
SLB_Downsample('SDPAConv', M*10//3, M*8//3, hop=1, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True}, conv1x1=True),
SLB_Downsample('SDPAConv', M*8//3, M*6//3, hop=1, skip_conn_aggr=skip_conn_aggr, conv1x1=True)
)
def forward(self, x, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None): # x is a tensor of size [batch_size, num_nodes, num_features] def forward(self, x, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None): # x is a tensor of size [batch_size, num_nodes, num_features]
data_res = res if patch_res is None else (res, patch_res) data_res = res if patch_res is None else (res, patch_res)
...@@ -53,20 +75,30 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior): ...@@ -53,20 +75,30 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior):
# print("z.max()=", z.max(), "z.min()=", z.min()) # print("z.max()=", z.max(), "z.min()=", z.min())
# print("torch.abs(y).max()=", torch.abs(y).max(), "torch.abs(y).min()=", torch.abs(y).min()) # print("torch.abs(y).max()=", torch.abs(y).max(), "torch.abs(y).min()=", torch.abs(y).min())
z_hat, z_likelihoods = self.entropy_bottleneck(z) z_hat, z_likelihoods = self.entropy_bottleneck(z)
########### apply h_s ########### ########### apply h_s ###########
gaussian_params = z_hat params = z_hat
for i in range(len(self.h_s)): for i in range(len(self.h_s)):
conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i])) conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
gaussian_params = self.h_s[i](gaussian_params, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None) params = self.h_s[i](params, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
scales_hat, means_hat = gaussian_params.chunk(2, -1)
# print("applying h_s") # print("applying h_s")
# print("scales_hat.mean()=", scales_hat.mean(), means_hat.mean()=", means_hat.mean(), "z_hat.mean()=", z_hat.mean()) # print("params.mean()=", params.mean(), means_hat.mean()=", means_hat.mean(), "z_hat.mean()=", z_hat.mean())
# print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min()) # print("params.max()=", params.max(), "params.min()=", params.min())
# print("means_hat.max()=", means_hat.max(), "means_hat.min()=", means_hat.min()) # print("means_hat.max()=", means_hat.max(), "means_hat.min()=", means_hat.min())
# print("z_hat.max()=", z_hat.max(), "z_hat.min()=", z_hat.min()) # print("z_hat.max()=", z_hat.max(), "z_hat.min()=", z_hat.min())
###### apply context model ######
if self.context_model:
y_hat = self.gaussian_conditional.quantize(y, 'noise' if self.gaussian_conditional.training else 'dequantize')
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
phi_sp = self.autoregressive(y_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
gaussian_params = self.combine_ar_hp(torch.cat([params, phi_sp], dim=-1))
scales_hat, means_hat = gaussian_params.chunk(2, -1)
# print("applying context model")
# print("scales_hat.mean()=", scales_hat.mean(), "y_hat.mean()=", y_hat.mean())
# print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min())
# print("y_hat.max()=", y_hat.max(), "y_hat.min()=", y_hat.min())
else:
scales_hat, means_hat = params.chunk(2, -1)
y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
########### apply g_s ########### ########### apply g_s ###########
x_hat = y_hat x_hat = y_hat
for i in range(len(self.g_s)): for i in range(len(self.g_s)):
...@@ -95,6 +127,12 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior): ...@@ -95,6 +127,12 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior):
} }
def compress(self, x, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None): def compress(self, x, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None):
if self.context_model:
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU).",
stacklevel=2)
data_res = res if patch_res is None else (res, patch_res) data_res = res if patch_res is None else (res, patch_res)
########### apply g_a ########### ########### apply g_a ###########
y = x y = x
...@@ -107,7 +145,6 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior): ...@@ -107,7 +145,6 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior):
index_, weight_, valid_index_ = None, None, None index_, weight_, valid_index_ = None, None, None
y = self.g_a[i](y, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None, y = self.g_a[i](y, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
index_=index_, weight_=weight_, valid_index_=valid_index_) index_=index_, weight_=weight_, valid_index_=valid_index_)
########### apply h_a ########### ########### apply h_a ###########
z = y z = y
for i in range(len(self.h_a)): for i in range(len(self.h_a)):
...@@ -116,30 +153,117 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior): ...@@ -116,30 +153,117 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior):
z_strings = self.entropy_bottleneck.compress(z) z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[1]) z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[1])
########### apply h_s ########### ########### apply h_s ###########
gaussian_params = z_hat params = z_hat
for i in range(len(self.h_s)): for i in range(len(self.h_s)):
conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i])) conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
gaussian_params = self.h_s[i](gaussian_params, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None) params = self.h_s[i](params, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
scales_hat, means_hat = gaussian_params.chunk(2, -1) if not self.context_model:
indexes = self.gaussian_conditional.build_indexes(scales_hat) scales_hat, means_hat = params.chunk(2, -1)
y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat)
else:
y_strings = []
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
for i in range(y.size(0)):
string = self._compress_ar(
y[i:i+1],
params[i:i+1],
dict_index[conv_res],
dict_weight[conv_res],
self.autoregressive.skipconn,
)
y_strings.append(string)
return {"strings": [y_strings, z_strings], "shape": z.size()[1]} return {"strings": [y_strings, z_strings], "shape": z.size()[1]}
def _compress_ar(self, y_hat, params, neighbors_indices, neighbors_weights, skipconn):
assert len(self.autoregressive.list_conv) == 2, "only 2-hop convolution supported"
neighbors_indices = neighbors_indices.to(y_hat.device)
neighbors_weights = neighbors_weights.to(y_hat.device)
cdf = self.gaussian_conditional._quantized_cdf.tolist()
cdf_lengths = self.gaussian_conditional._cdf_length.tolist()
offsets = self.gaussian_conditional._offset.tolist()
encoder = BufferedRansEncoder()
symbols_list = []
indexes_list = []
n_nodes = y_hat.size(1)
weights = [conv.weight for conv in self.autoregressive.list_conv]
bias = [conv.bias for conv in self.autoregressive.list_conv]
for n in range(n_nodes):
ctx_p = conv_hop_2_node_n(
y_hat,
weights,
bias,
n,
neighbors_indices,
neighbors_weights,
skipconn,
masked=True,
)
if self.arrelu:
ctx_p = torch.nn.functional.relu(ctx_p)
# 1x1 conv for the entropy parameters prediction network, so
# we only keep the elements in the "center"
p = params[:, n:n+1, :]
gaussian_params = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1))
gaussian_params = gaussian_params.squeeze(1)
scales_hat, means_hat = gaussian_params.chunk(2, -1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_q = self.gaussian_conditional.quantize(y_hat[:, n, :], "symbols", means_hat)
y_hat[:, n, :] = y_q + means_hat
symbols_list.extend(y_q.squeeze().tolist())
indexes_list.extend(indexes.squeeze().tolist())
encoder.encode_with_indexes(
symbols_list,
indexes_list,
cdf,
cdf_lengths,
offsets,
)
string = encoder.flush()
return string
def decompress(self, strings, shape, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None): def decompress(self, strings, shape, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None):
assert isinstance(strings, list) and len(strings) == 2 assert isinstance(strings, list) and len(strings) == 2
if self.context_model:
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU).",
stacklevel=2,
)
z_hat = self.entropy_bottleneck.decompress(strings[1], shape) z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
data_res = res if patch_res is None else (res, patch_res) data_res = res if patch_res is None else (res, patch_res)
########### apply h_s ########### ########### apply h_s ###########
gaussian_params = z_hat params = z_hat
for i in range(len(self.h_s)): for i in range(len(self.h_s)):
conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i])) conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
gaussian_params = self.h_s[i](gaussian_params, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None) params = self.h_s[i](params, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
scales_hat, means_hat = gaussian_params.chunk(2, -1) if self.context_model:
indexes = self.gaussian_conditional.build_indexes(scales_hat) s = 4**2 # scaling factor between z and y # TODO calc from offsets
y_hat = self.gaussian_conditional.decompress(strings[0], indexes, means=means_hat) y_hat = torch.zeros(z_hat.size(0), s*shape, self.M, dtype=z_hat.dtype, device=z_hat.device)
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
for i, y_string in enumerate(strings[0]):
self._decompress_ar(
y_string,
y_hat[i:i+1],
params[i:i+1],
dict_index[conv_res],
dict_weight[conv_res],
self.autoregressive.skipconn,
)
else:
scales_hat, means_hat = params.chunk(2, -1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_hat = self.gaussian_conditional.decompress(strings[0], indexes, means=means_hat)
########### apply g_s ########### ########### apply g_s ###########
x_hat = y_hat x_hat = y_hat
...@@ -153,4 +277,47 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior): ...@@ -153,4 +277,47 @@ class SphereMeanScaleHyperprior(SphereScaleHyperprior):
x_hat = self.g_s[i](x_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None, x_hat = self.g_s[i](x_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
index_=index_, weight_=weight_, valid_index_=valid_index_) index_=index_, weight_=weight_, valid_index_=valid_index_)
x_hat = x_hat.clamp_(0, 1) x_hat = x_hat.clamp_(0, 1)
return {"x_hat": x_hat} return {"x_hat": x_hat}
\ No newline at end of file
def _decompress_ar(self, y_string, y_hat, params, neighbors_indices, neighbors_weights, skipconn):
cdf = self.gaussian_conditional._quantized_cdf.tolist()
cdf_lengths = self.gaussian_conditional._cdf_length.tolist()
offsets = self.gaussian_conditional._offset.tolist()
neighbors_indices = neighbors_indices.to(y_hat.device)
neighbors_weights = neighbors_weights.to(y_hat.device)
decoder = RansDecoder()
decoder.set_stream(y_string)
n_nodes = y_hat.size(1)
weights = [conv.weight for conv in self.autoregressive.list_conv]
bias = [conv.bias for conv in self.autoregressive.list_conv]
# Warning: this is slow due to the auto-regressive nature of the decoding
for n in range(n_nodes):
ctx_p = conv_hop_2_node_n(
y_hat,
weights,
bias,
n,
neighbors_indices,
neighbors_weights,
skipconn,
masked=False,
)
# 1x1 conv for the entropy parameters prediction network, so
# we only keep the elements in the "center"
p = params[:, n:n+1, :]
gaussian_params = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1))
scales_hat, means_hat = gaussian_params.chunk(2, -1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
rv = decoder.decode_stream(
indexes.squeeze().tolist(),
cdf,
cdf_lengths,
offsets,
)
rv = torch.Tensor(rv).reshape(1, 1, -1)
rv = self.gaussian_conditional.dequantize(rv, means_hat)
y_hat[:, n:n+1, :] = rv
\ No newline at end of file
...@@ -171,14 +171,13 @@ class SphereScaleHyperprior(SphereCompressionModel): ...@@ -171,14 +171,13 @@ class SphereScaleHyperprior(SphereCompressionModel):
# print("scales_hat.mean()=", scales_hat.mean(), "z_hat.mean()=", z_hat.mean()) # print("scales_hat.mean()=", scales_hat.mean(), "z_hat.mean()=", z_hat.mean())
# print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min()) # print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min())
# print("z_hat.max()=", z_hat.max(), "z_hat.min()=", z_hat.min()) # print("z_hat.max()=", z_hat.max(), "z_hat.min()=", z_hat.min())
###### apply context model ######
skip_quant = bool(self.context_model) skip_quant = bool(self.context_model)
if skip_quant: # take quantization out of self.gaussian_conditional.forward() if skip_quant: # take quantization out of self.gaussian_conditional.forward()
y_hat = self.gaussian_conditional.quantize(torch.abs(y) if self.arabs else y, 'noise' if self.gaussian_conditional.training else 'dequantize', means=None) # TODO add mean if needed y_hat = self.gaussian_conditional.quantize(torch.abs(y) if self.arabs else y, 'noise' if self.gaussian_conditional.training else 'dequantize', means=None)
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0])) conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
phi_sp = self.autoregressive(y_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None) phi_sp = self.autoregressive(y_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
scales_hat = self.combine_ar_hp(torch.cat([scales_hat, phi_sp], dim=-1)) scales_hat = self.combine_ar_hp(torch.cat([scales_hat, phi_sp], dim=-1))
# TODO meanHyperprior: scales_hat, means_hat = scales_hat.chunk(2, -1)
# print("applying context model") # print("applying context model")
# print("scales_hat.mean()=", scales_hat.mean(), "y_hat.mean()=", y_hat.mean()) # print("scales_hat.mean()=", scales_hat.mean(), "y_hat.mean()=", y_hat.mean())
# print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min()) # print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min())
...@@ -308,12 +307,11 @@ class SphereScaleHyperprior(SphereCompressionModel): ...@@ -308,12 +307,11 @@ class SphereScaleHyperprior(SphereCompressionModel):
p = params[:, n:n+1, :] p = params[:, n:n+1, :]
scales_hat = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1)) scales_hat = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1))
scales_hat = scales_hat.squeeze(1) scales_hat = scales_hat.squeeze(1)
# TODO for means: scales_hat, means_hat = gaussian_params.chunk(2, 1)
indexes = self.gaussian_conditional.build_indexes(scales_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_q = self.gaussian_conditional.quantize(y_hat[:, n, :], "symbols") # TODO if means: , means_hat) y_q = self.gaussian_conditional.quantize(y_hat[:, n, :], "symbols")
y_hat[:, n, :] = y_q # TODO if means: + means_hat y_hat[:, n, :] = y_q
symbols_list.extend(y_q.squeeze().tolist()) symbols_list.extend(y_q.squeeze().tolist())
indexes_list.extend(indexes.squeeze().tolist()) indexes_list.extend(indexes.squeeze().tolist())
...@@ -407,7 +405,6 @@ class SphereScaleHyperprior(SphereCompressionModel): ...@@ -407,7 +405,6 @@ class SphereScaleHyperprior(SphereCompressionModel):
# we only keep the elements in the "center" # we only keep the elements in the "center"
p = params[:, n:n+1, :] p = params[:, n:n+1, :]
scales_hat = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1)) scales_hat = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1))
# TODO for mean: scales_hat, means_hat = gaussian_params.chunk(2, 1)
indexes = self.gaussian_conditional.build_indexes(scales_hat) indexes = self.gaussian_conditional.build_indexes(scales_hat)
rv = decoder.decode_stream( rv = decoder.decode_stream(
...@@ -417,7 +414,7 @@ class SphereScaleHyperprior(SphereCompressionModel): ...@@ -417,7 +414,7 @@ class SphereScaleHyperprior(SphereCompressionModel):
offsets, offsets,
) )
rv = torch.Tensor(rv).reshape(1, 1, -1) rv = torch.Tensor(rv).reshape(1, 1, -1)
rv = self.gaussian_conditional.dequantize(rv) # TODO add mean: , means_hat) rv = self.gaussian_conditional.dequantize(rv)
y_hat[:, n:n+1, :] = rv y_hat[:, n:n+1, :] = rv
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment