from collections import OrderedDict
from torch.nn import functional
from torchdistill.common.constant import def_logger
from ..backbone import check_if_updatable
from ...analysis import AnalyzableModule, check_if_analyzable
logger = def_logger.getChild(__name__)
[docs]class UpdatableSegmentationModel(AnalyzableModule):
"""
An abstract class for updatable semantic segmentation model.
:param analyzer_configs: list of analysis configurations
:type analyzer_configs: list[dict]
"""
def __init__(self, analyzer_configs=None):
super().__init__(analyzer_configs)
self.bottleneck_updated = False
def forward(self, *args, **kwargs):
raise NotImplementedError()
[docs] def update(self, **kwargs):
"""
Updates compression-specific parameters like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.update>`_.
This should be overridden by all subclasses.
"""
raise NotImplementedError()
[docs] def get_aux_module(self, **kwargs):
"""
Returns an auxiliary module to compute auxiliary loss if necessary like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.aux_loss>`_.
This should be overridden by all subclasses.
"""
raise NotImplementedError()
[docs]class BaseSegmentationModel(UpdatableSegmentationModel):
# Referred to the base implementation at https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/_utils.py
__constants__ = ['aux_classifier']
"""
A base, updatable segmentation model.
:param backbone: backbone model (usually a classification model)
:type backbone: nn.Module
:param classifier: classification model
:type classifier: nn.Module
:param aux_classifier: auxiliary classification model to be used during training only
:type aux_classifier: nn.Module or None
:param analysis_config: analysis configuration
:type analysis_config: dict or None
"""
def __init__(self, backbone, classifier, aux_classifier=None, analysis_config=None):
if analysis_config is None:
analysis_config = dict()
super().__init__(analysis_config.get('analyzer_configs', list()))
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier
def forward(self, x):
input_shape = x.shape[-2:]
# contract: features is a dict of tensors
features = self.backbone(x)
result = OrderedDict()
x = features['out']
x = self.classifier(x)
x = functional.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result['out'] = x
if self.aux_classifier is not None:
x = features['aux']
x = self.aux_classifier(x)
x = functional.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
result['aux'] = x
return result
[docs] def update(self, **kwargs):
"""
Updates compression-specific parameters like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.update>`_.
Needs to be called once after training to be able to later perform the evaluation with an actual entropy coder.
"""
if not check_if_updatable(self.backbone):
raise KeyError(f'`backbone` {type(self)} is not updatable')
self.backbone.update()
[docs] def get_aux_module(self, **kwargs):
"""
Returns an auxiliary module to compute auxiliary loss if necessary like `CompressAI models do <https://interdigitalinc.github.io/CompressAI/models.html#compressai.models.CompressionModel.aux_loss>`_.
:return: auxiliary module
:rtype: nn.Module
"""
return self.backbone.get_aux_module()
[docs] def activate_analysis(self):
"""
Activates the analysis mode.
Should be called after training model.
"""
self.activated_analysis = True
if check_if_analyzable(self.backbone):
self.backbone.activate_analysis()
[docs] def deactivate_analysis(self):
"""
Deactivates the analysis mode.
"""
self.activated_analysis = False
self.backbone.deactivate_analysis()
if check_if_analyzable(self.backbone):
self.backbone.deactivate_analysis()
[docs] def analyze(self, compressed_obj):
"""
Analyzes a given compressed object (e.g., file size of the compressed object).
:param compressed_obj: compressed object to be analyzed
:type compressed_obj: Any
"""
if not self.activated_analysis:
return
for analyzer in self.analyzers:
analyzer.analyze(compressed_obj)
if check_if_analyzable(self.backbone):
self.backbone.analyze(compressed_obj)
[docs] def summarize(self):
"""
Summarizes the results that the configured analyzers store.
"""
for analyzer in self.analyzers:
analyzer.summarize()
if check_if_analyzable(self.backbone):
self.backbone.summarize()
[docs] def clear_analysis(self):
"""
Clears the results that the configured analyzers store.
"""
for analyzer in self.analyzers:
analyzer.clear()
if check_if_analyzable(self.backbone):
self.backbone.clear_analysis()
[docs]def check_if_updatable_segmentation_model(model):
"""
Checks if the given semantic segmentation model is updatable.
:param model: semantic segmentation model
:type model: nn.Module
:return: True if the model is updatable, False otherwise
:rtype: bool
"""
return isinstance(model, UpdatableSegmentationModel)