Source code for sc2bench.loss

from torch import nn
from torchdistill.losses.mid_level import register_mid_level_loss


[docs] @register_mid_level_loss class BppLoss(nn.Module): """ Bit-per-pixel (or rate) loss. :param entropy_module_path: entropy module path to extract its output from io_dict :type entropy_module_path: str :param reduction: reduction type ('sum', 'batchmean', or 'mean') :type reduction: str or None """ def __init__(self, entropy_module_path, reduction='mean'): super().__init__() self.entropy_module_path = entropy_module_path self.reduction = reduction
[docs] def forward(self, student_io_dict, *args, **kwargs): """ Computes a rate loss. :param student_io_dict: io_dict of model to be trained :type student_io_dict: dict """ entropy_module_dict = student_io_dict[self.entropy_module_path] intermediate_features, likelihoods = entropy_module_dict['output'] n, _, h, w = intermediate_features.shape num_pixels = n * h * w if self.reduction == 'sum': bpp = -likelihoods.log2().sum() elif self.reduction == 'batchmean': bpp = -likelihoods.log2().sum() / n else: bpp = -likelihoods.log2().sum() / num_pixels return bpp