Source code for sc2bench.loss

from torch import nn
from torchdistill.losses.single import register_single_loss
from torchdistill.losses.util import register_func2extract_org_output


[docs]@register_func2extract_org_output def extract_org_loss_dict(org_criterion, student_outputs, teacher_outputs, targets, uses_teacher_output, **kwargs): """ Extracts loss(es) from student_outputs inside `TrainingBox` or `DistillationBox` in `torchdistill`. :param org_criterion: not used :type org_criterion: nn.Module :param student_outputs: student models' output :type student_outputs: dict or Any :param teacher_outputs: not used :type teacher_outputs: Any :param targets: not used :type targets: Any :param uses_teacher_output: not used :type uses_teacher_output: bool :return: original loss dict :rtype: class """ org_loss_dict = dict() if isinstance(student_outputs, dict): org_loss_dict.update(student_outputs) return org_loss_dict
[docs]@register_func2extract_org_output def extract_org_segment_loss(org_criterion, student_outputs, teacher_outputs, targets, uses_teacher_output, **kwargs): """ Computes loss(es) using the original loss module inside `TrainingBox` or `DistillationBox` in `torchdistill` for semantic segmentation models in `torchvision`. :param org_criterion: original loss module :type org_criterion: nn.Module :param student_outputs: student models' output :type student_outputs: dict or Any :param teacher_outputs: not used :type teacher_outputs: Any :param targets: targets :type targets: Any :param uses_teacher_output: not used :type uses_teacher_output: bool :return: original loss dict :rtype: class """ org_loss_dict = dict() if isinstance(student_outputs, dict): sub_loss_dict = dict() for key, outputs in student_outputs.items(): sub_loss_dict[key] = org_criterion(outputs, targets) org_loss = sub_loss_dict['out'] if 'aux' in sub_loss_dict: org_loss += 0.5 * sub_loss_dict['aux'] org_loss_dict['total'] = org_loss return org_loss_dict
[docs]@register_single_loss class BppLoss(nn.Module): """ Bit-per-pixel (or rate) loss. :param entropy_module_path: entropy module path to extract its output from io_dict :type entropy_module_path: str :param reduction: reduction type ('sum', 'batchmean', or 'mean') :type reduction: str or None """ def __init__(self, entropy_module_path, reduction='mean'): super().__init__() self.entropy_module_path = entropy_module_path self.reduction = reduction
[docs] def forward(self, student_io_dict, *args, **kwargs): """ Computes a rate loss. :param student_io_dict: io_dict of model to be trained :type student_io_dict: dict """ entropy_module_dict = student_io_dict[self.entropy_module_path] intermediate_features, likelihoods = entropy_module_dict['output'] n, _, h, w = intermediate_features.shape num_pixels = n * h * w if self.reduction == 'sum': bpp = -likelihoods.log2().sum() elif self.reduction == 'batchmean': bpp = -likelihoods.log2().sum() / n else: bpp = -likelihoods.log2().sum() / num_pixels return bpp