Source code for sc2bench.transforms.collator

import torch
from torchdistill.datasets.collator import register_collate_func


[docs]def cat_list(images, fill_value=0): """ Concatenates a list of images with the max size for each of heights and widths and fills empty spaces with a specified value. :param images: batch tensor :type images: torch.Tensor :param fill_value: value to be filled :type fill_value: int :return: backbone model :rtype: torch.Tensor """ if len(images) == 1 and not isinstance(images[0], torch.Tensor): return images max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) batch_shape = (len(images),) + max_size batched_imgs = images[0].new(*batch_shape).fill_(fill_value) for img, pad_img in zip(images, batched_imgs): pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) return batched_imgs
[docs]@register_collate_func def pascal_seg_collate_fn(batch): """ Collates input data for PASCAL VOC 2012 segmentation. :param batch: list/tuple of triplets (image, target, supp_dict), where supp_dict can be an empty dict :type batch: list or tuple :return: collated images, targets, and supplementary dicts :rtype: (torch.Tensor, tensor.Tensor, list[dict]) """ images, targets, supp_dicts = list(zip(*batch)) batched_imgs = cat_list(images, fill_value=0) batched_targets = cat_list(targets, fill_value=255) return batched_imgs, batched_targets, supp_dicts
[docs]@register_collate_func def pascal_seg_eval_collate_fn(batch): """ Collates input data for PASCAL VOC 2012 segmentation in evaluation :param batch: list/tuple of tuples (image, target) :type batch: list or tuple :return: collated images and targets :rtype: (torch.Tensor, tensor.Tensor) """ images, targets = list(zip(*batch)) batched_imgs = cat_list(images, fill_value=0) batched_targets = cat_list(targets, fill_value=255) return batched_imgs, batched_targets