torchdistill.coreο
torchdistill.core.forward_hookο
- torchdistill.core.forward_hook.get_device_index(data)[source]ο
Gets device index of tensor in given data.
- Parameters:
data (torch.Tensor or abc.Mapping or tuple or list) β tensor or data structure containing tensor.
- Returns:
device index.
- Return type:
int or str or None
- torchdistill.core.forward_hook.register_forward_hook_with_dict(root_module, module_path, requires_input, requires_output, io_dict)[source]ο
Registers a forward hook for a child module to store its input and/or output in io_dict.
- Parameters:
root_module (nn.Module) β root module (e.g., model).
module_path (str) β path to target child module.
requires_input (bool) β if True, stores input to the target child module.
requires_output (bool) β if True, stores output from the target child module.
io_dict (dict) β dict to store the target child moduleβs input and/or output.
- Returns:
removable forward hook handle.
- Return type:
torch.utils.hook.RemovableHandle
- class torchdistill.core.forward_hook.ForwardHookManager(target_device)[source]ο
A forward hook manager for PyTorch modules.
- Parameters:
target_device (torch.device or str) β target device.
- Example:
>>> import torch >>> from torchvision import models >>> from torchdistill.core.forward_hook import ForwardHookManager >>> device = torch.device('cpu') >>> forward_hook_manager = ForwardHookManager(device) >>> model = models.resnet18() >>> forward_hook_manager.add_hook(model, 'layer2') >>> x = torch.rand(16, 3, 224, 224) >>> y = model(x) >>> io_dict = forward_hook_manager.pop_io_dict() >>> layer2_input_tensor = io_dict['layer2']['input'] >>> layer2_output_tensor = io_dict['layer2']['output']
- add_hook(root_module, module_path, requires_input=True, requires_output=True)[source]ο
Registers a forward hook for a child module to store its input and/or output.
- Parameters:
root_module (nn.Module) β root module (e.g., model).
module_path (str) β path to target child module.
requires_input (bool) β if True, stores input to the target child module.
requires_output (bool) β if True, stores output from the target child module.
- pop_io_dict()[source]ο
Pops I/O dict after gathering tensors on
self.target_device
.- Returns:
I/O dict that contains input and/or output tensors with a module path as a key.
- Return type:
dict
- pop_io_dict_from_device(device)[source]ο
Pops I/O dict for a specified
device
.- Parameters:
device (torch.device) β device to pop I/O dict.
- Returns:
I/O dict that contains input and/or output tensors with a module path as a key.
- Return type:
dict
torchdistill.core.interfacesο
torchdistill.core.interfaces.forward_procο
- torchdistill.core.interfaces.forward_proc.forward_all(model, *args, **kwargs)[source]ο
Performs forward computation using *args and **kwargs.
- Parameters:
model (nn.Module) β model.
args (tuple) β variable-length arguments for forward.
kwargs (dict) β kwargs for forward.
- Returns:
modelβs forward output.
- Return type:
Any
- torchdistill.core.interfaces.forward_proc.forward_batch_only(model, sample_batch, targets=None, supp_dict=None, **kwargs)[source]ο
Performs forward computation using sample_batch only.
- Parameters:
model (nn.Module) β model.
sample_batch (Any) β sample batch.
targets (Any) β training targets (wonβt be passed to forward).
supp_dict (dict) β supplementary dict (wonβt be passed to forward).
- Returns:
modelβs forward output.
- Return type:
Any
- torchdistill.core.interfaces.forward_proc.forward_batch_target(model, sample_batch, targets, supp_dict=None, **kwargs)[source]ο
Performs forward computation using sample_batch and targets only.
- Parameters:
model (nn.Module) β model.
sample_batch (Any) β sample batch.
targets (Any) β training targets.
supp_dict (dict) β supplementary dict (wonβt be passed to forward).
- Returns:
modelβs forward output.
- Return type:
Any
- torchdistill.core.interfaces.forward_proc.forward_batch_supp_dict(model, sample_batch, targets, supp_dict=None, **kwargs)[source]ο
Performs forward computation using sample_batch and supp_dict only.
- Parameters:
model (nn.Module) β model.
sample_batch (Any) β sample batch.
targets (Any) β training targets (wonβt be passed to forward).
supp_dict (dict) β supplementary dict.
- Returns:
modelβs forward output.
- Return type:
Any
- torchdistill.core.interfaces.forward_proc.forward_batch4sskd(model, sample_batch, targets=None, supp_dict=None, **kwargs)[source]ο
Performs forward computation using sample_batch only for the SSKD method.
Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: βKnowledge Distillation Meets Self-Supervisionβ @ ECCV 2020 (2020)
- Parameters:
model (nn.Module) β model.
sample_batch (Any) β sample batch.
targets (Any) β training targets (wonβt be passed to forward).
supp_dict (dict) β supplementary dict (wonβt be passed to forward).
- Returns:
modelβs forward output.
- Return type:
Any
torchdistill.core.interfaces.pre_epoch_procο
- torchdistill.core.interfaces.pre_epoch_proc.default_pre_epoch_process_with_teacher(self, epoch=None, **kwargs)[source]ο
Performs pre-epoch process for distillation box.
- Parameters:
self (torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox) β distillation box.
epoch (int) β
epoch
for DistributedSampler.
- torchdistill.core.interfaces.pre_epoch_proc.default_pre_epoch_process_without_teacher(self, epoch=None, **kwargs)[source]ο
Performs pre-epoch process for training box.
- Parameters:
self (torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox) β distillation box.
epoch (int) β
epoch
for DistributedSampler.
torchdistill.core.interfaces.pre_forward_procο
torchdistill.core.interfaces.post_forward_procο
- torchdistill.core.interfaces.post_forward_proc.default_post_forward_process(self, loss, metrics=None, **kwargs)[source]ο
Performs post-forward process for distillation box.
- Parameters:
self (torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox) β distillation box.
loss (torch.Tensor) β loss tensor.
metrics (Any) β
metric
for ReduceLROnPlateau.step.
torchdistill.core.interfaces.post_epoch_procο
- torchdistill.core.interfaces.post_epoch_proc.default_post_epoch_process_with_teacher(self, metrics=None, **kwargs)[source]ο
Performs post-epoch process for distillation box.
- Parameters:
self (torchdistill.core.distillation.DistillationBox) β distillation box.
metrics (Any) β
metric
for ReduceLROnPlateau.step.
- torchdistill.core.interfaces.post_epoch_proc.default_post_epoch_process_without_teacher(self, metrics=None, **kwargs)[source]ο
Performs post-epoch process for training box.
- Parameters:
self (torchdistill.core.training.TrainingBox) β training box.
metrics (Any) β
metric
for ReduceLROnPlateau.step.
torchdistill.core.interfaces.registryο
- torchdistill.core.interfaces.registry.register_pre_epoch_proc_func(arg=None, **kwargs)[source]ο
Registers a pre-epoch process function for
torchdistill.core.distillation.DistillationBox
andtorchdistill.core.training.TrainingBox
.- Parameters:
arg (Callable or None) β function to be registered as a pre-epoch process function.
- Returns:
registered pre-epoch process function.
- Return type:
Callable
Note
The function will be registered as an option of the pre-epoch process function. You can choose the registered function by specifying the name of the function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the function with a key of your choice, add
key
to the decorator as below:>>> from torchdistill.core.interfaces.registry import register_pre_epoch_proc_func >>> @register_pre_epoch_proc_func(key='my_custom_pre_epoch_proc_func') >>> def new_pre_epoch_proc(self, epoch=None, **kwargs): >>> print('This is my custom pre-epoch process function')
In the example,
new_pre_epoch_proc
function is registered with a key βmy_custom_pre_epoch_proc_funcβ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose thenew_pre_epoch_proc
function by βmy_custom_pre_epoch_proc_funcβ.
- torchdistill.core.interfaces.registry.register_pre_forward_proc_func(arg=None, **kwargs)[source]ο
Registers a pre-forward process function for
torchdistill.core.distillation.DistillationBox
andtorchdistill.core.training.TrainingBox
.- Parameters:
arg (Callable or None) β function to be registered as a pre-forward process function.
- Returns:
registered pre-forward process function.
- Return type:
Callable
Note
The function will be registered as an option of the pre-forward process function. You can choose the registered function by specifying the name of the function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the function with a key of your choice, add
key
to the decorator as below:>>> from torchdistill.core.interfaces.registry import register_pre_forward_proc_func >>> @register_pre_forward_proc_func(key='my_custom_pre_forward_proc_func') >>> def new_pre_forward_proc(self, *args, **kwargs): >>> print('This is my custom pre-forward process function')
In the example,
new_pre_forward_proc
function is registered with a key βmy_custom_pre_forward_proc_funcβ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose thenew_pre_forward_proc
function by βmy_custom_pre_forward_proc_funcβ.
- torchdistill.core.interfaces.registry.register_forward_proc_func(arg=None, **kwargs)[source]ο
Registers a forward process function for
torchdistill.core.distillation.DistillationBox
andtorchdistill.core.training.TrainingBox
.- Parameters:
arg (Callable or None) β function to be registered as a forward process function.
- Returns:
registered forward process function.
- Return type:
Callable
Note
The function will be registered as an option of the forward process function. You can choose the registered function by specifying the name of the function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the function with a key of your choice, add
key
to the decorator as below:>>> from torchdistill.core.interfaces.registry import register_forward_proc_func >>> @register_forward_proc_func(key='my_custom_forward_proc_func') >>> def new_forward_proc(model, sample_batch, targets=None, supp_dict=None, **kwargs): >>> print('This is my custom forward process function')
In the example,
new_forward_proc
function is registered with a key βmy_custom_forward_proc_funcβ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose thenew_forward_proc
function by βmy_custom_forward_proc_funcβ.
- torchdistill.core.interfaces.registry.register_post_forward_proc_func(arg=None, **kwargs)[source]ο
Registers a post-forward process function for
torchdistill.core.distillation.DistillationBox
andtorchdistill.core.training.TrainingBox
.- Parameters:
arg (Callable or None) β function to be registered as a post-forward process function.
- Returns:
registered post-forward process function.
- Return type:
Callable
Note
The function will be registered as an option of the post-forward process function. You can choose the registered function by specifying the name of the function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the function with a key of your choice, add
key
to the decorator as below:>>> from torchdistill.core.interfaces.registry import register_post_forward_proc_func >>> @register_post_forward_proc_func(key='my_custom_post_forward_proc_func') >>> def new_post_forward_proc(self, loss, metrics=None, **kwargs): >>> print('This is my custom post-forward process function')
In the example,
new_post_forward_proc
function is registered with a key βmy_custom_post_forward_proc_funcβ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose thenew_post_forward_proc
function by βmy_custom_post_forward_proc_funcβ.
- torchdistill.core.interfaces.registry.register_post_epoch_proc_func(arg=None, **kwargs)[source]ο
Registers a post-epoch process function for
torchdistill.core.distillation.DistillationBox
andtorchdistill.core.training.TrainingBox
.- Parameters:
arg (Callable or None) β function to be registered as a post-epoch process function.
- Returns:
registered post-epoch process function.
- Return type:
Callable
Note
The function will be registered as an option of the post-epoch process function. You can choose the registered function by specifying the name of the function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the function with a key of your choice, add
key
to the decorator as below:>>> from torchdistill.core.interfaces.registry import register_post_epoch_proc_func >>> @register_post_epoch_proc_func(key='my_custom_post_epoch_proc_func') >>> def new_post_epoch_proc(self, metrics=None, **kwargs): >>> print('This is my custom post-epoch process function')
In the example,
new_post_epoch_proc
function is registered with a key βmy_custom_post_epoch_proc_funcβ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose thenew_post_epoch_proc
function by βmy_custom_post_epoch_proc_funcβ.
- torchdistill.core.interfaces.registry.get_pre_epoch_proc_func(key)[source]ο
Gets a registered pre-epoch process function.
- Parameters:
key (str) β unique key to identify the registered pre-epoch process function.
- Returns:
registered pre-epoch process function.
- Return type:
Callable
- torchdistill.core.interfaces.registry.get_pre_forward_proc_func(key)[source]ο
Gets a registered pre-forward process function.
- Parameters:
key (str) β unique key to identify the registered pre-forward process function.
- Returns:
registered pre-forward process function.
- Return type:
Callable
- torchdistill.core.interfaces.registry.get_forward_proc_func(key)[source]ο
Gets a registered forward process function.
- Parameters:
key (str) β unique key to identify the registered forward process function.
- Returns:
registered forward process function.
- Return type:
Callable
torchdistill.core.trainingο
- class torchdistill.core.training.TrainingBox(model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]ο
A single-stage training framework.
- Parameters:
model (nn.Module) β model.
dataset_dict (dict) β dict that contains datasets with IDs of your choice.
train_config (dict) β training configuration.
device (torch.device) β target device.
device_ids (list[int]) β target device IDs.
distributed (bool) β whether to be in distributed training mode.
lr_factor (float or int) β multiplier for learning rate.
accelerator (accelerate.Accelerator or None) β Hugging Face accelerator.
- setup_data_loaders(train_config)[source]ο
Sets up training and validation data loaders for the current training stage. This method will be internally called when instantiating this class and when calling
MultiStagesTrainingBox.advance_to_next_stage()
.- Parameters:
train_config (dict) β training configuration.
- setup_model(model_config)[source]ο
Sets up a model for the current training stage. This method will be internally called when instantiating this class and when calling
MultiStagesTrainingBox.advance_to_next_stage()
.- Parameters:
model_config (dict) β model configuration.
- setup_loss(train_config)[source]ο
Sets up a training loss module for the current training stage. This method will be internally called when instantiating this class and when calling
MultiStagesTrainingBox.advance_to_next_stage()
.- Parameters:
train_config (dict) β training configuration.
- setup_pre_post_processes(train_config)[source]ο
Sets up pre/post-epoch/forward processes for the current training stage. This method will be internally called when instantiating this class and when calling
MultiStagesTrainingBox.advance_to_next_stage()
.- Parameters:
train_config (dict) β training configuration.
- setup(train_config)[source]ο
Configures a
TrainingBox
/MultiStagesTrainingBox
for the current training stage. This method will be internally called when instantiating this class and when callingMultiStagesTrainingBox.advance_to_next_stage()
.- Parameters:
train_config (dict) β training configuration.
- pre_epoch_process(*args, **kwargs)[source]ο
Performs a pre-epoch process Shows the summary of results.
This should be overridden by all subclasses or defined through
setup_pre_post_processes()
.
- pre_forward_process(*args, **kwargs)[source]ο
Performs a pre-forward process Shows the summary of results.
This should be overridden by all subclasses or defined through
setup_pre_post_processes()
.
- forward_process(sample_batch, targets=None, supp_dict=None, **kwargs)[source]ο
Performs forward computations for a model.
- Parameters:
sample_batch (Any) β sample batch.
targets (Any) β training targets.
supp_dict (dict) β supplementary dict.
- Returns:
loss tensor.
- Return type:
torch.Tensor
- post_forward_process(*args, **kwargs)[source]ο
Performs a post-forward process.
This should be overridden by all subclasses or defined through
setup_pre_post_processes()
.
- post_epoch_process(*args, **kwargs)[source]ο
Performs a post-epoch process.
This should be overridden by all subclasses or defined through
setup_pre_post_processes()
.
- class torchdistill.core.training.MultiStagesTrainingBox(model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]ο
A multi-stage training framework. This is a subclass of
TrainingBox
.- Parameters:
model (nn.Module) β model.
dataset_dict (dict) β dict that contains datasets with IDs of your choice.
train_config (dict) β training configuration.
device (torch.device) β target device.
device_ids (list[int]) β target device IDs.
distributed (bool) β whether to be in distributed training mode.
lr_factor (float or int) β multiplier for learning rate.
accelerator (accelerate.Accelerator or None) β Hugging Face accelerator.
- save_stage_ckpt(model, local_model_config)[source]ο
Saves the checkpoint of
model
for the current training stage.- Parameters:
model (nn.Module) β model to be saved.
local_model_config (dict) β model configuration at the current training stage.
- advance_to_next_stage()[source]ο
Reads the next training stageβs configuration in
train_config
and advances to the next training stage.
- post_epoch_process(*args, **kwargs)[source]ο
Performs a post-epoch process.
The superclassβs post_epoch_process should be overridden by all subclasses or defined through
TrainingBox.setup_pre_post_processes()
.
- torchdistill.core.training.get_training_box(model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]ο
Gets a training box.
- Parameters:
model (nn.Module) β model.
dataset_dict (dict) β dict that contains datasets with IDs of your choice.
train_config (dict) β training configuration.
device (torch.device) β target device.
device_ids (list[int]) β target device IDs.
distributed (bool) β whether to be in distributed training mode.
lr_factor (float or int) β multiplier for learning rate.
accelerator (accelerate.Accelerator or None) β Hugging Face accelerator.
- Returns:
training box.
- Return type:
torchdistill.core.distillationο
- class torchdistill.core.distillation.DistillationBox(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]ο
A single-stage knowledge distillation framework.
- Parameters:
teacher_model (nn.Module) β teacher model.
student_model (nn.Module) β student model.
dataset_dict (dict) β dict that contains datasets with IDs of your choice.
train_config (dict) β training configuration.
device (torch.device) β target device.
device_ids (list[int]) β target device IDs.
distributed (bool) β whether to be in distributed training mode.
lr_factor (float or int) β multiplier for learning rate.
accelerator (accelerate.Accelerator or None) β Hugging Face accelerator.
- setup_data_loaders(train_config)[source]ο
Sets up training and validation data loaders for the current training stage. This method will be internally called when instantiating this class and when calling
MultiStagesDistillationBox.advance_to_next_stage()
.- Parameters:
train_config (dict) β training configuration.
- setup_teacher_student_models(teacher_config, student_config)[source]ο
Sets up teacher and student models for the current training stage. This method will be internally called when instantiating this class and when calling
MultiStagesDistillationBox.advance_to_next_stage()
.- Parameters:
teacher_config (dict) β teacher configuration.
student_config (dict) β student configuration.
- setup_loss(train_config)[source]ο
Sets up a training loss module for the current training stage. This method will be internally called when instantiating this class and when calling
MultiStagesDistillationBox.advance_to_next_stage()
.- Parameters:
train_config (dict) β training configuration.
- setup_pre_post_processes(train_config)[source]ο
Sets up pre/post-epoch/forward processes for the current training stage. This method will be internally called when instantiating this class and when calling
MultiStagesDistillationBox.advance_to_next_stage()
.- Parameters:
train_config (dict) β training configuration.
- setup(train_config)[source]ο
Configures a
DistillationBox
/MultiStagesDistillationBox
for the current training stage. This method will be internally called when instantiating this class and when callingMultiStagesDistillationBox.advance_to_next_stage()
.- Parameters:
train_config (dict) β training configuration.
- pre_epoch_process(*args, **kwargs)[source]ο
Performs a pre-epoch process Shows the summary of results.
This should be overridden by all subclasses or defined through
setup_pre_post_processes()
.
- pre_forward_process(*args, **kwargs)[source]ο
Performs a pre-forward process Shows the summary of results.
This should be overridden by all subclasses or defined through
setup_pre_post_processes()
.
- get_teacher_output(sample_batch, targets, supp_dict, **kwargs)[source]ο
Gets teacher modelβs output.
- Parameters:
sample_batch (Any) β sample batch.
targets (Any) β training targets.
supp_dict (dict) β supplementary dict.
- Returns:
teacherβs outputs and teacherβs I/O dict.
- Return type:
(Any, dict)
- forward_process(sample_batch, targets=None, supp_dict=None, **kwargs)[source]ο
Performs forward computations for teacher and student models.
- Parameters:
sample_batch (Any) β sample batch.
targets (Any) β training targets.
supp_dict (dict) β supplementary dict.
- Returns:
loss tensor.
- Return type:
torch.Tensor
- post_forward_process(*args, **kwargs)[source]ο
Performs a post-forward process.
This should be overridden by all subclasses or defined through
setup_pre_post_processes()
.
- post_epoch_process(*args, **kwargs)[source]ο
Performs a post-epoch process.
This should be overridden by all subclasses or defined through
setup_pre_post_processes()
.
- class torchdistill.core.distillation.MultiStagesDistillationBox(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]ο
A multi-stage knowledge distillation framework. This is a subclass of
DistillationBox
.- Parameters:
teacher_model (nn.Module) β teacher model.
student_model (nn.Module) β student model.
dataset_dict (dict) β dict that contains datasets with IDs of your choice.
train_config (dict) β training configuration.
device (torch.device) β target device.
device_ids (list[int]) β target device IDs.
distributed (bool) β whether to be in distributed training mode.
lr_factor (float or int) β multiplier for learning rate.
accelerator (accelerate.Accelerator or None) β Hugging Face accelerator.
- save_stage_ckpt(model, local_model_config)[source]ο
Saves the checkpoint of
model
for the current training stage.- Parameters:
model (nn.Module) β model to be saved.
local_model_config (dict) β model configuration at the current training stage.
- advance_to_next_stage()[source]ο
Reads the next training stageβs configuration in
train_config
and advances to the next training stage.
- post_epoch_process(*args, **kwargs)[source]ο
Performs a post-epoch process.
The superclassβs post_epoch_process should be overridden by all subclasses or defined through
DistillationBox.setup_pre_post_processes()
.
- torchdistill.core.distillation.get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None)[source]ο
Gets a distillation box.
- Parameters:
teacher_model (nn.Module) β teacher model.
student_model (nn.Module) β student model.
dataset_dict (dict) β dict that contains datasets with IDs of your choice.
train_config (dict) β training configuration.
device (torch.device) β target device.
device_ids (list[int]) β target device IDs.
distributed (bool) β whether to be in distributed training mode.
lr_factor (float or int) β multiplier for learning rate.
accelerator (accelerate.Accelerator or None) β Hugging Face accelerator.
- Returns:
distillation box.
- Return type:
torchdistill.core.utilο
- torchdistill.core.util.add_kwargs_to_io_dict(io_dict, module_path, **kwargs)[source]ο
Adds kwargs to an I/O dict.
- Parameters:
io_dict (dict) β I/O dict.
module_path (str) β module path.
kwargs (dict) β kwargs to be stored in
io_dict
.
- torchdistill.core.util.set_hooks(model, unwrapped_org_model, model_config, io_dict)[source]ο
Sets forward hooks for target modules in model.
- Parameters:
model (nn.Module) β model.
unwrapped_org_model (nn.Module) β unwrapped original model.
model_config (dict) β model configuration.
io_dict (dict) β I/O dict.
- Returns:
list of pairs of module path and removable forward hook handle.
- Return type:
list[(str, torch.utils.hook.RemovableHandle)]
- torchdistill.core.util.wrap_model(model, model_config, device, device_ids=None, distributed=False, find_unused_parameters=False, any_updatable=True)[source]ο
Wraps
model
with either DataParallel or DistributedDataParallel if specified.- Parameters:
model (nn.Module) β model.
model_config (dict) β model configuration.
device (torch.device) β target device.
device_ids (list[int]) β target device IDs.
distributed (bool) β whether to be in distributed training mode.
find_unused_parameters (bool) β
find_unused_parameters
for DistributedDataParallel.any_updatable (bool) β True if
model
contains any updatable parameters.
- Returns:
wrapped model (or
model
if wrapper is not specified).- Return type:
nn.Module
- torchdistill.core.util.change_device(data, device)[source]ο
Updates the device of tensor(s) stored in
data
with a newdevice
.- Parameters:
data (Any) β data that contain tensor(s).
device (torch.device or str) β new device.
- Returns:
data
on the newdevice
.- Return type:
Any
- torchdistill.core.util.tensor2numpy2tensor(data, device)[source]ο
Converts tensor to numpy data and re-converts the numpy data to tensor.
- Parameters:
data (Any) β data that contain tensor(s).
device (torch.device or str) β new device.
- Returns:
data that contain recreated tensor(s).
- Return type:
Any
- torchdistill.core.util.clear_io_dict(model_io_dict)[source]ο
Clears a model I/O dictβs sub dict(s).
- Parameters:
model_io_dict (dict) β model I/O dict.
- torchdistill.core.util.extract_io_dict(model_io_dict, target_device)[source]ο
Extracts I/O dict, gathering tensors on
target_device
.- Parameters:
model_io_dict (dict) β model I/O dict.
target_device (torch.device or str) β target device.
- Returns:
extracted I/O dict.
- Return type:
dict