from typing import Dict, Optional, List
from torch import nn, Tensor
from torchdistill.common.constant import def_logger
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
from ..backbone import FeatureExtractionBackbone
from ...analysis import AnalyzableModule
logger = def_logger.getChild(__name__)
[docs]
class UpdatableDetectionModel(AnalyzableModule):
"""
An abstract class for updatable object detection model.
:param analyzer_configs: list of analysis configurations
:type analyzer_configs: list[dict]
"""
def __init__(self, analyzer_configs=None):
super().__init__(analyzer_configs)
self.bottleneck_updated = False
def forward(self, *args, **kwargs):
raise NotImplementedError()
[docs]
def update(self, **kwargs):
"""
Updates compression-specific parameters like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.update>`_.
This should be overridden by all subclasses.
"""
raise NotImplementedError()
[docs]
def get_aux_module(self, **kwargs):
"""
Returns an auxiliary module to compute auxiliary loss if necessary like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.aux_loss>`_.
This should be overridden by all subclasses.
"""
raise NotImplementedError()
[docs]
class UpdatableBackboneWithFPN(UpdatableDetectionModel):
"""
An updatable backbone model with feature pyramid network (FPN).
:param backbone: backbone model (usually a classification model)
:type backbone: nn.Module
:param return_layer_dict: mapping from name of module to return its output to a specified key
:type return_layer_dict: dict
:param in_channels_list: number of channels for each feature map that is passed to the module for FPN
:type in_channels_list: list[int]
:param out_channels: number of channels of the FPN representation
:type out_channels: int
:param extra_blocks: if provided, extra operations will
be performed. It is expected to take the fpn features, the original
features and the names of the original features as input, and returns
a new list of feature maps and their corresponding names
:type extra_blocks: ExtraFPNBlock or None
:param analyzer_configs: list of analysis configurations
:type analyzer_configs: list[dict]
:param analyzes_after_compress: run analysis with `analyzer_configs` if True
:type analyzes_after_compress: bool
:param analyzable_layer_key: key of analyzable layer
:type analyzable_layer_key: str or None
"""
# Referred to https://github.com/pytorch/vision/blob/main/torchvision/models/detection/backbone_utils.py
def __init__(
self,
backbone: nn.Module,
return_layer_dict: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
analyzer_configs: List[Dict] = None,
analyzes_after_compress: bool = False,
analyzable_layer_key: str = None
) -> None:
super().__init__()
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
if analyzer_configs is None:
analyzer_configs = list()
self.body = FeatureExtractionBackbone(backbone, return_layer_dict=return_layer_dict,
analyzer_configs=analyzer_configs,
analyzes_after_compress=analyzes_after_compress,
analyzable_layer_key=analyzable_layer_key)
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=extra_blocks,
)
self.out_channels = out_channels
def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.body(x)
x = self.fpn(x)
return x
[docs]
def check_if_updatable(self):
"""
Checks if this module is updatable with respect to CompressAI modules.
:return: True if the model is updatable, False otherwise
:rtype: bool
"""
if self.analyzable_layer_key is None or self.analyzable_layer_key not in self._modules:
return False
return True
[docs]
def update(self):
"""
Updates compression-specific parameters like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.update>`_.
Needs to be called once after training to be able to later perform the evaluation with an actual entropy coder.
"""
self.body.update()
self.bottleneck_updated = True
[docs]
def get_aux_module(self):
"""
Returns an auxiliary module to compute auxiliary loss if necessary like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.aux_loss>`_.
:return: auxiliary module
:rtype: nn.Module
"""
return self.body.get_aux_module()
[docs]
def check_if_updatable_detection_model(model):
"""
Checks if the given object detection model is updatable.
:param model: object detection model
:type model: nn.Module
:return: True if the model is updatable, False otherwise
:rtype: bool
"""
return isinstance(model, UpdatableDetectionModel)