Newer
Older
import math
import torch
from spherical_models.compression_models import SphereCompressionModel
from spherical_models import SphereGaussianConditional
from compressai.models.utils import update_registered_buffers
from compressai.ans import BufferedRansEncoder, RansDecoder
from spherical_models import SLB_Downsample, SLB_Upsample
from spherical_models.sdpa_conv import sdpaconv_node_n
# From Balle's tensorflow compression examples
SCALES_MIN = 0.11
SCALES_MAX = 256
SCALES_LEVELS = 64
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS): # pylint: disable=W0622
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
class SphereScaleHyperprior(SphereCompressionModel):
r"""Scale Hyperprior model for spherical data.
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N, M, conv_name, skip_conn_aggr, pool_func, unpool_func, attention:bool=False, activation:str='GDN', single_conv:bool=False, context_model:str='', arrelu:bool=False, arabs:bool=False, arhops:int=2, **kwargs):
self.arrelu = arrelu
####################### g_a #######################
self.g_a = torch.nn.ModuleList()
self.g_a.append(SLB_Downsample(conv_name, 3, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
pool_func=pool_func, pool_size_sqrt=2, attention_pos=a_pos))
self.g_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
pool_func=pool_func, pool_size_sqrt=2))
# For the last layer there is no GDN anymore:
self.g_a.append(SLB_Downsample(conv_name, N, M, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
pool_func=pool_func, pool_size_sqrt=2, attention_pos=a_pos))
####################### g_s #######################
self.g_s = torch.nn.ModuleList()
self.g_s.append(SLB_Upsample(conv_name, M, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse":True},
unpool_func=unpool_func, unpool_size_sqrt=2, attention_pos=a_pos))
a_pos = (2 if attention else -1)
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse": True},
unpool_func=unpool_func, unpool_size_sqrt=2, attention_pos=a_pos))
self.g_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation=activation, activation_args={"inverse": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
# For the last layer there is no GDN anymore:
self.g_s.append(SLB_Upsample(conv_name, N, 3, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
unpool_func=unpool_func, unpool_size_sqrt=2))
####################### h_a #######################
self.h_a = torch.nn.ModuleList()
# effective hop=1 => num_conv = 1, and there is no downsampling
self.h_a.append(SLB_Downsample(conv_name, M, N, hop=1, single_conv=single_conv, skip_conn_aggr=None,
self.h_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True},
pool_func=pool_func, pool_size_sqrt=2))
# For the last layer there is no ReLu anymore:
self.h_a.append(SLB_Downsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
pool_func=pool_func, pool_size_sqrt=2))
####################### h_s #######################
self.h_s = torch.nn.ModuleList()
self.h_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
self.h_s.append(SLB_Upsample(conv_name, N, N, hop=2, single_conv=single_conv, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True},
unpool_func=unpool_func, unpool_size_sqrt=2))
# effective hop=1 => num_conv = 1, and there is no Upsampling
self.h_s.append(SLB_Upsample(conv_name, N, M, hop=1, single_conv=single_conv, skip_conn_aggr=None,
################## Context Model ##################
self.context_model = context_model
if context_model:
self.autoregressive = SLB_Downsample(conv_name, M, M, hop=arhops, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True}, mask=mask)
else: # default
self.autoregressive = SLB_Downsample(conv_name, M, M, hop=arhops, skip_conn_aggr=skip_conn_aggr,
self.combine_ar_hp = torch.nn.Sequential(
SLB_Downsample('SDPAConv', 2*M, M+256, hop=1, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True}, conv1x1=True),
SLB_Downsample('SDPAConv', M+256, M+128, hop=1, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True}, conv1x1=True),
SLB_Downsample('SDPAConv', M+128, M, hop=1, skip_conn_aggr=skip_conn_aggr, conv1x1=True)
)
###################################################
self.gaussian_conditional = SphereGaussianConditional(None)
self.N = int(N)
self.M = int(M)
self._computeResOffset()
def _computeResOffset(self):
# compute convolution resolution offset
g_a_output = list(np.cumsum([layerBlock.get_output_res_offset() for layerBlock in self.g_a]))
self._g_a_offset = [self.g_a[0].get_conv_input_res_offset()]
self._g_a_offset.extend([self.g_a[i].get_conv_input_res_offset() + g_a_output[i-1] for i in range(1, len(self.g_a))])
h_a_output = list(np.cumsum([layerBlock.get_output_res_offset() for layerBlock in self.h_a]))
h_a_output = [res+g_a_output[-1] for res in h_a_output]
self._h_a_offset = [self.h_a[0].get_conv_input_res_offset() + g_a_output[-1]]
self._h_a_offset.extend([self.h_a[i].get_conv_input_res_offset() + h_a_output[i-1] for i in range(1, len(self.h_a))])
h_s_output = list(np.cumsum([layerBlock.get_output_res_offset() for layerBlock in self.h_s]))
h_s_output = [res+h_a_output[-1] for res in h_s_output]
self._h_s_offset = [self.h_s[0].get_conv_input_res_offset() + h_a_output[-1]]
self._h_s_offset.extend([self.h_s[i].get_conv_input_res_offset() + h_s_output[i-1] for i in range(1, len(self.h_s))])
assert h_s_output[-1] == g_a_output[-1], "resolutions do not match"
g_s_output = list(np.cumsum([layerBlock.get_output_res_offset() for layerBlock in self.g_s]))
g_s_output = [res+g_a_output[-1] for res in g_s_output]
self._g_s_offset = [self.g_s[0].get_conv_input_res_offset() + g_a_output[-1]]
self._g_s_offset.extend([self.g_s[i].get_conv_input_res_offset() + g_s_output[i-1] for i in range(1, len(self.g_s))])
def get_resOffset(self):
return set(self._g_a_offset + self._h_a_offset + self._h_s_offset + self._g_s_offset)
def forward(self, x, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None): # x is a tensor of size [batch_size, num_nodes, num_features]
data_res = res if patch_res is None else (res, patch_res)
########### apply g_a ###########
y = x
for i in range(len(self.g_a)):
conv_res = type(data_res)(np.add(data_res, self._g_a_offset[i]))
if (self.activation.upper()=='RB') or (self.attention and (self.g_a[i].attention_pos > 0)): # for attention after convolution or RBs as nonlinearity
conv_res_ = type(data_res)(np.add(data_res, self._g_a_offset[i+1] if i<(len(self.g_a)-1) else self._g_s_offset[0]))
index_, weight_, valid_index_ = dict_index[conv_res_], dict_weight[conv_res_], dict_valid_index[conv_res_] if dict_valid_index is not None else None
else:
index_, weight_, valid_index_ = None, None, None
y = self.g_a[i](y, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
index_=index_, weight_=weight_, valid_index_=valid_index_)
# print("applying g_a")
# print("y.mean()=", y.mean(), "x.mean()=", x.mean())
# print("y.max()=", y.max(), "y.min()=", y.min())
# print("x.max()=", x.max(), "x.min()=", x.min())
########### apply h_a ###########
z = torch.abs(y)
for i in range(len(self.h_a)):
conv_res = type(data_res)(np.add(data_res, self._h_a_offset[i]))
z = self.h_a[i](z, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
# print("applying h_a")
# print("z.mean()=", z.mean(), "torch.abs(y).mean()=", torch.abs(y).mean())
# print("z.max()=", z.max(), "z.min()=", z.min())
# print("torch.abs(y).max()=", torch.abs(y).max(), "torch.abs(y).min()=", torch.abs(y).min())
z_hat, z_likelihoods = self.entropy_bottleneck(z)
########### apply h_s ###########
scales_hat = z_hat
for i in range(len(self.h_s)):
conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
scales_hat = self.h_s[i](scales_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
# print("applying h_s")
# print("scales_hat.mean()=", scales_hat.mean(), "z_hat.mean()=", z_hat.mean())
# print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min())
# print("z_hat.max()=", z_hat.max(), "z_hat.min()=", z_hat.min())
skip_quant = bool(self.context_model)
if skip_quant: # take quantization out of self.gaussian_conditional.forward()
y_hat = self.gaussian_conditional.quantize(torch.abs(y) if self.arabs else y, 'noise' if self.gaussian_conditional.training else 'dequantize', means=None)
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
phi_sp = self.autoregressive(y_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
scales_hat = self.combine_ar_hp(torch.cat([scales_hat, phi_sp], dim=-1))
# print("applying context model")
# print("scales_hat.mean()=", scales_hat.mean(), "y_hat.mean()=", y_hat.mean())
# print("scales_hat.max()=", scales_hat.max(), "scales_hat.min()=", scales_hat.min())
# print("y_hat.max()=", y_hat.max(), "y_hat.min()=", y_hat.min())
y_hat, y_likelihoods = self.gaussian_conditional(y if not skip_quant else y_hat, scales_hat, skip_quant=skip_quant)
########### apply g_s ###########
x_hat = y_hat
for i in range(len(self.g_s)):
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[i]))
if (self.activation.upper()=='RB') or (self.attention and (self.g_s[i].attention_pos > 0)): # for attention after convolution
conv_res_ = type(data_res)(np.add(data_res, self._g_s_offset[i+1] if i<(len(self.g_s)-1) else self._g_a_offset[0]))
index_, weight_, valid_index_ = dict_index[conv_res_], dict_weight[conv_res_], dict_valid_index[conv_res_] if dict_valid_index is not None else None
else:
index_, weight_, valid_index_ = None, None, None
x_hat = self.g_s[i](x_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
index_=index_, weight_=weight_, valid_index_=valid_index_)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# print("applying g_s")
# print("x_hat.mean()=", x_hat.mean(), "y_hat.mean()=", y_hat.mean())
# print("x_hat.max()=", x_hat.max(), "x_hat.min()=", x_hat.min())
# print("y_hat.max()=", y_hat.max(), "y_hat.min()=", y_hat.min())
# with torch.no_grad():
# print("input/out mean ratio=", x.mean()/x_hat.mean())
return {
'x_hat': x_hat,
'likelihoods': {
'y': y_likelihoods,
'z': z_likelihoods
},
}
def load_state_dict(self, state_dict):
update_registered_buffers(
self.gaussian_conditional, "gaussian_conditional",
["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
state_dict)
super().load_state_dict(state_dict)
def update(self, scale_table=None, force=False):
if scale_table is None:
scale_table = get_scale_table()
updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
updated |= super().update(force=force)
return updated
def compress(self, x, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None):
if self.context_model:
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU).",
stacklevel=2)
data_res = res if patch_res is None else (res, patch_res)
########### apply g_a ###########
y = x
for i in range(len(self.g_a)):
conv_res = type(data_res)(np.add(data_res, self._g_a_offset[i]))
if (self.activation.upper()=='RB') or (self.attention and (self.g_a[i].attention_pos > 0)):
conv_res_ = type(data_res)(np.add(data_res, self._g_a_offset[i+1] if i<(len(self.g_a)-1) else self._g_s_offset[0]))
index_, weight_, valid_index_ = dict_index[conv_res_], dict_weight[conv_res_], dict_valid_index[conv_res_] if dict_valid_index is not None else None
else:
index_, weight_, valid_index_ = None, None, None
y = self.g_a[i](y, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
index_=index_, weight_=weight_, valid_index_=valid_index_)
########### apply h_a ###########
z = torch.abs(y)
for i in range(len(self.h_a)):
conv_res = type(data_res)(np.add(data_res, self._h_a_offset[i]))
z = self.h_a[i](z, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[1])
########### apply h_s ###########
scales_hat = z_hat
for i in range(len(self.h_s)):
conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
scales_hat = self.h_s[i](scales_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
if not self.context_model:
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_strings = self.gaussian_conditional.compress(y, indexes)
else:
y_strings = []
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
for i in range(y.size(0)):
string = self._compress_ar(
y[i:i+1],
scales_hat[i:i+1],
dict_index[conv_res],
dict_weight[conv_res],
self.autoregressive.skipconn,
)
y_strings.append(string)
return {"strings": [y_strings, z_strings], "shape": z.size()[1]}
def _compress_ar(self, y_hat, params, neighbors_indices, neighbors_weights, skipconn):
assert len(self.autoregressive.list_conv) in [1,2], "only 1 or 2-hop convolution supported for autoregressive context model"
neighbors_indices = neighbors_indices.to(y_hat.device)
neighbors_weights = neighbors_weights.to(y_hat.device)
cdf = self.gaussian_conditional._quantized_cdf.tolist()
cdf_lengths = self.gaussian_conditional._cdf_length.tolist()
offsets = self.gaussian_conditional._offset.tolist()
encoder = BufferedRansEncoder()
symbols_list = []
indexes_list = []
n_nodes = y_hat.size(1)
weights = [conv.weight for conv in self.autoregressive.list_conv]
bias = [conv.bias for conv in self.autoregressive.list_conv]
# causal mask for convolution
for n in range(n_nodes):
torch.abs(y_hat) if self.arabs else y_hat,
weights,
bias,
n,
neighbors_indices,
neighbors_weights,
skipconn,
masked=True,
)
if self.arrelu:
ctx_p = torch.nn.functional.relu(ctx_p)
# 1x1 conv for the entropy parameters prediction network, so
# we only keep the elements in the "center"
p = params[:, n:n+1, :]
scales_hat = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1))
scales_hat = scales_hat.squeeze(1)
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_q = self.gaussian_conditional.quantize(y_hat[:, n, :], "symbols")
y_hat[:, n, :] = y_q
symbols_list.extend(y_q.squeeze().tolist())
indexes_list.extend(indexes.squeeze().tolist())
encoder.encode_with_indexes(
symbols_list,
indexes_list,
cdf,
cdf_lengths,
offsets,
)
string = encoder.flush()
return string
def decompress(self, strings, shape, dict_index, dict_weight, res, patch_res=None, dict_valid_index=None):
assert isinstance(strings, list) and len(strings) == 2
if self.context_model:
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU).",
stacklevel=2,
)
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
data_res = res if patch_res is None else (res, patch_res)
########### apply h_s ###########
scales_hat = z_hat
for i in range(len(self.h_s)):
conv_res = type(data_res)(np.add(data_res, self._h_s_offset[i]))
scales_hat = self.h_s[i](scales_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
if self.context_model:
s = 4**2 # scaling factor between z and y
y_hat = torch.zeros(z_hat.size(0), s*shape, self.M, dtype=z_hat.dtype, device=z_hat.device)
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
for i, y_string in enumerate(strings[0]):
self._decompress_ar(
y_string,
y_hat[i:i+1],
scales_hat[i:i+1],
dict_index[conv_res],
dict_weight[conv_res],
self.autoregressive.skipconn,
)
else:
indexes = self.gaussian_conditional.build_indexes(scales_hat)
y_hat = self.gaussian_conditional.decompress(strings[0], indexes)
########### apply g_s ###########
x_hat = y_hat
for i in range(len(self.g_s)):
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[i]))
if (self.activation.upper()=='RB') or (self.attention and (self.g_s[i].attention_pos > 0)):
conv_res_ = type(data_res)(np.add(data_res, self._g_s_offset[i+1] if i<(len(self.g_s)-1) else self._g_a_offset[0]))
index_, weight_, valid_index_ = dict_index[conv_res_], dict_weight[conv_res_], dict_valid_index[conv_res_] if dict_valid_index is not None else None
else:
index_, weight_, valid_index_ = None, None, None
x_hat = self.g_s[i](x_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None,
index_=index_, weight_=weight_, valid_index_=valid_index_)
def _decompress_ar(self, y_string, y_hat, params, neighbors_indices, neighbors_weights, skipconn):
cdf = self.gaussian_conditional._quantized_cdf.tolist()
cdf_lengths = self.gaussian_conditional._cdf_length.tolist()
offsets = self.gaussian_conditional._offset.tolist()
neighbors_indices = neighbors_indices.to(y_hat.device)
neighbors_weights = neighbors_weights.to(y_hat.device)
decoder = RansDecoder()
decoder.set_stream(y_string)
n_nodes = y_hat.size(1)
weights = [conv.weight for conv in self.autoregressive.list_conv]
bias = [conv.bias for conv in self.autoregressive.list_conv]
# Warning: this is slow due to the auto-regressive nature of the decoding
for n in range(n_nodes):
torch.abs(y_hat) if self.arabs else y_hat,
weights,
bias,
n,
neighbors_indices,
neighbors_weights,
skipconn,
masked=False,
)
if self.arrelu:
ctx_p = torch.nn.functional.relu(ctx_p)
# 1x1 conv for the entropy parameters prediction network, so
# we only keep the elements in the "center"
p = params[:, n:n+1, :]
scales_hat = self.combine_ar_hp(torch.cat((p, ctx_p), dim=-1))
indexes = self.gaussian_conditional.build_indexes(scales_hat)
rv = decoder.decode_stream(
indexes.squeeze().tolist(),
cdf,
cdf_lengths,
offsets,
)
rv = torch.Tensor(rv).reshape(1, 1, -1)
rv = self.gaussian_conditional.dequantize(rv)