Source code for sc2bench.models.detection.rcnn

import torch
from torch.hub import load_state_dict_from_url
from torchdistill.common.main_util import load_ckpt
from torchvision.models.detection._utils import overwrite_eps
from torchvision.models.detection.faster_rcnn import FasterRCNN, model_urls as faster_rcnn_model_urls
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import LastLevelMaxPool

from .base import UpdatableDetectionModel, UpdatableBackboneWithFPN
from .registry import register_detection_model_func
from ..backbone import check_if_updatable
from ..registry import load_classification_model
from ...analysis import check_if_analyzable


[docs]class BaseRCNN(GeneralizedRCNN, UpdatableDetectionModel): """ A base, updatable R-CNN model. :param rcnn_model: backbone model (usually a classification model) :type rcnn_model: nn.Module :param analysis_config: analysis configuration :type analysis_config: dict or None """ # Referred to https://github.com/pytorch/vision/blob/main/torchvision/models/detection/generalized_rcnn.py def __init__(self, rcnn_model, analysis_config=None): if analysis_config is None: analysis_config = dict() UpdatableDetectionModel.__init__(self, analysis_config.get('analyzer_configs', list())) GeneralizedRCNN.__init__(self, rcnn_model.backbone, rcnn_model.rpn, rcnn_model.roi_heads, rcnn_model.transform)
[docs] def update(self, **kwargs): """ Updates compression-specific parameters like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.update>`_. Needs to be called once after training to be able to later perform the evaluation with an actual entropy coder. """ if not check_if_updatable(self.backbone.body): raise KeyError(f'`backbone` {type(self)} is not updatable') self.backbone.body.update()
[docs] def get_aux_module(self, **kwargs): """ Returns an auxiliary module to compute auxiliary loss if necessary like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.aux_loss>`_. :return: auxiliary module :rtype: nn.Module """ return self.backbone.body.get_aux_module()
[docs] def activate_analysis(self): """ Activates the analysis mode. Should be called after training model. """ self.activated_analysis = True if check_if_analyzable(self.backbone.body): self.backbone.body.activate_analysis()
[docs] def deactivate_analysis(self): """ Deactivates the analysis mode. """ self.activated_analysis = False self.backbone.body.deactivate_analysis() if check_if_analyzable(self.backbone.body): self.backbone.body.deactivate_analysis()
[docs] def analyze(self, compressed_obj): """ Analyzes a given compressed object (e.g., file size of the compressed object). :param compressed_obj: compressed object to be analyzed :type compressed_obj: Any """ if not self.activated_analysis: return for analyzer in self.analyzers: analyzer.analyze(compressed_obj) if check_if_analyzable(self.backbone.body): self.backbone.body.analyze(compressed_obj)
[docs] def summarize(self): """ Summarizes the results that the configured analyzers store. """ for analyzer in self.analyzers: analyzer.summarize() if check_if_analyzable(self.backbone.body): self.backbone.body.summarize()
[docs] def clear_analysis(self): """ Clears the results that the configured analyzers store. """ for analyzer in self.analyzers: analyzer.clear() if check_if_analyzable(self.backbone.body): self.backbone.body.clear_analysis()
[docs]def create_faster_rcnn_fpn(backbone, extra_blocks=None, return_layer_dict=None, in_channels_list=None, in_channels_stage2=None, out_channels=256, returned_layers=None, num_classes=91, analysis_config=None, analyzable_layer_key=None, **kwargs): """ Builds Faster R-CNN model using a given updatable backbone model. :param backbone: backbone model (usually a classification model) :type backbone: nn.Module :param extra_blocks: if provided, extra operations will be performed. It is expected to take the fpn features, the original features and the names of the original features as input, and returns a new list of feature maps and their corresponding names :type extra_blocks: ExtraFPNBlock or None :param return_layer_dict: mapping from name of module to return its output to a specified key :type return_layer_dict: dict :param in_channels_list: number of channels for each feature map that is passed to the module for FPN :type in_channels_list: list[int] or None :param in_channels_stage2: base number of channels used to define `in_channels_list` if `in_channels_list` is `None` :type in_channels_stage2: int or None :param out_channels: number of channels of the FPN representation :type out_channels: int :param returned_layers: list of layer numbers to define `return_layer_dict` if `return_layer_dict` is `None` :type returned_layers: list[int] or None :param num_classes: number of output classes of the model (including the background) :type num_classes: int :param analysis_config: analysis configuration :type analysis_config: dict or None :param analyzable_layer_key: key of analyzable layer :type analyzable_layer_key: str or None :return: Faster R-CNN model with backbone model with FPN :rtype: torchvision.models.detection.faster_rcnn.FasterRCNN """ if analysis_config is None: analysis_config = dict() if extra_blocks is None: extra_blocks = LastLevelMaxPool() if returned_layers is None: returned_layers = [1, 2, 3, 4] if return_layer_dict is None: return_layer_dict = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)} if in_channels_stage2 is None: in_channels_stage2 = backbone.inplanes // 8 if in_channels_list is None: in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] backbone_with_fpn = \ UpdatableBackboneWithFPN(backbone, return_layer_dict, in_channels_list, out_channels, extra_blocks=extra_blocks, analyzable_layer_key=analyzable_layer_key, **analysis_config) return FasterRCNN(backbone_with_fpn, num_classes, **kwargs)
def _process_torchvision_pretrained_weights(model, pretrained_backbone_name, progress): base_backbone_name = 'resnet50' if pretrained_backbone_name == 'mobilenet_v3_large_320': base_backbone_name = 'mobilenet_v3_large_320' elif pretrained_backbone_name == 'mobilenet_v3_large': base_backbone_name = 'mobilenet_v3_large' state_dict = \ load_state_dict_from_url(faster_rcnn_model_urls['fasterrcnn_{}_fpn_coco'.format(base_backbone_name)], progress=progress) model.load_state_dict(state_dict, strict=False) if pretrained_backbone_name == 'resnet50': overwrite_eps(model, 0.0)
[docs]@register_detection_model_func def faster_rcnn_model(backbone_config, pretrained=True, pretrained_backbone_name=None, progress=True, backbone_fpn_kwargs=None, num_classes=91, analysis_config=None, start_ckpt_file_path=None, **kwargs): """ Builds Faster R-CNN model. :param backbone_config: backbone configuration :type backbone_config: dict :param pretrained: if True, returns a model pre-trained on COCO train2017 (torchvision) :type pretrained: bool :param pretrained_backbone_name: pretrained backbone name such as `'resnet50'`, `'mobilenet_v3_large_320'`, and `'mobilenet_v3_large'` :type pretrained_backbone_name: str :param progress: if True, displays a progress bar of the download to stderr :type progress: bool :param backbone_fpn_kwargs: keyword arguments for `create_faster_rcnn_fpn` :type backbone_fpn_kwargs: dict :param num_classes: number of output classes of the model (including the background) :type num_classes: int :param analysis_config: analysis configuration :type analysis_config: dict or None :param start_ckpt_file_path: checkpoint file path to be loaded for the built Faster R-CNN model :type start_ckpt_file_path: str or None :return: Faster R-CNN model with splittable backbone model and FPN :rtype: BaseRCNN """ if backbone_fpn_kwargs is None: backbone_fpn_kwargs = dict() if analysis_config is None: analysis_config = dict() backbone_config['params']['norm_layer'] = misc_nn_ops.FrozenBatchNorm2d backbone = load_classification_model(backbone_config, torch.device('cpu'), False, strict=False) rcnn_model = create_faster_rcnn_fpn(backbone, num_classes=num_classes, **backbone_fpn_kwargs, **kwargs) model = BaseRCNN(rcnn_model, analysis_config=analysis_config) if pretrained and pretrained_backbone_name in ('resnet50', 'mobilenet_v3_large_320', 'mobilenet_v3_large'): _process_torchvision_pretrained_weights(model, pretrained_backbone_name, progress) if start_ckpt_file_path is not None: load_ckpt(start_ckpt_file_path, model=model, strict=False) return model