Source code for sc2bench.models.detection.transform

from typing import List, Tuple, Dict, Optional

from torch import Tensor
from torchdistill.datasets.util import build_transform
from torchvision.models.detection.image_list import ImageList
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.transforms.functional import to_pil_image, to_tensor, crop

from ...analysis import AnalyzableModule
from ...transforms.misc import AdaptivePad


[docs]class RCNNTransformWithCompression(GeneralizedRCNNTransform, AnalyzableModule): """ An R-CNN Transform with codec-based or model-based compression. :param transform: performs the data transformation from the inputs to feed into the model :type transform: nn.Module :param device: torch device :type device: torch.device or str :param codec_params: codec parameters :type codec_params: dict :param analyzer_configs: list of analysis configurations :type analyzer_configs: list[dict] :param analyzes_after_compress: run analysis with `analyzer_configs` if True :type analyzes_after_compress: bool :param compression_model: compression model :type compression_model: nn.Module or None :param uses_cpu4compression_model: whether to use CPU instead of GPU for `comoression_model` :type uses_cpu4compression_model: bool :param pre_transform_params: pre-transform parameters :type pre_transform_params: dict or None :param post_transform_params: post-transform parameters :type post_transform_params: dict or None :param adaptive_pad_kwargs: keyword arguments for AdaptivePad :type adaptive_pad_kwargs: dict or None """ # Referred to https://github.com/pytorch/vision/blob/main/torchvision/models/detection/transform.py def __init__(self, transform, device, codec_params, analyzer_configs, analyzes_after_compress=False, compression_model=None, uses_cpu4compression_model=False, pre_transform_params=None, post_transform_params=None, adaptive_pad_kwargs=None): GeneralizedRCNNTransform.__init__(self, transform.min_size, transform.max_size, transform.image_mean, transform.image_std) AnalyzableModule.__init__(self, analyzer_configs) self.device = device self.codec_encoder_decoder = build_transform(codec_params) self.analyzes_after_compress = analyzes_after_compress self.pre_transform = build_transform(pre_transform_params) self.post_transform = build_transform(post_transform_params) if uses_cpu4compression_model: compression_model = compression_model.cpu() self.compression_model = compression_model self.uses_cpu4compression_model = uses_cpu4compression_model self.adaptive_pad = AdaptivePad(**adaptive_pad_kwargs) if isinstance(adaptive_pad_kwargs, dict) else None
[docs] def compress_by_codec(self, org_img): """ Convert a tensor to an image and compress-decompress it by codec. :param org_img: image tensor :type org_img: torch.Tensor :return: compressed-and-decompressed image tensor :rtype: torch.Tensor """ pil_img = to_pil_image(org_img, mode='RGB') pil_img, file_size = self.codec_encoder_decoder(pil_img) if not self.training: self.analyze(file_size) return to_tensor(pil_img).to(org_img.device)
[docs] def compress_by_model(self, org_img): """ Convert a tensor to an image and compress-decompress it by model. :param org_img: image tensor :type org_img: torch.Tensor :return: compressed-and-decompressed image tensor :rtype: torch.Tensor """ org_img = org_img.unsqueeze(0) org_height, org_width = None, None if self.adaptive_pad is not None: org_height, org_width = org_img.shape[-2:] org_img = self.adaptive_pad(org_img) compressed_obj = self.compression_model.compress(org_img) if not self.training and self.analyzes_after_compress: compressed_data = compressed_obj if org_height is None or org_width is None \ else (compressed_obj, org_height, org_width) self.analyze(compressed_data) decompressed_obj = self.compression_model.decompress(**compressed_obj) decompressed_obj = decompressed_obj['x_hat'] if org_height is not None and org_width is not None: decompressed_obj = crop(decompressed_obj, 0, 0, org_height, org_width) return decompressed_obj.squeeze(0)
[docs] def compress(self, org_img): """ Apply `pre_transform` to an image tensor, compress and decompress it, and apply `post_transform` to the compressed-decompressed image tensor. :param org_img: image tensor :type org_img: torch.Tensor :return: compressed-and-decompressed image tensor :rtype: torch.Tensor """ if self.pre_transform is not None: org_img = self.pre_transform(org_img) org_device = org_img.device if self.uses_cpu4compression_model: org_img = org_img.cpu() org_img = self.compress_by_codec(org_img) if self.compression_model is None else self.compress_by_model(org_img) if self.uses_cpu4compression_model: org_img = org_img.to(org_device) if self.post_transform is not None: org_img = self.post_transform(org_img) return org_img
def forward( self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]: images = [img for img in images] if targets is not None: # make a copy of targets to avoid modifying it in-place # once torchscript supports dict comprehension # this can be simplified as as follows # targets = [{k: v for k,v in t.items()} for t in targets] targets_copy: List[Dict[str, Tensor]] = [] for t in targets: data: Dict[str, Tensor] = {} for k, v in t.items(): data[k] = v targets_copy.append(data) targets = targets_copy for i in range(len(images)): image = images[i] target_index = targets[i] if targets is not None else None if image.dim() != 3: raise ValueError("images is expected to be a list of 3d tensors " "of shape [C, H, W], got {}".format(image.shape)) image, target_index = self.resize(image, target_index) shape_before_compression = image.shape image = self.compress(image) shape_after_compression = image.shape assert shape_after_compression == shape_before_compression, \ 'Compression should not change tensor shape {} -> {}'.format(shape_before_compression, shape_after_compression) image = self.normalize(image) images[i] = image if targets is not None and target_index is not None: targets[i] = target_index image_sizes = [img.shape[-2:] for img in images] images = self.batch_images(images) image_sizes_list: List[Tuple[int, int]] = [] for image_size in image_sizes: assert len(image_size) == 2 image_sizes_list.append((image_size[0], image_size[1])) image_list = ImageList(images, image_sizes_list) return image_list, targets