from collections import abc
import torch
from torch.nn.parallel.scatter_gather import gather
from ..common.module_util import check_if_wrapped, get_module
[docs]
def get_device_index(data):
    """
    Gets device index of tensor in given data.
    :param data: tensor or data structure containing tensor.
    :type data: torch.Tensor or abc.Mapping or tuple or list
    :return: device index.
    :rtype: int or str or None
    """
    if isinstance(data, torch.Tensor):
        device = data.device
        return 'cpu' if device.type == 'cpu' else device.index
    elif isinstance(data, abc.Mapping):
        for key, data in data.items():
            result = get_device_index(data)
            if result is not None:
                return result
    elif isinstance(data, tuple):
        for d in data:
            result = get_device_index(d)
            if result is not None:
                return result
    elif isinstance(data, abc.Sequence) and not isinstance(data, (list, tuple)):
        for d in data:
            result = get_device_index(d)
            if result is not None:
                return result
    return None 
[docs]
def register_forward_hook_with_dict(root_module, module_path, requires_input, requires_output, io_dict):
    """
    Registers a forward hook for a child module to store its input and/or output in `io_dict`.
    :param root_module: root module (e.g., model).
    :type root_module: nn.Module
    :param module_path: path to target child module.
    :type module_path: str
    :param requires_input: if True, stores input to the target child module.
    :type requires_input: bool
    :param requires_output: if True, stores output from the target child module.
    :type requires_output: bool
    :param io_dict: dict to store the target child module's input and/or output.
    :type io_dict: dict
    :return: removable forward hook handle.
    :rtype: torch.utils.hook.RemovableHandle
    """
    io_dict[module_path] = dict()
    def forward_hook4input(self, func_input, func_output):
        if isinstance(func_input, tuple) and len(func_input) == 1:
            func_input = func_input[0]
        device_index = get_device_index(func_output)
        sub_io_dict = io_dict[module_path]
        if 'input' not in sub_io_dict:
            sub_io_dict['input'] = dict()
        sub_io_dict['input'][device_index] = func_input
    def forward_hook4output(self, func_input, func_output):
        if isinstance(func_output, tuple) and len(func_output) == 1:
            func_output = func_output[0]
        device_index = get_device_index(func_output)
        sub_io_dict = io_dict[module_path]
        if 'output' not in sub_io_dict:
            sub_io_dict['output'] = dict()
        sub_io_dict['output'][device_index] = func_output
    def forward_hook4io(self, func_input, func_output):
        if isinstance(func_input, tuple) and len(func_input) == 1:
            func_input = func_input[0]
        if isinstance(func_output, tuple) and len(func_output) == 1:
            func_output = func_output[0]
        device_index = get_device_index(func_output)
        sub_io_dict = io_dict[module_path]
        if 'input' not in sub_io_dict:
            sub_io_dict['input'] = dict()
        if 'output' not in sub_io_dict:
            sub_io_dict['output'] = dict()
        sub_io_dict['input'][device_index] = func_input
        sub_io_dict['output'][device_index] = func_output
    if requires_input and not requires_output:
        return root_module.register_forward_hook(forward_hook4input)
    elif not requires_input and requires_output:
        return root_module.register_forward_hook(forward_hook4output)
    elif requires_input and requires_output:
        return root_module.register_forward_hook(forward_hook4io)
    raise ValueError('Either requires_input or requires_output should be True') 
[docs]
class ForwardHookManager(object):
    """
    A forward hook manager for PyTorch modules.
    :param target_device: target device.
    :type target_device: torch.device or str
    Example:
        >>> import torch
        >>> from torchvision import models
        >>> from torchdistill.core.forward_hook import ForwardHookManager
        >>> device = torch.device('cpu')
        >>> forward_hook_manager = ForwardHookManager(device)
        >>> model = models.resnet18()
        >>> forward_hook_manager.add_hook(model, 'layer2')
        >>> x = torch.rand(16, 3, 224, 224)
        >>> y = model(x)
        >>> io_dict = forward_hook_manager.pop_io_dict()
        >>> layer2_input_tensor = io_dict['layer2']['input']
        >>> layer2_output_tensor = io_dict['layer2']['output']
    """
    def __init__(self, target_device):
        self.target_device = torch.device(target_device) if isinstance(target_device, str) else target_device
        self.uses_cuda = self.target_device.type == 'cuda'
        self.io_dict = dict()
        self.hook_list = list()
[docs]
    def add_hook(self, root_module, module_path, requires_input=True, requires_output=True):
        """
        Registers a forward hook for a child module to store its input and/or output.
        :param root_module: root module (e.g., model).
        :type root_module: nn.Module
        :param module_path: path to target child module.
        :type module_path: str
        :param requires_input: if True, stores input to the target child module.
        :type requires_input: bool
        :param requires_output: if True, stores output from the target child module.
        :type requires_output: bool
        """
        unwrapped_module = root_module.module if check_if_wrapped(root_module) else root_module
        sub_module = get_module(unwrapped_module, module_path)
        handle = \
            
register_forward_hook_with_dict(sub_module, module_path, requires_input, requires_output, self.io_dict)
        self.hook_list.append((module_path, handle)) 
[docs]
    def pop_io_dict(self):
        """
        Pops I/O dict after gathering tensors on ``self.target_device``.
        :return: I/O dict that contains input and/or output tensors with a module path as a key.
        :rtype: dict
        """
        gathered_io_dict = dict()
        for module_path, module_io_dict in self.io_dict.items():
            gathered_io_dict[module_path] = dict()
            for io_type in list(module_io_dict.keys()):
                sub_dict = module_io_dict.pop(io_type)
                values = [sub_dict[key] for key in sorted(sub_dict.keys())]
                gathered_obj = gather(values, self.target_device) if self.uses_cuda and len(values) > 1 else values[-1]
                gathered_io_dict[module_path][io_type] = gathered_obj
        return gathered_io_dict 
[docs]
    def pop_io_dict_from_device(self, device):
        """
        Pops I/O dict for a specified ``device``.
        :param device: device to pop I/O dict.
        :type device: torch.device
        :return: I/O dict that contains input and/or output tensors with a module path as a key.
        :rtype: dict
        """
        device_io_dict = dict()
        device_key = device.index if device.type == 'cuda' else device.type
        for module_path, module_io_dict in self.io_dict.items():
            device_io_dict[module_path] = dict()
            for io_type in list(module_io_dict.keys()):
                sub_dict = module_io_dict[io_type]
                device_io_dict[module_path][io_type] = sub_dict.pop(device_key)
        return device_io_dict 
[docs]
    def change_target_device(self, target_device):
        """
        Updates the target device with a new ``target_device``.
        :param target_device: new target device.
        :type target_device: torch.device or str
        """
        target_device = torch.device(target_device) if isinstance(target_device, str) else target_device
        if self.target_device.type != target_device.type:
            for sub_dict in self.io_dict.values():
                sub_dict.clear()
        self.target_device = target_device 
[docs]
    def clear(self):
        """
        Clears I/O dict and forward hooks registered in the instance.
        """
        self.io_dict.clear()
        for _, handle in self.hook_list:
            handle.remove()
        self.hook_list.clear()