Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch.nn as nn
from spherical_models import SphereEntropyBottleneck
from compressai.models.utils import update_registered_buffers
class SphereCompressionModel(nn.Module):
"""Base class for constructing an auto-encoder with at least one entropy
bottleneck module.
Args:
entropy_bottleneck_channels (int): Number of channels of the entropy
bottleneck
"""
def __init__(self, entropy_bottleneck_channels, init_weights=True):
super().__init__()
self.entropy_bottleneck = SphereEntropyBottleneck(entropy_bottleneck_channels)
if init_weights:
self._initialize_weights()
def aux_loss(self):
"""Return the aggregated loss over the auxiliary entropy bottleneck
module(s).
"""
aux_loss = sum(m.loss() for m in self.modules() if isinstance(m, SphereEntropyBottleneck))
return aux_loss
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, *args):
raise NotImplementedError()
def update(self, force=False):
"""Updates the entropy bottleneck(s) CDF values.
Needs to be called once after training to be able to later perform the
evaluation with an actual entropy coder.
Args:
force (bool): overwrite previous values (default: False)
Returns:
updated (bool): True if one of the EntropyBottlenecks was updated.
"""
updated = False
for m in self.children():
if not isinstance(m, SphereEntropyBottleneck):
continue
rv = m.update(force=force)
updated |= rv
return updated
def load_state_dict(self, state_dict):
# Dynamically update the entropy bottleneck buffers related to the CDFs
update_registered_buffers(
self.entropy_bottleneck,
"entropy_bottleneck",
["_quantized_cdf", "_offset", "_cdf_length"],
state_dict,
)
super().load_state_dict(state_dict)