Source code for sc2bench.models.detection.wrapper
import torch
from torchdistill.common.main_util import load_ckpt
from .registry import load_detection_model
from .transform import RCNNTransformWithCompression
from ..registry import get_compression_model
from ..wrapper import register_wrapper_class, WRAPPER_CLASS_DICT
from ...analysis import AnalyzableModule, check_if_analyzable
[docs]
@register_wrapper_class
class InputCompressionDetectionModel(AnalyzableModule):
"""
A wrapper module for input compression model followed by a detection model.
:param detection_model: object detection model
:type detection_model: nn.Module
:param device: torch device
:type device: torch.device or str
:param codec_encoder_decoder: transform sequence configuration for codec
:type codec_encoder_decoder: nn.Module or None
:param compression_model: compression model
:type compression_model: nn.Module or None
:param uses_cpu4compression_model: whether to use CPU instead of GPU for `compression_model`
:type uses_cpu4compression_model: bool
:param pre_transform: pre-transform
:type pre_transform: nn.Module or None
:param post_transform: post-transform
:type post_transform: nn.Module or None
:param analysis_config: analysis configuration
:type analysis_config: dict or None
:param adaptive_pad_kwargs: keyword arguments for AdaptivePad
:type adaptive_pad_kwargs: dict or None
"""
def __init__(self, detection_model, device, codec_encoder_decoder=None, compression_model=None,
uses_cpu4compression_model=False, pre_transform=None, post_transform=None,
analysis_config=None, adaptive_pad_kwargs=None, **kwargs):
if analysis_config is None:
analysis_config = dict()
super().__init__()
detection_model.transform = \
RCNNTransformWithCompression(
detection_model.transform, device, codec_encoder_decoder,
analysis_config.get('analyzer_configs', list()),
analyzes_after_compress=analysis_config.get('analyzes_after_compress', False),
compression_model=compression_model, uses_cpu4compression_model=uses_cpu4compression_model,
pre_transform=pre_transform, post_transform=post_transform,
adaptive_pad_kwargs=adaptive_pad_kwargs
)
self.device = device
self.uses_cpu4compression_model = uses_cpu4compression_model
self.detection_model = detection_model
[docs]
def use_cpu4compression(self):
"""
Changes the device of the compression model to CPU.
"""
if self.uses_cpu4compression_model and self.detection_model.transform.compression_model is not None:
self.detection_model.transform.compression_model = self.detection_model.transform.compression_model.cpu()
def forward(self, x):
return self.detection_model(x)
[docs]
def activate_analysis(self):
self.activated_analysis = True
if check_if_analyzable(self.detection_model.transform):
self.detection_model.transform.activate_analysis()
[docs]
def deactivate_analysis(self):
self.activated_analysis = False
self.detection_model.transform.deactivate_analysis()
if check_if_analyzable(self.detection_model.transform):
self.detection_model.transform.deactivate_analysis()
[docs]
def analyze(self, compressed_obj):
if not self.activated_analysis:
return
for analyzer in self.analyzers:
analyzer.analyze(compressed_obj)
if check_if_analyzable(self.detection_model.transform):
self.detection_model.transform.analyze(compressed_obj)
[docs]
def summarize(self):
for analyzer in self.analyzers:
analyzer.summarize()
if check_if_analyzable(self.detection_model.transform):
self.detection_model.transform.summarize()
[docs]
def clear_analysis(self):
for analyzer in self.analyzers:
analyzer.clear()
if check_if_analyzable(self.detection_model.transform):
self.detection_model.transform.clear_analysis()
[docs]
def get_wrapped_detection_model(wrapper_model_config, device):
"""
Gets a wrapped object detection model.
:param wrapper_model_config: wrapper model configuration
:type wrapper_model_config: dict
:param device: torch device
:type device: torch.device
:return: wrapped object detection model
:rtype: nn.Module
"""
wrapper_model_name = wrapper_model_config['key']
if wrapper_model_name not in WRAPPER_CLASS_DICT:
raise ValueError('wrapper_model_name `{}` is not expected'.format(wrapper_model_name))
compression_model_config = wrapper_model_config.get('compression_model', None)
compression_model = get_compression_model(compression_model_config, device)
detection_model_config = wrapper_model_config['detection_model']
model = load_detection_model(detection_model_config, device)
wrapped_model = WRAPPER_CLASS_DICT[wrapper_model_name](model, compression_model=compression_model, device=device,
**wrapper_model_config['kwargs'])
src_ckpt_file_path = wrapper_model_config.get('src_ckpt', None)
if src_ckpt_file_path is not None:
load_ckpt(src_ckpt_file_path, model=wrapped_model, strict=False)
return wrapped_model