Source code for sc2bench.models.segmentation.wrapper
import torch
from torchdistill.common.main_util import load_ckpt
from torchvision.transforms.functional import crop
from .registry import load_segmentation_model
from ..registry import get_compression_model
from ..wrapper import register_wrapper_class, WRAPPER_CLASS_DICT
from ...analysis import AnalyzableModule
[docs]
@register_wrapper_class
class CodecInputCompressionSegmentationModel(AnalyzableModule):
"""
A wrapper module for codec input compression model followed by a segmentation model.
:param segmentation_model: semantic segmentation model
:type segmentation_model: nn.Module
:param device: torch device
:type device: torch.device or str
:param codec_encoder_decoder: transform sequence configuration for codec
:type codec_encoder_decoder: nn.Module or None
:param post_transform: post-transform
:type post_transform: nn.Module or None
:param analysis_config: analysis configuration
:type analysis_config: dict or None
"""
def __init__(self, segmentation_model, device, codec_encoder_decoder=None,
post_transform=None, analysis_config=None, **kwargs):
if analysis_config is None:
analysis_config = dict()
super().__init__(analysis_config.get('analyzer_configs', list()))
self.codec_encoder_decoder = codec_encoder_decoder
self.device = device
self.segmentation_model = segmentation_model
self.post_transform = post_transform
def forward(self, x):
tmp_list = list()
for sub_x in x:
if self.codec_encoder_decoder is not None:
sub_x, file_size = self.codec_encoder_decoder(sub_x)
if not self.training:
self.analyze(file_size)
if self.post_transform is not None:
sub_x = self.post_transform(sub_x)
tmp_list.append(sub_x.unsqueeze(0))
x = torch.hstack(tmp_list).to(self.device)
return self.segmentation_model(x)
[docs]
@register_wrapper_class
class NeuralInputCompressionSegmentationModel(AnalyzableModule):
"""
A wrapper module for neural input compression model followed by a segmentation model.
:param segmentation_model: semantic segmentation model
:type segmentation_model: nn.Module
:param pre_transform: pre-transform
:type pre_transform: nn.Module or None
:param compression_model: compression model
:type compression_model: nn.Module or None
:param uses_cpu4compression_model: whether to use CPU instead of GPU for `compression_model`
:type uses_cpu4compression_model: bool
:param post_transform: post-transform
:type post_transform: nn.Module or None
:param analysis_config: analysis configuration
:type analysis_config: dict or None
"""
def __init__(self, segmentation_model, pre_transform=None, compression_model=None,
uses_cpu4compression_model=False, post_transform=None, analysis_config=None, **kwargs):
if analysis_config is None:
analysis_config = dict()
super().__init__(analysis_config.get('analyzer_configs', list()))
self.analyzes_after_pre_transform = analysis_config.get('analyzes_after_pre_transform', False)
self.analyzes_after_compress = analysis_config.get('analyzes_after_compress', False)
self.pre_transform = pre_transform
self.compression_model = compression_model
self.uses_cpu4compression_model = uses_cpu4compression_model
self.segmentation_model = segmentation_model
self.post_transform = post_transform
[docs]
def use_cpu4compression(self):
"""
Changes the device of the compression model to CPU.
"""
if self.uses_cpu4compression_model and self.compression_model is not None:
self.compression_model = self.compression_model.cpu()
def forward(self, x):
org_patch_size = None
if self.pre_transform is not None:
x = self.pre_transform(x)
if isinstance(x, tuple) and len(x) == 2 and isinstance(x[1], tuple):
org_patch_size = x[1]
x = x[0]
if not self.training and self.analyzes_after_pre_transform:
self.analyze(x)
if self.compression_model is not None:
compressed_obj = self.compression_model.compress(x)
if not self.training and self.analyzes_after_compress:
compressed_data = compressed_obj if org_patch_size is None else (compressed_obj, org_patch_size)
self.analyze(compressed_data)
x = self.compression_model.decompress(**compressed_obj)
if isinstance(x, dict):
x = x['x_hat']
if self.post_transform is not None:
if org_patch_size is not None:
x = crop(x, 0, 0, org_patch_size[0], org_patch_size[1])
x = self.post_transform(x)
return self.segmentation_model(x)
[docs]
def get_wrapped_segmentation_model(wrapper_model_config, device):
"""
Gets a wrapped semantic segmentation model.
:param wrapper_model_config: wrapper model configuration
:type wrapper_model_config: dict
:param device: torch device
:type device: torch.device
:return: wrapped semantic segmentation model
:rtype: nn.Module
"""
wrapper_model_name = wrapper_model_config['key']
if wrapper_model_name not in WRAPPER_CLASS_DICT:
raise ValueError('wrapper_model_name `{}` is not expected'.format(wrapper_model_name))
compression_model_config = wrapper_model_config.get('compression_model', None)
compression_model = get_compression_model(compression_model_config, device)
segmentation_model_config = wrapper_model_config['segmentation_model']
model = load_segmentation_model(segmentation_model_config, device)
wrapped_model = WRAPPER_CLASS_DICT[wrapper_model_name](model, compression_model=compression_model, device=device,
**wrapper_model_config['kwargs'])
src_ckpt_file_path = wrapper_model_config.get('src_ckpt', None)
if src_ckpt_file_path is not None:
load_ckpt(src_ckpt_file_path, model=wrapped_model, strict=False)
return wrapped_model