import collections
import numpy as np
import torch
from PIL.Image import Image
from torch import nn
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
from torchdistill.common import tensor_util
from torchdistill.datasets.registry import register_collate_func, register_transform
from torchvision.transforms import functional as F
from torchvision.transforms.functional import pad
MISC_TRANSFORM_MODULE_DICT = dict()
[docs]
@register_collate_func
def default_collate_w_pil(batch):
"""
Puts each data field into a tensor or PIL Image with outer dimension batch size.
:param batch: single batch to be collated
:return: collated batch
"""
# Extended `default_collate` function in PyTorch
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate_w_pil([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, (str, bytes)):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate_w_pil([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate_w_pil(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate_w_pil(samples) for samples in transposed]
elif isinstance(elem, Image):
return batch
raise TypeError(default_collate_err_msg_format.format(elem_type))
[docs]
@register_misc_transform_module
class AdaptivePad(nn.Module):
"""
A transform module that adaptively determines the size of padded sample.
:param fill: padded value
:type fill: int
:param padding_position: 'hw' (default) to pad left and right for padding horizontal size // 2 and top and
bottom for padding vertical size // 2; 'right_bottom' to pad bottom and right only
:type padding_position: str
:param padding_mode: padding mode passed to pad module
:type padding_mode: str
:param factor: factor value for the padded input sample
:type factor: int
:param returns_org_patch_size: if True, returns the patch size of the original input
:type returns_org_patch_size: bool
"""
def __init__(self, fill=0, padding_position='hw', padding_mode='constant',
factor=128, returns_org_patch_size=False):
super().__init__()
self.fill = fill
self.padding_position = padding_position
self.padding_mode = padding_mode
self.factor = factor
self.returns_org_patch_size = returns_org_patch_size
[docs]
def forward(self, x):
"""
Adaptively determines the size of padded image or image tensor.
:param x: image or image tensor
:type x: PIL.Image.Image or torch.Tensor
:return: padded image or image tensor, and the patch size of the input (height, width)
if returns_org_patch_size=True
:rtype: PIL.Image.Image or torch.Tensor or (PIL.Image.Image or torch.Tensor, list[int, int])
"""
height, width = x.shape[-2:]
vertical_pad_size = 0 if height % self.factor == 0 else int((height // self.factor + 1) * self.factor - height)
horizontal_pad_size = 0 if width % self.factor == 0 else int((width // self.factor + 1) * self.factor - width)
padded_vertical_size = vertical_pad_size + height
padded_horizontal_size = horizontal_pad_size + width
assert padded_vertical_size % self.factor == 0 and padded_horizontal_size % self.factor == 0, \
'padded vertical and horizontal sizes ({}, {}) should be ' \
'factor of {}'.format(padded_vertical_size, padded_horizontal_size, self.factor)
padding = [horizontal_pad_size // 2, vertical_pad_size // 2] if self.padding_position == 'equal_side' \
else [0, 0, horizontal_pad_size, vertical_pad_size]
x = pad(x, padding, self.fill, self.padding_mode)
if self.returns_org_patch_size:
return x, (height, width)
return x
[docs]
@register_misc_transform_module
class CustomToTensor(nn.Module):
"""
A customized ToTensor module that can be applied to sample and target selectively.
:param converts_sample: if True, applies to_tensor to sample
:type converts_sample: bool
:param converts_target: if True, applies torch.as_tensor to target
:type converts_target: bool
"""
def __init__(self, converts_sample=True, converts_target=True):
super().__init__()
self.converts_sample = converts_sample
self.converts_target = converts_target
def __call__(self, image, target):
if self.converts_sample:
image = F.to_tensor(image)
if self.converts_target:
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
[docs]
@register_misc_transform_module
class SimpleQuantizer(nn.Module):
"""
A module to quantize tensor with its half() function if num_bits=16 (FP16) or
Jacob et al.'s method if num_bits=8 (INT8 + one FP32 scale parameter).
Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: `"Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" <https://openaccess.thecvf.com/content_cvpr_2018/html/Jacob_Quantization_and_Training_CVPR_2018_paper.html>`_ @ CVPR 2018 (2018)
:param num_bits: number of bits for quantization
:type num_bits: int
"""
def __init__(self, num_bits):
super().__init__()
self.num_bits = num_bits
[docs]
def forward(self, z):
"""
Quantizes tensor.
:param z: tensor
:type z: torch.Tensor
:return: quantized tensor
:rtype: torch.Tensor or torchdistill.common.tensor_util.QuantizedTensor
"""
return z.half() if self.num_bits == 16 else tensor_util.quantize_tensor(z, self.num_bits)
[docs]
@register_misc_transform_module
class SimpleDequantizer(nn.Module):
"""
A module to dequantize quantized tensor in FP32. If num_bits=8, it uses Jacob et al.'s method.
Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: `"Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" <https://openaccess.thecvf.com/content_cvpr_2018/html/Jacob_Quantization_and_Training_CVPR_2018_paper.html>`_ @ CVPR 2018 (2018)
:param num_bits: number of bits used for quantization
:type num_bits: int
"""
def __init__(self, num_bits):
super().__init__()
self.num_bits = num_bits
[docs]
def forward(self, z):
"""
Dequantizes quantized tensor.
:param z: quantized tensor
:type z: torch.Tensor or torchdistill.common.tensor_util.QuantizedTensor
:return: dequantized tensor
:rtype: torch.Tensor
"""
return z.float() if self.num_bits == 16 else tensor_util.dequantize_tensor(z)