Source code for torchdistill.models.util

from collections import OrderedDict

import torch
from torch import nn
from torch.nn import Module, Sequential
from torch.nn.parallel import DistributedDataParallel

from .registry import get_adaptation_module
from ..common.constant import def_logger
from ..common.file_util import make_parent_dirs
from ..common.main_util import is_main_process, save_on_master
from ..common.module_util import check_if_wrapped, get_module, get_frozen_param_names, get_updatable_param_names,\
    freeze_module_params

logger = def_logger.getChild(__name__)


[docs] def wrap_if_distributed(module, device, device_ids, distributed, find_unused_parameters=None, **kwargs): """ Wraps ``module`` with DistributedDataParallel if ``distributed`` = True and ``module`` has any updatable parameters. :param module: module to be wrapped. :type module: nn.Module :param device: target device. :type device: torch.device :param device_ids: target device IDs. :type device_ids: list[int] :param distributed: whether to be in distributed training mode. :type distributed: bool :param find_unused_parameters: ``find_unused_parameters`` for DistributedDataParallel. :type find_unused_parameters: bool or None :return: wrapped module if ``distributed`` = True and it contains any updatable parameters. :rtype: nn.Module """ module.to(device) if distributed and len(get_updatable_param_names(module)) > 0: any_frozen = len(get_frozen_param_names(module)) > 0 if find_unused_parameters is None: find_unused_parameters = any_frozen return DistributedDataParallel(module, device_ids=device_ids, find_unused_parameters=find_unused_parameters, **kwargs) return module
[docs] def load_module_ckpt(module, map_location, ckpt_file_path): """ Loads checkpoint for ``module``. :param module: module to load checkpoint. :type module: nn.Module :param map_location: ``map_location`` for torch.load. :type map_location: torch.device or str or dict or typing.Callable :param ckpt_file_path: file path to load checkpoint. :type ckpt_file_path: str """ state_dict = torch.load(ckpt_file_path, map_location=map_location) if check_if_wrapped(module): module.module.load_state_dict(state_dict) else: module.load_state_dict(state_dict)
[docs] def save_module_ckpt(module, ckpt_file_path): """ Saves checkpoint of ``module``'s state dict. :param module: module to load checkpoint. :type module: nn.Module :param ckpt_file_path: file path to save checkpoint. :type ckpt_file_path: str """ if is_main_process(): make_parent_dirs(ckpt_file_path) state_dict = module.module.state_dict() if check_if_wrapped(module) else module.state_dict() save_on_master(state_dict, ckpt_file_path)
[docs] def add_submodule(module, module_path, module_dict): """ Recursively adds submodules to `module_dict`. :param module: module. :type module: nn.Module :param module_path: module path. :type module_path: str :param module_dict: module dict. :type module_dict: nn.ModuleDict or dict """ module_names = module_path.split('.') module_name = module_names.pop(0) if len(module_names) == 0: if module_name in module_dict: raise KeyError('module_name `{}` is already used.'.format(module_name)) module_dict[module_name] = module return next_module_path = '.'.join(module_names) sub_module_dict = module_dict.get(module_name, None) if module_name not in module_dict: sub_module_dict = OrderedDict() module_dict[module_name] = sub_module_dict add_submodule(module, next_module_path, sub_module_dict)
[docs] def build_sequential_container(module_dict): """ Builds sequential container (nn.Sequential) from ``module_dict``. :param module_dict: module dict to build sequential to build a sequential container. :type module_dict: nn.ModuleDict or collections.OrderedDict :return: sequential container. :rtype: nn.Sequential """ for key in module_dict.keys(): value = module_dict[key] if isinstance(value, OrderedDict): value = build_sequential_container(value) module_dict[key] = value elif not isinstance(value, Module): raise ValueError('module type `{}` is not expected'.format(type(value))) return Sequential(module_dict)
[docs] def redesign_model(org_model, model_config, model_label, model_type='original'): """ Redesigns ``org_model`` and returns a new separate model e.g., * prunes some modules from ``org_model``, * freezes parameters of some modules in ``org_model``, and * adds adaptation module(s) to ``org_model`` as a new separate model. .. note:: The parameters and states of modules in ``org_model`` will be kept in a new redesigned model. :param org_model: original model to be redesigned. :type org_model: nn.Module :param model_config: configuration to redesign ``org_model``. :type model_config: dict :param model_label: model label (e.g., 'teacher', 'student') to be printed just for debugging purpose. :type model_label: str :param model_type: model type (e.g., 'original', name of model class, etc) to be printed just for debugging purpose. :type model_type: str :return: redesigned model. :rtype: nn.Module """ frozen_module_path_set = set(model_config.get('frozen_modules', list())) module_paths = model_config.get('sequential', list()) if not isinstance(module_paths, list) or len(module_paths) == 0: logger.info('Using the {} model'.format(model_type)) if len(frozen_module_path_set) > 0: logger.info('Frozen module(s): {}'.format(frozen_module_path_set)) isinstance_str = 'instance(' for frozen_module_path in frozen_module_path_set: if frozen_module_path.startswith(isinstance_str) and frozen_module_path.endswith(')'): target_cls = nn.__dict__[frozen_module_path[len(isinstance_str):-1]] for m in org_model.modules(): if isinstance(m, target_cls): freeze_module_params(m) else: module = get_module(org_model, frozen_module_path) freeze_module_params(module) return org_model logger.info('Redesigning the {} model with {}'.format(model_label, module_paths)) if len(frozen_module_path_set) > 0: logger.info('Frozen module(s): {}'.format(frozen_module_path_set)) module_dict = OrderedDict() adaptation_dict = model_config.get('adaptations', dict()) for frozen_module_path in frozen_module_path_set: module = get_module(org_model, frozen_module_path) freeze_module_params(module) for module_path in module_paths: if module_path.startswith('+'): module_path = module_path[1:] adaptation_config = adaptation_dict[module_path] module = get_adaptation_module(adaptation_config['key'], **adaptation_config['kwargs']) else: module = get_module(org_model, module_path) if module_path in frozen_module_path_set: freeze_module_params(module) add_submodule(module, module_path, module_dict) return build_sequential_container(module_dict)