Source code for torchdistill.common.module_util

from collections import OrderedDict

from torch.nn import DataParallel, Sequential, ModuleList, Module, Parameter
from torch.nn.parallel import DistributedDataParallel

from .constant import def_logger

logger = def_logger.getChild(__name__)


[docs] def check_if_wrapped(model): """ Checks if a given model is wrapped by DataParallel or DistributedDataParallel. :param model: model. :type model: nn.Module :return: True if `model` is wrapped by either DataParallel or DistributedDataParallel. :rtype: bool """ return isinstance(model, (DataParallel, DistributedDataParallel))
[docs] def count_params(module): """ Returns the number of module parameters. :param module: module. :type module: nn.Module :return: number of model parameters. :rtype: int """ return sum(param.numel() for param in module.parameters())
[docs] def freeze_module_params(module): """ Freezes parameters by setting requires_grad=False for all the parameters. :param module: module. :type module: nn.Module """ if isinstance(module, Module): for param in module.parameters(): param.requires_grad = False elif isinstance(module, Parameter): module.requires_grad = False
[docs] def unfreeze_module_params(module): """ Unfreezes parameters by setting requires_grad=True for all the parameters. :param module: module. :type module: nn.Module """ if isinstance(module, Module): for param in module.parameters(): param.requires_grad = True elif isinstance(module, Parameter): module.requires_grad = True
[docs] def get_updatable_param_names(module): """ Gets collection of updatable parameter names. :param module: module. :type module: nn.Module :return: names of updatable parameters. :rtype: list[str] """ return [name for name, param in module.named_parameters() if param.requires_grad]
[docs] def get_frozen_param_names(module): """ Gets collection of frozen parameter names. :param module: module. :type module: nn.Module :return: names of frozen parameters. :rtype: list[str] """ return [name for name, param in module.named_parameters() if not param.requires_grad]
[docs] def get_module(root_module, module_path): """ Gets a module specified by ``module_path``. :param root_module: module. :type root_module: nn.Module :param module_path: module path for extracting the module from ``root_module``. :type module_path: str :return: module extracted from ``root_module`` if exists. :rtype: nn.Module or None """ module_names = module_path.split('.') module = root_module for module_name in module_names: if not hasattr(module, module_name): if isinstance(module, (DataParallel, DistributedDataParallel)): module = module.module if not hasattr(module, module_name): if isinstance(module, Sequential) and module_name.lstrip('-').isnumeric(): module = module[int(module_name)] else: logger.warning('`{}` of `{}` could not be reached in `{}`'.format( module_name, module_path, type(root_module).__name__) ) else: module = getattr(module, module_name) elif isinstance(module, (Sequential, ModuleList)) and module_name.lstrip('-').isnumeric(): module = module[int(module_name)] else: logger.warning('`{}` of `{}` could not be reached in `{}`'.format( module_name, module_path, type(root_module).__name__) ) return None else: module = getattr(module, module_name) return module
[docs] def get_hierarchized_dict(module_paths): """ Gets a hierarchical structure from module paths. :param module_paths: module paths. :type module_paths: list[str] :return: module extracted from ``root_module`` if exists. :rtype: dict """ children_dict = OrderedDict() for module_path in module_paths: elements = module_path.split('.') if elements[0] not in children_dict and len(elements) == 1: children_dict[elements[0]] = module_path continue elif elements[0] not in children_dict: children_dict[elements[0]] = list() children_dict[elements[0]].append('.'.join(elements[1:])) for key in children_dict.keys(): value = children_dict[key] if isinstance(value, list) and len(value) > 1: children_dict[key] = get_hierarchized_dict(value) return children_dict
[docs] def decompose(ordered_dict): """ Converts an ordered dict into a list of key-value pairs. :param ordered_dict: ordered dict. :type ordered_dict: collections.OrderedDict :return: list of key-value pairs. :rtype: list[(str, Any)] """ component_list = list() for key, value in ordered_dict.items(): if isinstance(value, OrderedDict): component_list.append((key, decompose(value))) elif isinstance(value, list): component_list.append((key, value)) else: component_list.append(key) return component_list
[docs] def get_components(module_paths): """ Converts module paths into a list of pairs of parent module and child module names. :param module_paths: module paths. :type module_paths: list[str] :return: list of pairs of parent module and child module names. :rtype: list[(str, str)] """ ordered_dict = get_hierarchized_dict(module_paths) return decompose(ordered_dict)
[docs] def extract_target_modules(parent_module, target_class, module_list): """ Extracts modules that are instance of ``target_class`` and update ``module_list`` with the extracted modules. :param parent_module: parent module. :type parent_module: nn.Module :param target_class: target class. :type target_class: class :param module_list: (empty) list to be filled with modules that are instances of ``target_class``. :type module_list: list[nn.Module] """ if isinstance(parent_module, target_class): module_list.append(parent_module) child_modules = list(parent_module.children()) for child_module in child_modules: extract_target_modules(child_module, target_class, module_list)
[docs] def extract_all_child_modules(parent_module, module_list): """ Extracts all the child modules and update ``module_list`` with the extracted modules. :param parent_module: parent module. :type parent_module: nn.Module :param module_list: (empty) list to be filled with child modules. :type module_list: list[nn.Module] """ child_modules = list(parent_module.children()) if not child_modules: module_list.append(parent_module) return for child_module in child_modules: extract_all_child_modules(child_module, module_list)