import torch.nn as nn from spherical_models import SphereEntropyBottleneck from compressai.models.utils import update_registered_buffers class SphereCompressionModel(nn.Module): """Base class for constructing an auto-encoder with at least one entropy bottleneck module. Args: entropy_bottleneck_channels (int): Number of channels of the entropy bottleneck """ def __init__(self, entropy_bottleneck_channels, init_weights=True): super().__init__() self.entropy_bottleneck = SphereEntropyBottleneck(entropy_bottleneck_channels) if init_weights: self._initialize_weights() def aux_loss(self): """Return the aggregated loss over the auxiliary entropy bottleneck module(s). """ aux_loss = sum(m.loss() for m in self.modules() if isinstance(m, SphereEntropyBottleneck)) return aux_loss def _initialize_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, *args): raise NotImplementedError() def update(self, force=False): """Updates the entropy bottleneck(s) CDF values. Needs to be called once after training to be able to later perform the evaluation with an actual entropy coder. Args: force (bool): overwrite previous values (default: False) Returns: updated (bool): True if one of the EntropyBottlenecks was updated. """ updated = False for m in self.children(): if not isinstance(m, SphereEntropyBottleneck): continue rv = m.update(force=force) updated |= rv return updated def load_state_dict(self, state_dict): # Dynamically update the entropy bottleneck buffers related to the CDFs update_registered_buffers( self.entropy_bottleneck, "entropy_bottleneck", ["_quantized_cdf", "_offset", "_cdf_length"], state_dict, ) super().load_state_dict(state_dict)