from collections import OrderedDict
import torch
from torch import nn
from torch.nn import Module, Sequential
from torch.nn.parallel import DistributedDataParallel
from .registry import get_adaptation_module
from ..common.constant import def_logger
from ..common.file_util import make_parent_dirs
from ..common.main_util import is_main_process, save_on_master
from ..common.module_util import check_if_wrapped, get_module, get_frozen_param_names, get_updatable_param_names,\
    freeze_module_params
logger = def_logger.getChild(__name__)
[docs]
def wrap_if_distributed(module, device, device_ids, distributed, find_unused_parameters=None, **kwargs):
    """
    Wraps ``module`` with DistributedDataParallel if ``distributed`` = True and ``module`` has any updatable parameters.
    :param module: module to be wrapped.
    :type module: nn.Module
    :param device: target device.
    :type device: torch.device
    :param device_ids: target device IDs.
    :type device_ids: list[int]
    :param distributed: whether to be in distributed training mode.
    :type distributed: bool
    :param find_unused_parameters: ``find_unused_parameters`` for DistributedDataParallel.
    :type find_unused_parameters: bool or None
    :return: wrapped module if ``distributed`` = True and it contains any updatable parameters.
    :rtype: nn.Module
    """
    module.to(device)
    if distributed and len(get_updatable_param_names(module)) > 0:
        any_frozen = len(get_frozen_param_names(module)) > 0
        if find_unused_parameters is None:
            find_unused_parameters = any_frozen
        return DistributedDataParallel(module, device_ids=device_ids, find_unused_parameters=find_unused_parameters,
                                       **kwargs)
    return module 
[docs]
def load_module_ckpt(module, map_location, ckpt_file_path):
    """
    Loads checkpoint for ``module``.
    :param module: module to load checkpoint.
    :type module: nn.Module
    :param map_location: ``map_location`` for torch.load.
    :type map_location: torch.device or str or dict or typing.Callable
    :param ckpt_file_path: file path to load checkpoint.
    :type ckpt_file_path: str
    """
    state_dict = torch.load(ckpt_file_path, map_location=map_location)
    if check_if_wrapped(module):
        module.module.load_state_dict(state_dict)
    else:
        module.load_state_dict(state_dict) 
[docs]
def save_module_ckpt(module, ckpt_file_path):
    """
    Saves checkpoint of ``module``'s state dict.
    :param module: module to load checkpoint.
    :type module: nn.Module
    :param ckpt_file_path: file path to save checkpoint.
    :type ckpt_file_path: str
    """
    if is_main_process():
        make_parent_dirs(ckpt_file_path)
    state_dict = module.module.state_dict() if check_if_wrapped(module) else module.state_dict()
    save_on_master(state_dict, ckpt_file_path) 
[docs]
def add_submodule(module, module_path, module_dict):
    """
    Recursively adds submodules to `module_dict`.
    :param module: module.
    :type module: nn.Module
    :param module_path: module path.
    :type module_path: str
    :param module_dict: module dict.
    :type module_dict: nn.ModuleDict or dict
    """
    module_names = module_path.split('.')
    module_name = module_names.pop(0)
    if len(module_names) == 0:
        if module_name in module_dict:
            raise KeyError('module_name `{}` is already used.'.format(module_name))
        module_dict[module_name] = module
        return
    next_module_path = '.'.join(module_names)
    sub_module_dict = module_dict.get(module_name, None)
    if module_name not in module_dict:
        sub_module_dict = OrderedDict()
        module_dict[module_name] = sub_module_dict
    add_submodule(module, next_module_path, sub_module_dict) 
[docs]
def build_sequential_container(module_dict):
    """
    Builds sequential container (nn.Sequential) from ``module_dict``.
    :param module_dict: module dict to build sequential to build a sequential container.
    :type module_dict: nn.ModuleDict or collections.OrderedDict
    :return: sequential container.
    :rtype: nn.Sequential
    """
    for key in module_dict.keys():
        value = module_dict[key]
        if isinstance(value, OrderedDict):
            value = build_sequential_container(value)
            module_dict[key] = value
        elif not isinstance(value, Module):
            raise ValueError('module type `{}` is not expected'.format(type(value)))
    return Sequential(module_dict) 
[docs]
def redesign_model(org_model, model_config, model_label, model_type='original'):
    """
    Redesigns ``org_model`` and returns a new separate model e.g.,
    * prunes some modules from ``org_model``,
    * freezes parameters of some modules in ``org_model``, and
    * adds adaptation module(s) to ``org_model`` as a new separate model.
    .. note::
        The parameters and states of modules in ``org_model`` will be kept in a new redesigned model.
    :param org_model: original model to be redesigned.
    :type org_model: nn.Module
    :param model_config: configuration to redesign ``org_model``.
    :type model_config: dict
    :param model_label: model label (e.g., 'teacher', 'student') to be printed just for debugging purpose.
    :type model_label: str
    :param model_type: model type (e.g., 'original', name of model class, etc) to be printed just for debugging purpose.
    :type model_type: str
    :return: redesigned model.
    :rtype: nn.Module
    """
    frozen_module_path_set = set(model_config.get('frozen_modules', list()))
    module_paths = model_config.get('sequential', list())
    if not isinstance(module_paths, list) or len(module_paths) == 0:
        logger.info('Using the {} model'.format(model_type))
        if len(frozen_module_path_set) > 0:
            logger.info('Frozen module(s): {}'.format(frozen_module_path_set))
        isinstance_str = 'instance('
        for frozen_module_path in frozen_module_path_set:
            if frozen_module_path.startswith(isinstance_str) and frozen_module_path.endswith(')'):
                target_cls = nn.__dict__[frozen_module_path[len(isinstance_str):-1]]
                for m in org_model.modules():
                    if isinstance(m, target_cls):
                        freeze_module_params(m)
            else:
                module = get_module(org_model, frozen_module_path)
                freeze_module_params(module)
        return org_model
    logger.info('Redesigning the {} model with {}'.format(model_label, module_paths))
    if len(frozen_module_path_set) > 0:
        logger.info('Frozen module(s): {}'.format(frozen_module_path_set))
    module_dict = OrderedDict()
    adaptation_dict = model_config.get('adaptations', dict())
    for frozen_module_path in frozen_module_path_set:
        module = get_module(org_model, frozen_module_path)
        freeze_module_params(module)
    for module_path in module_paths:
        if module_path.startswith('+'):
            module_path = module_path[1:]
            adaptation_config = adaptation_dict[module_path]
            module = get_adaptation_module(adaptation_config['key'], **adaptation_config['kwargs'])
        else:
            module = get_module(org_model, module_path)
        if module_path in frozen_module_path_set:
            freeze_module_params(module)
        add_submodule(module, module_path, module_dict)
    return build_sequential_container(module_dict)