torchdistill.core


torchdistill.core.forward_hook


torchdistill.core.forward_hook.get_device_index(data)[source]

Gets device index of tensor in given data.

Parameters:

data (torch.Tensor or abc.Mapping or tuple or list) – tensor or data structure containing tensor.

Returns:

device index.

Return type:

int or str or None

torchdistill.core.forward_hook.register_forward_hook_with_dict(root_module, module_path, requires_input, requires_output, io_dict)[source]

Registers a forward hook for a child module to store its input and/or output in io_dict.

Parameters:
  • root_module (nn.Module) – root module (e.g., model).

  • module_path (str) – path to target child module.

  • requires_input (bool) – if True, stores input to the target child module.

  • requires_output (bool) – if True, stores output from the target child module.

  • io_dict (dict) – dict to store the target child module’s input and/or output.

Returns:

removable forward hook handle.

Return type:

torch.utils.hook.RemovableHandle

class torchdistill.core.forward_hook.ForwardHookManager(target_device)[source]

A forward hook manager for PyTorch modules.

Parameters:

target_device (torch.device or str) – target device.

Example:
>>> import torch
>>> from torchvision import models
>>> from torchdistill.core.forward_hook import ForwardHookManager
>>> device = torch.device('cpu')
>>> forward_hook_manager = ForwardHookManager(device)
>>> model = models.resnet18()
>>> forward_hook_manager.add_hook(model, 'layer2')
>>> x = torch.rand(16, 3, 224, 224)
>>> y = model(x)
>>> io_dict = forward_hook_manager.pop_io_dict()
>>> layer2_input_tensor = io_dict['layer2']['input']
>>> layer2_output_tensor = io_dict['layer2']['output']
add_hook(root_module, module_path, requires_input=True, requires_output=True)[source]

Registers a forward hook for a child module to store its input and/or output.

Parameters:
  • root_module (nn.Module) – root module (e.g., model).

  • module_path (str) – path to target child module.

  • requires_input (bool) – if True, stores input to the target child module.

  • requires_output (bool) – if True, stores output from the target child module.

pop_io_dict()[source]

Pops I/O dict after gathering tensors on self.target_device.

Returns:

I/O dict that contains input and/or output tensors with a module path as a key.

Return type:

dict

pop_io_dict_from_device(device)[source]

Pops I/O dict for a specified device.

Parameters:

device (torch.device) – device to pop I/O dict.

Returns:

I/O dict that contains input and/or output tensors with a module path as a key.

Return type:

dict

change_target_device(target_device)[source]

Updates the target device with a new target_device.

Parameters:

target_device (torch.device or str) – new target device.

clear()[source]

Clears I/O dict and forward hooks registered in the instance.

torchdistill.core.interfaces


torchdistill.core.interfaces.forward_proc

torchdistill.core.interfaces.forward_proc.forward_all(model, *args, **kwargs)[source]

Performs forward computation using *args and **kwargs.

Parameters:
  • model (nn.Module) – model.

  • args (tuple) – variable-length arguments for forward.

  • kwargs (dict) – kwargs for forward.

Returns:

model’s forward output.

Return type:

Any

torchdistill.core.interfaces.forward_proc.forward_batch_only(model, sample_batch, targets=None, supp_dict=None, **kwargs)[source]

Performs forward computation using sample_batch only.

Parameters:
  • model (nn.Module) – model.

  • sample_batch (Any) – sample batch.

  • targets (Any) – training targets (won’t be passed to forward).

  • supp_dict (dict) – supplementary dict (won’t be passed to forward).

Returns:

model’s forward output.

Return type:

Any

torchdistill.core.interfaces.forward_proc.forward_batch_target(model, sample_batch, targets, supp_dict=None, **kwargs)[source]

Performs forward computation using sample_batch and targets only.

Parameters:
  • model (nn.Module) – model.

  • sample_batch (Any) – sample batch.

  • targets (Any) – training targets.

  • supp_dict (dict) – supplementary dict (won’t be passed to forward).

Returns:

model’s forward output.

Return type:

Any

torchdistill.core.interfaces.forward_proc.forward_batch_supp_dict(model, sample_batch, targets, supp_dict=None, **kwargs)[source]

Performs forward computation using sample_batch and supp_dict only.

Parameters:
  • model (nn.Module) – model.

  • sample_batch (Any) – sample batch.

  • targets (Any) – training targets (won’t be passed to forward).

  • supp_dict (dict) – supplementary dict.

Returns:

model’s forward output.

Return type:

Any

torchdistill.core.interfaces.forward_proc.forward_batch4sskd(model, sample_batch, targets=None, supp_dict=None, **kwargs)[source]

Performs forward computation using sample_batch only for the SSKD method.

Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: β€œKnowledge Distillation Meets Self-Supervision” @ ECCV 2020 (2020)

Parameters:
  • model (nn.Module) – model.

  • sample_batch (Any) – sample batch.

  • targets (Any) – training targets (won’t be passed to forward).

  • supp_dict (dict) – supplementary dict (won’t be passed to forward).

Returns:

model’s forward output.

Return type:

Any


torchdistill.core.interfaces.pre_epoch_proc

torchdistill.core.interfaces.pre_epoch_proc.default_pre_epoch_process_with_teacher(self, epoch=None, **kwargs)[source]

Performs pre-epoch process for distillation box.

Parameters:
torchdistill.core.interfaces.pre_epoch_proc.default_pre_epoch_process_without_teacher(self, epoch=None, **kwargs)[source]

Performs pre-epoch process for training box.

Parameters:

torchdistill.core.interfaces.pre_forward_proc


torchdistill.core.interfaces.post_forward_proc

torchdistill.core.interfaces.post_forward_proc.default_post_forward_process(self, loss, metrics=None, **kwargs)[source]

Performs post-forward process for distillation box.

Parameters:

torchdistill.core.interfaces.post_epoch_proc

torchdistill.core.interfaces.post_epoch_proc.default_post_epoch_process_with_teacher(self, metrics=None, **kwargs)[source]

Performs post-epoch process for distillation box.

Parameters:
torchdistill.core.interfaces.post_epoch_proc.default_post_epoch_process_without_teacher(self, metrics=None, **kwargs)[source]

Performs post-epoch process for training box.

Parameters:

torchdistill.core.interfaces.registry

torchdistill.core.interfaces.registry.register_pre_epoch_proc_func(arg=None, **kwargs)[source]

Registers a pre-epoch process function for torchdistill.core.distillation.DistillationBox and torchdistill.core.training.TrainingBox.

Parameters:

arg (Callable or None) – function to be registered as a pre-epoch process function.

Returns:

registered pre-epoch process function.

Return type:

Callable

Note

The function will be registered as an option of the pre-epoch process function. You can choose the registered function by specifying the name of the function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the function with a key of your choice, add key to the decorator as below:

>>> from torchdistill.core.interfaces.registry import register_pre_epoch_proc_func
>>> @register_pre_epoch_proc_func(key='my_custom_pre_epoch_proc_func')
>>> def new_pre_epoch_proc(self, epoch=None, **kwargs):
>>>     print('This is my custom pre-epoch process function')

In the example, new_pre_epoch_proc function is registered with a key β€œmy_custom_pre_epoch_proc_func”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the new_pre_epoch_proc function by β€œmy_custom_pre_epoch_proc_func”.

torchdistill.core.interfaces.registry.register_pre_forward_proc_func(arg=None, **kwargs)[source]

Registers a pre-forward process function for torchdistill.core.distillation.DistillationBox and torchdistill.core.training.TrainingBox.

Parameters:

arg (Callable or None) – function to be registered as a pre-forward process function.

Returns:

registered pre-forward process function.

Return type:

Callable

Note

The function will be registered as an option of the pre-forward process function. You can choose the registered function by specifying the name of the function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the function with a key of your choice, add key to the decorator as below:

>>> from torchdistill.core.interfaces.registry import register_pre_forward_proc_func
>>> @register_pre_forward_proc_func(key='my_custom_pre_forward_proc_func')
>>> def new_pre_forward_proc(self, *args, **kwargs):
>>>     print('This is my custom pre-forward process function')

In the example, new_pre_forward_proc function is registered with a key β€œmy_custom_pre_forward_proc_func”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the new_pre_forward_proc function by β€œmy_custom_pre_forward_proc_func”.

torchdistill.core.interfaces.registry.register_forward_proc_func(arg=None, **kwargs)[source]

Registers a forward process function for torchdistill.core.distillation.DistillationBox and torchdistill.core.training.TrainingBox.

Parameters:

arg (Callable or None) – function to be registered as a forward process function.

Returns:

registered forward process function.

Return type:

Callable

Note

The function will be registered as an option of the forward process function. You can choose the registered function by specifying the name of the function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the function with a key of your choice, add key to the decorator as below:

>>> from torchdistill.core.interfaces.registry import register_forward_proc_func
>>> @register_forward_proc_func(key='my_custom_forward_proc_func')
>>> def new_forward_proc(model, sample_batch, targets=None, supp_dict=None, **kwargs):
>>>     print('This is my custom forward process function')

In the example, new_forward_proc function is registered with a key β€œmy_custom_forward_proc_func”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the new_forward_proc function by β€œmy_custom_forward_proc_func”.

torchdistill.core.interfaces.registry.register_post_forward_proc_func(arg=None, **kwargs)[source]

Registers a post-forward process function for torchdistill.core.distillation.DistillationBox and torchdistill.core.training.TrainingBox.

Parameters:

arg (Callable or None) – function to be registered as a post-forward process function.

Returns:

registered post-forward process function.

Return type:

Callable

Note

The function will be registered as an option of the post-forward process function. You can choose the registered function by specifying the name of the function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the function with a key of your choice, add key to the decorator as below:

>>> from torchdistill.core.interfaces.registry import register_post_forward_proc_func
>>> @register_post_forward_proc_func(key='my_custom_post_forward_proc_func')
>>> def new_post_forward_proc(self, loss, metrics=None, **kwargs):
>>>     print('This is my custom post-forward process function')

In the example, new_post_forward_proc function is registered with a key β€œmy_custom_post_forward_proc_func”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the new_post_forward_proc function by β€œmy_custom_post_forward_proc_func”.

torchdistill.core.interfaces.registry.register_post_epoch_proc_func(arg=None, **kwargs)[source]

Registers a post-epoch process function for torchdistill.core.distillation.DistillationBox and torchdistill.core.training.TrainingBox.

Parameters:

arg (Callable or None) – function to be registered as a post-epoch process function.

Returns:

registered post-epoch process function.

Return type:

Callable

Note

The function will be registered as an option of the post-epoch process function. You can choose the registered function by specifying the name of the function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the function with a key of your choice, add key to the decorator as below:

>>> from torchdistill.core.interfaces.registry import register_post_epoch_proc_func
>>> @register_post_epoch_proc_func(key='my_custom_post_epoch_proc_func')
>>> def new_post_epoch_proc(self, metrics=None, **kwargs):
>>>     print('This is my custom post-epoch process function')

In the example, new_post_epoch_proc function is registered with a key β€œmy_custom_post_epoch_proc_func”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the new_post_epoch_proc function by β€œmy_custom_post_epoch_proc_func”.

torchdistill.core.interfaces.registry.get_pre_epoch_proc_func(key)[source]

Gets a registered pre-epoch process function.

Parameters:

key (str) – unique key to identify the registered pre-epoch process function.

Returns:

registered pre-epoch process function.

Return type:

Callable

torchdistill.core.interfaces.registry.get_pre_forward_proc_func(key)[source]

Gets a registered pre-forward process function.

Parameters:

key (str) – unique key to identify the registered pre-forward process function.

Returns:

registered pre-forward process function.

Return type:

Callable

torchdistill.core.interfaces.registry.get_forward_proc_func(key)[source]

Gets a registered forward process function.

Parameters:

key (str) – unique key to identify the registered forward process function.

Returns:

registered forward process function.

Return type:

Callable

torchdistill.core.interfaces.registry.get_post_forward_proc_func(key)[source]

Gets a registered post-forward process function.

Parameters:

key (str) – unique key to identify the registered post-forward process function.

Returns:

registered post-forward process function.

Return type:

Callable

torchdistill.core.interfaces.registry.get_post_epoch_proc_func(key)[source]

Gets a registered post-epoch process function.

Parameters:

key (str) – unique key to identify the registered post-epoch process function.

Returns:

registered post-epoch process function.

Return type:

Callable


torchdistill.core.training


class torchdistill.core.training.TrainingBox(model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]

A single-stage training framework.

Parameters:
  • model (nn.Module) – model.

  • dataset_dict (dict) – dict that contains datasets with IDs of your choice.

  • train_config (dict) – training configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • lr_factor (float or int) – multiplier for learning rate.

  • accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.

setup_data_loaders(train_config)[source]

Sets up training and validation data loaders for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesTrainingBox.advance_to_next_stage().

Parameters:

train_config (dict) – training configuration.

setup_model(model_config)[source]

Sets up a model for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesTrainingBox.advance_to_next_stage().

Parameters:

model_config (dict) – model configuration.

setup_loss(train_config)[source]

Sets up a training loss module for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesTrainingBox.advance_to_next_stage().

Parameters:

train_config (dict) – training configuration.

setup_pre_post_processes(train_config)[source]

Sets up pre/post-epoch/forward processes for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesTrainingBox.advance_to_next_stage().

Parameters:

train_config (dict) – training configuration.

setup(train_config)[source]

Configures a TrainingBox/MultiStagesTrainingBox for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesTrainingBox.advance_to_next_stage().

Parameters:

train_config (dict) – training configuration.

pre_epoch_process(*args, **kwargs)[source]

Performs a pre-epoch process Shows the summary of results.

This should be overridden by all subclasses or defined through setup_pre_post_processes().

pre_forward_process(*args, **kwargs)[source]

Performs a pre-forward process Shows the summary of results.

This should be overridden by all subclasses or defined through setup_pre_post_processes().

forward_process(sample_batch, targets=None, supp_dict=None, **kwargs)[source]

Performs forward computations for a model.

Parameters:
  • sample_batch (Any) – sample batch.

  • targets (Any) – training targets.

  • supp_dict (dict) – supplementary dict.

Returns:

loss tensor.

Return type:

torch.Tensor

post_forward_process(*args, **kwargs)[source]

Performs a post-forward process.

This should be overridden by all subclasses or defined through setup_pre_post_processes().

post_epoch_process(*args, **kwargs)[source]

Performs a post-epoch process.

This should be overridden by all subclasses or defined through setup_pre_post_processes().

clean_modules()[source]

Unfreezes all the modules, clears an I/O dict, unregisters forward hook handles, and clears the handle lists.

class torchdistill.core.training.MultiStagesTrainingBox(model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]

A multi-stage training framework. This is a subclass of TrainingBox.

Parameters:
  • model (nn.Module) – model.

  • dataset_dict (dict) – dict that contains datasets with IDs of your choice.

  • train_config (dict) – training configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • lr_factor (float or int) – multiplier for learning rate.

  • accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.

save_stage_ckpt(model, local_model_config)[source]

Saves the checkpoint of model for the current training stage.

Parameters:
  • model (nn.Module) – model to be saved.

  • local_model_config (dict) – model configuration at the current training stage.

advance_to_next_stage()[source]

Reads the next training stage’s configuration in train_config and advances to the next training stage.

post_epoch_process(*args, **kwargs)[source]

Performs a post-epoch process.

The superclass’s post_epoch_process should be overridden by all subclasses or defined through TrainingBox.setup_pre_post_processes().

torchdistill.core.training.get_training_box(model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]

Gets a training box.

Parameters:
  • model (nn.Module) – model.

  • dataset_dict (dict) – dict that contains datasets with IDs of your choice.

  • train_config (dict) – training configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • lr_factor (float or int) – multiplier for learning rate.

  • accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.

Returns:

training box.

Return type:

TrainingBox or MultiStagesTrainingBox


torchdistill.core.distillation


class torchdistill.core.distillation.DistillationBox(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]

A single-stage knowledge distillation framework.

Parameters:
  • teacher_model (nn.Module) – teacher model.

  • student_model (nn.Module) – student model.

  • dataset_dict (dict) – dict that contains datasets with IDs of your choice.

  • train_config (dict) – training configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • lr_factor (float or int) – multiplier for learning rate.

  • accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.

setup_data_loaders(train_config)[source]

Sets up training and validation data loaders for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesDistillationBox.advance_to_next_stage().

Parameters:

train_config (dict) – training configuration.

setup_teacher_student_models(teacher_config, student_config)[source]

Sets up teacher and student models for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesDistillationBox.advance_to_next_stage().

Parameters:
  • teacher_config (dict) – teacher configuration.

  • student_config (dict) – student configuration.

setup_loss(train_config)[source]

Sets up a training loss module for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesDistillationBox.advance_to_next_stage().

Parameters:

train_config (dict) – training configuration.

setup_pre_post_processes(train_config)[source]

Sets up pre/post-epoch/forward processes for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesDistillationBox.advance_to_next_stage().

Parameters:

train_config (dict) – training configuration.

setup(train_config)[source]

Configures a DistillationBox/MultiStagesDistillationBox for the current training stage. This method will be internally called when instantiating this class and when calling MultiStagesDistillationBox.advance_to_next_stage().

Parameters:

train_config (dict) – training configuration.

pre_epoch_process(*args, **kwargs)[source]

Performs a pre-epoch process Shows the summary of results.

This should be overridden by all subclasses or defined through setup_pre_post_processes().

pre_forward_process(*args, **kwargs)[source]

Performs a pre-forward process Shows the summary of results.

This should be overridden by all subclasses or defined through setup_pre_post_processes().

get_teacher_output(sample_batch, targets, supp_dict, **kwargs)[source]

Gets teacher model’s output.

Parameters:
  • sample_batch (Any) – sample batch.

  • targets (Any) – training targets.

  • supp_dict (dict) – supplementary dict.

Returns:

teacher’s outputs and teacher’s I/O dict.

Return type:

(Any, dict)

forward_process(sample_batch, targets=None, supp_dict=None, **kwargs)[source]

Performs forward computations for teacher and student models.

Parameters:
  • sample_batch (Any) – sample batch.

  • targets (Any) – training targets.

  • supp_dict (dict) – supplementary dict.

Returns:

loss tensor.

Return type:

torch.Tensor

post_forward_process(*args, **kwargs)[source]

Performs a post-forward process.

This should be overridden by all subclasses or defined through setup_pre_post_processes().

post_epoch_process(*args, **kwargs)[source]

Performs a post-epoch process.

This should be overridden by all subclasses or defined through setup_pre_post_processes().

clean_modules()[source]

Unfreezes all the teacher and student modules, clears I/O dicts, unregisters forward hook handles, and clears the handle lists.

class torchdistill.core.distillation.MultiStagesDistillationBox(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]

A multi-stage knowledge distillation framework. This is a subclass of DistillationBox.

Parameters:
  • teacher_model (nn.Module) – teacher model.

  • student_model (nn.Module) – student model.

  • dataset_dict (dict) – dict that contains datasets with IDs of your choice.

  • train_config (dict) – training configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • lr_factor (float or int) – multiplier for learning rate.

  • accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.

save_stage_ckpt(model, local_model_config)[source]

Saves the checkpoint of model for the current training stage.

Parameters:
  • model (nn.Module) – model to be saved.

  • local_model_config (dict) – model configuration at the current training stage.

advance_to_next_stage()[source]

Reads the next training stage’s configuration in train_config and advances to the next training stage.

post_epoch_process(*args, **kwargs)[source]

Performs a post-epoch process.

The superclass’s post_epoch_process should be overridden by all subclasses or defined through DistillationBox.setup_pre_post_processes().

torchdistill.core.distillation.get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]

Gets a distillation box.

Parameters:
  • teacher_model (nn.Module) – teacher model.

  • student_model (nn.Module) – student model.

  • dataset_dict (dict) – dict that contains datasets with IDs of your choice.

  • train_config (dict) – training configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • lr_factor (float or int) – multiplier for learning rate.

  • accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.

Returns:

distillation box.

Return type:

DistillationBox or MultiStagesDistillationBox


torchdistill.core.util


torchdistill.core.util.add_kwargs_to_io_dict(io_dict, module_path, **kwargs)[source]

Adds kwargs to an I/O dict.

Parameters:
  • io_dict (dict) – I/O dict.

  • module_path (str) – module path.

  • kwargs (dict) – kwargs to be stored in io_dict.

torchdistill.core.util.set_hooks(model, unwrapped_org_model, model_config, io_dict)[source]

Sets forward hooks for target modules in model.

Parameters:
  • model (nn.Module) – model.

  • unwrapped_org_model (nn.Module) – unwrapped original model.

  • model_config (dict) – model configuration.

  • io_dict (dict) – I/O dict.

Returns:

list of pairs of module path and removable forward hook handle.

Return type:

list[(str, torch.utils.hook.RemovableHandle)]

torchdistill.core.util.wrap_model(model, model_config, device, device_ids=None, distributed=False, find_unused_parameters=False, any_updatable=True)[source]

Wraps model with either DataParallel or DistributedDataParallel if specified.

Parameters:
  • model (nn.Module) – model.

  • model_config (dict) – model configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool) – find_unused_parameters for DistributedDataParallel.

  • any_updatable (bool) – True if model contains any updatable parameters.

Returns:

wrapped model (or model if wrapper is not specified).

Return type:

nn.Module

torchdistill.core.util.change_device(data, device)[source]

Updates the device of tensor(s) stored in data with a new device.

Parameters:
  • data (Any) – data that contain tensor(s).

  • device (torch.device or str) – new device.

Returns:

data on the new device.

Return type:

Any

torchdistill.core.util.tensor2numpy2tensor(data, device)[source]

Converts tensor to numpy data and re-converts the numpy data to tensor.

Parameters:
  • data (Any) – data that contain tensor(s).

  • device (torch.device or str) – new device.

Returns:

data that contain recreated tensor(s).

Return type:

Any

torchdistill.core.util.clear_io_dict(model_io_dict)[source]

Clears a model I/O dict’s sub dict(s).

Parameters:

model_io_dict (dict) – model I/O dict.

torchdistill.core.util.extract_io_dict(model_io_dict, target_device)[source]

Extracts I/O dict, gathering tensors on target_device.

Parameters:
  • model_io_dict (dict) – model I/O dict.

  • target_device (torch.device or str) – target device.

Returns:

extracted I/O dict.

Return type:

dict

torchdistill.core.util.update_io_dict(main_io_dict, sub_io_dict)[source]

Updates an I/O dict with a sub I/O dict.

Parameters:
  • main_io_dict (dict) – main I/O dict to be updated.

  • sub_io_dict (dict) – sub I/O dict.

torchdistill.core.util.extract_sub_model_io_dict(model_io_dict, index)[source]

Extracts sub I/O dict from model_io_dict.

Parameters:
  • model_io_dict (dict) – model I/O dict.

  • index (int) – sample index.

Returns:

extracted sub I/O dict.

Return type:

dict