import torch
from torch.hub import load_state_dict_from_url
from torchdistill.common.main_util import load_ckpt
from torchvision.models.segmentation.deeplabv3 import model_urls, DeepLabHead
from torchvision.models.segmentation.fcn import FCNHead
from .base import BaseSegmentationModel
from .registry import register_segmentation_model_func
from ..backbone import FeatureExtractionBackbone
from ..registry import load_classification_model
[docs]def create_deeplabv3(backbone, num_input_channels=2048, uses_aux=False, num_aux_channels=1024, num_classes=21):
"""
Builds DeepLabv3 model using a given updatable backbone model.
:param backbone: backbone model (usually a classification model)
:type backbone: nn.Module
:param num_input_channels: number of input channels for classification head
:type num_input_channels: int
:param uses_aux: If True, add an auxiliary branch
:type uses_aux: bool
:param num_aux_channels: number of input channels for auxiliary classification head
:type num_aux_channels: int
:param num_classes: number of output classes of the model (including the background)
:type num_classes: int
:return: DeepLabv3 model
:rtype: BaseSegmentationModel
"""
aux_classifier = None
if uses_aux:
aux_classifier = FCNHead(num_aux_channels, num_classes)
classifier = DeepLabHead(num_input_channels, num_classes)
return BaseSegmentationModel(backbone, classifier, aux_classifier)
[docs]@register_segmentation_model_func
def deeplabv3_model(backbone_config, pretrained=True, pretrained_backbone_name=None, progress=True,
num_input_channels=2048, uses_aux=False, num_aux_channels=1024, return_layer_dict=None,
num_classes=21, analysis_config=None, analyzable_layer_key=None, start_ckpt_file_path=None,
**kwargs):
"""
Builds DeepLabv3 model using a given updatable backbone model.
:param backbone_config: backbone configuration
:type backbone_config: dict
:param pretrained: if True, returns a model pre-trained on COCO train2017 (torchvision)
:type pretrained: bool
:param pretrained_backbone_name: pretrained backbone name such as
`'resnet50'`, `'resnet101'`, and `'mobilenet_v3_large'`
:type pretrained_backbone_name: str
:param progress: if True, displays a progress bar of the download to stderr
:type progress: bool
:param num_input_channels: number of input channels for classification head
:type num_input_channels: int
:param uses_aux: If True, add an auxiliary branch
:type uses_aux: bool
:param num_aux_channels: number of input channels for auxiliary classification head
:type num_aux_channels: int
:param return_layer_dict: mapping from name of module to return its output to a specified key
:type return_layer_dict: dict
:param num_classes: number of output classes of the model (including the background)
:type num_classes: int
:param analysis_config: analysis configuration
:type analysis_config: dict or None
:param analyzable_layer_key: key of analyzable layer
:type analyzable_layer_key: str or None
:param start_ckpt_file_path: checkpoint file path to be loaded for the built DeepLabv3 model
:type start_ckpt_file_path: str or None
:return: DeepLabv3 model with splittable backbone model
:rtype: BaseSegmentationModel
"""
if analysis_config is None:
analysis_config = dict()
if return_layer_dict is None:
return_layer_dict = {'layer4': 'out'}
if uses_aux:
return_layer_dict['layer3'] = 'aux'
backbone = load_classification_model(backbone_config, torch.device('cpu'), False, strict=False)
backbone_model = \
FeatureExtractionBackbone(backbone, return_layer_dict, analysis_config.get('analyzer_configs', list()),
analysis_config.get('analyzes_after_compress', False),
analyzable_layer_key=analyzable_layer_key)
model = create_deeplabv3(backbone_model, num_input_channels=num_input_channels,
uses_aux=uses_aux, num_aux_channels=num_aux_channels, num_classes=num_classes)
if pretrained and pretrained_backbone_name in ('resnet50', 'resnet101'):
state_dict = \
load_state_dict_from_url(model_urls['deeplabv3_{}_coco'.format(pretrained_backbone_name)],
progress=progress)
model.load_state_dict(state_dict, strict=False)
if start_ckpt_file_path is not None:
load_ckpt(start_ckpt_file_path, model=model, strict=False)
return model