Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 2730e99a authored by PaulWawerek-L's avatar PaulWawerek-L
Browse files

corrected cm_ar (bias and convolution)

parent 13b5bb73
No related branches found
No related tags found
1 merge request!2OSLO-IC
......@@ -88,7 +88,7 @@ class SphereScaleHyperprior(SphereCompressionModel):
if context_model:
mask = 'full' if context_model=='full' else 'A'
self.autoregressive = SLB_Downsample(conv_name, M, M, hop=2, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True}, mask=mask)
activation=None, mask=mask)
self.combine_ar_hp = torch.nn.Sequential(
SLB_Downsample('SDPAConv', 2*M, M+256, hop=1, skip_conn_aggr=skip_conn_aggr,
activation="ReLU", activation_args={"inplace": True}, conv1x1=True),
......@@ -279,11 +279,14 @@ class SphereScaleHyperprior(SphereCompressionModel):
# buffer for SDPA conv, should only be calculated at current node n
xs = [torch.zeros(y_hat.size(0), n_neighbors+1 if i==0 else 1, convs[i].weight.size(-1), dtype=y_hat.dtype, device=y_hat.device) for i in range(len(convs))]
for k, k_ind in enumerate([n]+neighbors_indices[n].tolist()): # iteration over current node and its neighbors
if k_ind > n: continue
y_neighbors = torch.mul(y_hat[:, neighbors_indices[k_ind].tolist(), :], neighbors_weights[k_ind].view(1,-1,1))
xs[0][:, k, :] = torch.matmul(y_neighbors.flatten(1,2), convs[0].weight.data[1:].flatten(0,1))
if convs[0].bias is not None:
xs[0] += convs[0].bias.data
# compute second convolution at current node
xs[1][:, 0, :] = torch.matmul(xs[0][:, 1:, :].flatten(1,2), convs[1].weight.data[1:].flatten(0,1))
if convs[1].bias is not None:
xs[1] += convs[1].bias.data
# only keep current node
xs[0] = xs[0][:, 0:1, :]
ctx_p = skipconn(xs) if skipconn is not None else xs[-1]
......@@ -335,8 +338,7 @@ class SphereScaleHyperprior(SphereCompressionModel):
scales_hat = self.h_s[i](scales_hat, dict_index[conv_res], dict_weight[conv_res], valid_index=dict_valid_index[conv_res] if dict_valid_index is not None else None)
if self.context_model:
s = 4**2 # scaling factor between z and y
n_nodes = z_hat.size(1) * s
y_hat = torch.zeros(z_hat.size(0), n_nodes, self.M, dtype=z_hat.dtype, device=z_hat.device)
y_hat = torch.zeros(z_hat.size(0), s*shape, self.M, dtype=z_hat.dtype, device=z_hat.device)
conv_res = type(data_res)(np.add(data_res, self._g_s_offset[0]))
for i, y_string in enumerate(strings[0]):
self._decompress_ar(
......@@ -381,11 +383,14 @@ class SphereScaleHyperprior(SphereCompressionModel):
for n in range(n_nodes):
xs = [torch.zeros(y_hat.size(0), n_neighbors+1 if i==0 else 1, convs[i].weight.size(-1), dtype=y_hat.dtype, device=y_hat.device) for i in range(len(convs))]
for k, k_ind in enumerate([n]+neighbors_indices[n].tolist()): # iteration over current node and its neighbors
if k_ind > n: continue
y_neighbors = torch.mul(y_hat[:, neighbors_indices[k_ind].tolist(), :], neighbors_weights[k_ind].view(1,-1,1))
xs[0][:, k, :] = torch.matmul(y_neighbors.flatten(1,2), convs[0].weight.data[1:].flatten(0,1))
if convs[0].bias is not None:
xs[0] += convs[0].bias.data
# compute second convolution at current node
xs[1][:, 0, :] = torch.matmul(xs[0][:, 1:, :].flatten(1,2), convs[1].weight.data[1:].flatten(0,1))
if convs[1].bias is not None:
xs[1] += convs[1].bias.data
# only keep current node
xs[0] = xs[0][:, 0:1, :]
ctx_p = skipconn(xs) if skipconn is not None else xs[-1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment