from collections import OrderedDict
from torch.nn import DataParallel, Sequential, ModuleList, Module, Parameter
from torch.nn.parallel import DistributedDataParallel
from .constant import def_logger
logger = def_logger.getChild(__name__)
[docs]
def check_if_wrapped(model):
"""
Checks if a given model is wrapped by DataParallel or DistributedDataParallel.
:param model: model.
:type model: nn.Module
:return: True if `model` is wrapped by either DataParallel or DistributedDataParallel.
:rtype: bool
"""
return isinstance(model, (DataParallel, DistributedDataParallel))
[docs]
def count_params(module):
"""
Returns the number of module parameters.
:param module: module.
:type module: nn.Module
:return: number of model parameters.
:rtype: int
"""
return sum(param.numel() for param in module.parameters())
[docs]
def freeze_module_params(module):
"""
Freezes parameters by setting requires_grad=False for all the parameters.
:param module: module.
:type module: nn.Module
"""
if isinstance(module, Module):
for param in module.parameters():
param.requires_grad = False
elif isinstance(module, Parameter):
module.requires_grad = False
[docs]
def unfreeze_module_params(module):
"""
Unfreezes parameters by setting requires_grad=True for all the parameters.
:param module: module.
:type module: nn.Module
"""
if isinstance(module, Module):
for param in module.parameters():
param.requires_grad = True
elif isinstance(module, Parameter):
module.requires_grad = True
[docs]
def get_updatable_param_names(module):
"""
Gets collection of updatable parameter names.
:param module: module.
:type module: nn.Module
:return: names of updatable parameters.
:rtype: list[str]
"""
return [name for name, param in module.named_parameters() if param.requires_grad]
[docs]
def get_frozen_param_names(module):
"""
Gets collection of frozen parameter names.
:param module: module.
:type module: nn.Module
:return: names of frozen parameters.
:rtype: list[str]
"""
return [name for name, param in module.named_parameters() if not param.requires_grad]
[docs]
def get_module(root_module, module_path):
"""
Gets a module specified by ``module_path``.
:param root_module: module.
:type root_module: nn.Module
:param module_path: module path for extracting the module from ``root_module``.
:type module_path: str
:return: module extracted from ``root_module`` if exists.
:rtype: nn.Module or None
"""
module_names = module_path.split('.')
module = root_module
for module_name in module_names:
if not hasattr(module, module_name):
if isinstance(module, (DataParallel, DistributedDataParallel)):
module = module.module
if not hasattr(module, module_name):
if isinstance(module, Sequential) and module_name.lstrip('-').isnumeric():
module = module[int(module_name)]
else:
logger.warning('`{}` of `{}` could not be reached in `{}`'.format(
module_name, module_path, type(root_module).__name__)
)
else:
module = getattr(module, module_name)
elif isinstance(module, (Sequential, ModuleList)) and module_name.lstrip('-').isnumeric():
module = module[int(module_name)]
else:
logger.warning('`{}` of `{}` could not be reached in `{}`'.format(
module_name, module_path, type(root_module).__name__)
)
return None
else:
module = getattr(module, module_name)
return module
[docs]
def get_hierarchized_dict(module_paths):
"""
Gets a hierarchical structure from module paths.
:param module_paths: module paths.
:type module_paths: list[str]
:return: module extracted from ``root_module`` if exists.
:rtype: dict
"""
children_dict = OrderedDict()
for module_path in module_paths:
elements = module_path.split('.')
if elements[0] not in children_dict and len(elements) == 1:
children_dict[elements[0]] = module_path
continue
elif elements[0] not in children_dict:
children_dict[elements[0]] = list()
children_dict[elements[0]].append('.'.join(elements[1:]))
for key in children_dict.keys():
value = children_dict[key]
if isinstance(value, list) and len(value) > 1:
children_dict[key] = get_hierarchized_dict(value)
return children_dict
[docs]
def decompose(ordered_dict):
"""
Converts an ordered dict into a list of key-value pairs.
:param ordered_dict: ordered dict.
:type ordered_dict: collections.OrderedDict
:return: list of key-value pairs.
:rtype: list[(str, Any)]
"""
component_list = list()
for key, value in ordered_dict.items():
if isinstance(value, OrderedDict):
component_list.append((key, decompose(value)))
elif isinstance(value, list):
component_list.append((key, value))
else:
component_list.append(key)
return component_list
[docs]
def get_components(module_paths):
"""
Converts module paths into a list of pairs of parent module and child module names.
:param module_paths: module paths.
:type module_paths: list[str]
:return: list of pairs of parent module and child module names.
:rtype: list[(str, str)]
"""
ordered_dict = get_hierarchized_dict(module_paths)
return decompose(ordered_dict)