torchdistill.models


torchdistill.models.registry

torchdistill.models.registry.register_model(arg=None, **kwargs)[source]

Registers a model class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as a model.

Returns:

registered model class or function to instantiate it.

Return type:

class or Callable

Note

The model will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class/function with a key of your choice, add key to the decorator as below:

>>> from torch import nn
>>> from torchdistill.models.registry import register_model
>>>
>>> @register_model(key='my_custom_model')
>>> class CustomModel(nn.Module):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom model class')

In the example, CustomModel class is registered with a key “my_custom_model”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomModel class by “my_custom_model”.

torchdistill.models.registry.register_adaptation_module(arg=None, **kwargs)[source]

Registers an adaptation module class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as an adaptation module.

Returns:

registered adaptation module class or function to instantiate it.

Return type:

class or Callable

Note

The adaptation module will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class/function with a key of your choice, add key to the decorator as below:

>>> from torch import nn
>>> from torchdistill.models.registry import register_adaptation_module
>>>
>>> @register_adaptation_module(key='my_custom_adaptation_module')
>>> class CustomAdaptationModule(nn.Module):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom adaptation module class')

In the example, CustomAdaptationModule class is registered with a key “my_custom_adaptation_module”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomAdaptationModule class by “my_custom_adaptation_module”.

torchdistill.models.registry.register_auxiliary_model_wrapper(arg=None, **kwargs)[source]

Registers an auxiliary model wrapper class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as an auxiliary model wrapper.

Returns:

registered auxiliary model wrapper class or function to instantiate it.

Return type:

class or Callable

Note

The auxiliary model wrapper will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class/function with a key of your choice, add key to the decorator as below:

>>> from torch import nn
>>> from torchdistill.models.registry import register_auxiliary_model_wrapper
>>>
>>> @register_auxiliary_model_wrapper(key='my_custom_auxiliary_model_wrapper')
>>> class CustomAuxiliaryModelWrapper(nn.Module):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom auxiliary model wrapper class')

In the example, CustomAuxiliaryModelWrapper class is registered with a key “my_custom_auxiliary_model_wrapper”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomAuxiliaryModelWrapper class by “my_custom_auxiliary_model_wrapper”.

torchdistill.models.registry.get_model(key, repo_or_dir=None, *args, **kwargs)[source]

Gets a model from the model registry.

Parameters:
  • key (str) – model key.

  • repo_or_dir (str or None) – repo_or_dir for torch.hub.load.

Returns:

model.

Return type:

nn.Module

torchdistill.models.registry.get_adaptation_module(key, *args, **kwargs)[source]

Gets an adaptation module from the adaptation module registry.

Parameters:

key (str) – model key.

Returns:

adaptation module.

Return type:

nn.Module

torchdistill.models.registry.get_auxiliary_model_wrapper(key, *args, **kwargs)[source]

Gets an auxiliary model wrapper from the auxiliary model wrapper registry.

Parameters:

key (str) – model key.

Returns:

auxiliary model wrapper.

Return type:

nn.Module


torchdistill.models.classification


To reproduce the test results for CIFAR datasets, the following repositories were referred for training methods:

Accuracy of models pretrained on CIFAR-10/100 datasets

Model

CIFAR-10

CIFAR-100

ResNet-20

91.92

N/A

ResNet-32

93.03

N/A

ResNet-44

93.20

N/A

ResNet-56

93.57

N/A

ResNet-110

93.50

N/A

WRN-40-4

95.24

79.44

WRN-28-10

95.53

81.27

WRN-16-8

94.76

79.26

DenseNet-BC (k=12, depth=100)

95.53

77.14

Those results are reported in the following paper:


torchdistill.models.classification.densenet

class torchdistill.models.classification.densenet.DenseNet4Cifar(growth_rate: int = 32, block_config: Tuple[int, int, int] = (12, 12, 12), num_init_features: int = 64, bn_size: int = 4, drop_rate: float = 0, num_classes: int = 10, memory_efficient: bool = False)[source]

DenseNet-BC model for CIFAR datasets. Refactored https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py for CIFAR datasets, referring to https://github.com/liuzhuang13/DenseNet

Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).

Parameters:
  • growth_rate (int) – number of filters to add each layer (k in paper).

  • block_config (list[int]) – three numbers of layers in each pooling block.

  • num_init_features (int) – number of filters to learn in the first convolution layer.

  • bn_size (int) – multiplicative factor for number of bottleneck layers. (i.e. bn_size * k features in the bottleneck layer)

  • drop_rate (float) – dropout rate after each dense layer.

  • num_classes (int) – number of classification classes.

  • memory_efficient (bool) – if True, uses checkpointing. Much more memory efficient, but slower. Refer to “the paper” for details.

torchdistill.models.classification.densenet.densenet(growth_rate: int, depth: int, num_init_features: int, bottleneck: bool, num_classes: int, pretrained: bool, progress: bool, **kwargs: Any)[source]

Instantiates a DenseNet model for CIFAR datasets.

Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).

Parameters:
  • growth_rate (int) – number of filters to add each layer (k in paper).

  • depth (int) – depth.

  • num_init_features (int) – number of filters to learn in the first convolution layer.

  • bottleneck (bool) – if True, uses bottleneck.

  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

DenseNet model.

Return type:

DenseNet4Cifar

torchdistill.models.classification.densenet.densenet_bc_k12_depth100(num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]

DenseNet-BC (k=12, depth=100) model.

Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

DenseNet-BC (k=12, depth=100) model.

Return type:

DenseNet4Cifar

torchdistill.models.classification.densenet.densenet_bc_k24_depth250(num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]

DenseNet-BC (k=24, depth=250) model.

Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

DenseNet-BC (k=24, depth=250) model.

Return type:

DenseNet4Cifar

torchdistill.models.classification.densenet.densenet_bc_k40_depth190(num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]

DenseNet-BC (k=40, depth=190) model.

Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

DenseNet-BC (k=40, depth=190) model.

Return type:

DenseNet4Cifar


torchdistill.models.classification.resnet

class torchdistill.models.classification.resnet.ResNet4Cifar(block: Type[BasicBlock], layers: List[int], num_classes: int = 10, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: List[bool] | None = None, norm_layer: Callable[[...], Module] | None = None)[source]

ResNet model for CIFAR datasets. Refactored https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py for CIFAR datasets, referring to https://github.com/facebookarchive/fb.resnet.torch

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).

Parameters:
  • block (BasicBlock) – block class.

  • layers (list[int]) – three numbers of layers in each pooling block.

  • num_classes (int) – number of classification classes.

  • zero_init_residual (bool) – if True, zero-initializes the last BN in each residual branch

  • groups (int) – groups for Conv2d.

  • width_per_group (int) – base width for Conv2d.

  • replace_stride_with_dilation (list[bool] or None) – indicates if we should replace the 2x2 stride with a dilated convolution instead.

  • norm_layer (Callable or nn.Module or None) – normalization module class or callable object.

torchdistill.models.classification.resnet.resnet(depth: int, num_classes: int, pretrained: bool, progress: bool, **kwargs: Any)[source]

Instantiates a ResNet model for CIFAR datasets.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).

Parameters:
  • depth (int) – depth.

  • num_classes (int) – number of classification classes.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

ResNet model.

Return type:

ResNet4Cifar

torchdistill.models.classification.resnet.resnet20(num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]

ResNet-20 model.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

ResNet-20 model.

Return type:

ResNet4Cifar

torchdistill.models.classification.resnet.resnet32(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar[source]

ResNet-32 model.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

ResNet-32 model.

Return type:

ResNet4Cifar

torchdistill.models.classification.resnet.resnet44(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar[source]

ResNet-44 model.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

ResNet-44 model.

Return type:

ResNet4Cifar

torchdistill.models.classification.resnet.resnet56(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar[source]

ResNet-56 model.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

ResNet-56 model.

Return type:

ResNet4Cifar

torchdistill.models.classification.resnet.resnet110(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar[source]

ResNet-110 model.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

ResNet-110 model.

Return type:

ResNet4Cifar

torchdistill.models.classification.resnet.resnet1202(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar[source]

ResNet-1202 model.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).

Parameters:
  • num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

ResNet-1202 model.

Return type:

ResNet4Cifar


torchdistill.models.classification.wide_resnet

class torchdistill.models.classification.wide_resnet.WideBasicBlock(in_planes, planes, dropout_rate, stride=1)[source]

A basic block of Wide ResNet for CIFAR datasets.

Parameters:
  • in_planes (int) – number of input feature planes.

  • planes (int) – number of output feature planes.

  • dropout_rate (float) – dropout rate.

  • stride (int) – stride for Conv2d.

class torchdistill.models.classification.wide_resnet.WideResNet4Cifar(depth, k, dropout_p, block, num_classes, norm_layer=None)[source]

Wide ResNet (WRN) model for CIFAR datasets. Refactored https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py for CIFAR datasets, referring to https://github.com/szagoruyko/wide-residual-networks

Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)

Parameters:
  • depth (int) – depth.

  • k (int) – widening factor.

  • dropout_p (float) – dropout rate.

  • block (WideBasicBlock) – block class.

  • num_classes (int) – number of classification classes.

  • norm_layer (Callable or nn.Module or None) – normalization module class or callable object.

torchdistill.models.classification.wide_resnet.wide_resnet(depth: int, k: int, dropout_p: float, num_classes: int, pretrained: bool, progress: bool, **kwargs: Any)[source]

Instantiates a Wide ResNet model for CIFAR datasets.

Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)

Parameters:
  • depth (int) – depth.

  • k (int) – widening factor.

  • dropout_p (float) – dropout rate.

  • num_classes (int) – number of classification classes.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

Wide ResNet model.

Return type:

WideResNet4Cifar

torchdistill.models.classification.wide_resnet.wide_resnet40_4(dropout_p=0.3, num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]

WRN-40-4 model.

Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)

Parameters:
  • dropout_p (float) – dropout rate.

  • num_classes (int) – number of classification classes.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

WRN-40-4 model.

Return type:

WideResNet4Cifar

torchdistill.models.classification.wide_resnet.wide_resnet28_10(dropout_p=0.3, num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]

WRN-28-10 model.

Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)

Parameters:
  • dropout_p (float) – dropout rate.

  • num_classes (int) – number of classification classes.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

WRN-28-10 model.

Return type:

WideResNet4Cifar

torchdistill.models.classification.wide_resnet.wide_resnet16_8(dropout_p=0.3, num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]

WRN-16-8 model.

Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)

Parameters:
  • dropout_p (float) – dropout rate.

  • num_classes (int) – number of classification classes.

  • pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.

  • progress (bool) – if True, displays a progress bar of the download to stderr.

Returns:

WRN-16-8 model.

Return type:

WideResNet4Cifar


torchdistill.models.official

torchdistill.models.official.get_image_classification_model(model_config, distributed=False)[source]

Gets an image classification model from torchvision.

Parameters:
  • model_config (dict) – image classification model configuration.

  • distributed (bool) – whether to be in distributed training mode.

Returns:

image classification model.

Return type:

nn.Module

torchdistill.models.official.get_object_detection_model(model_config)[source]

Gets an object detection model from torchvision.

Parameters:

model_config (dict) – object detection model configuration.

Returns:

object detection model.

Return type:

nn.Module

torchdistill.models.official.get_semantic_segmentation_model(model_config)[source]

Gets a semantic segmentation model from torchvision.

Parameters:

model_config (dict) – semantic segmentation model configuration.

Returns:

semantic segmentation model.

Return type:

nn.Module

torchdistill.models.official.get_vision_model(model_config)[source]

Gets a computer vision model from torchvision.

Parameters:

model_config (dict) – model configuration.

Returns:

computer vision model.

Return type:

nn.Module


torchdistill.models.adaptation

class torchdistill.models.adaptation.ConvReg(num_input_channels, num_output_channels, kernel_size, stride, padding, uses_relu=True)[source]

A convolutional regression for FitNets used in “Contrastive Representation Distillation” (CRD)

Parameters:
  • num_input_channels (int) – in_channels for Conv2d.

  • num_output_channels (int) – out_channels for Conv2d.

  • kernel_size ((int, int) or int) – kernel_size for Conv2d.

  • stride (int) – stride for Conv2d.

  • padding (int) – padding for Conv2d.

  • uses_relu (bool) – if True, uses ReLU as the last module.


torchdistill.models.wrapper

class torchdistill.models.wrapper.AuxiliaryModelWrapper[source]

An abstract auxiliary model wrapper.

forward(), secondary_forward(), and post_epoch_process() should be overridden by all subclasses.

class torchdistill.models.wrapper.EmptyModule(**kwargs)[source]

An empty auxiliary model wrapper. This module returns input as output and is useful when you want to replace your teacher/student model with an empty model for saving inference time. e.g., Multi-stage knowledge distillation may have some stages that do not require either teacher or student models.

class torchdistill.models.wrapper.Paraphraser4FactorTransfer(k, num_input_channels, kernel_size=3, stride=1, padding=1, uses_bn=True, uses_decoder=True)[source]

Paraphraser for factor transfer (FT). This module is used at the 1st and 2nd stages of FT method.

Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)

Parameters:
  • k (float) – paraphrase rate.

  • num_input_channels (int) – number of input channels.

  • kernel_size (int) – kernel_size for Conv2d.

  • stride (int) – stride for Conv2d.

  • padding (int) – padding for Conv2d.

  • uses_bn (bool) – if True, uses BatchNorm2d.

  • uses_decoder (bool) – if True, uses decoder in forward().

class torchdistill.models.wrapper.Translator4FactorTransfer(num_input_channels, num_output_channels, kernel_size=3, stride=1, padding=1, uses_bn=True)[source]

Translator for factor transfer (FT). This module is used at the 2nd stage of FT method. Note that “the student translator has the same three convolution layers as the paraphraser”.

Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)

Parameters:
  • num_input_channels (int) – number of input channels.

  • kernel_size (int) – kernel_size for Conv2d.

  • stride (int) – stride for Conv2d.

  • padding (int) – padding for Conv2d.

  • uses_bn (bool) – if True, uses BatchNorm2d.

class torchdistill.models.wrapper.Teacher4FactorTransfer(teacher_model, minimal, input_module_path, paraphraser_kwargs, paraphraser_ckpt, uses_decoder, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]

An auxiliary teacher model wrapper for factor transfer (FT), including paraphraser Paraphraser4FactorTransfer.

Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)

Parameters:
  • teacher_model (nn.Module) – teacher model.

  • minimal (dict or None) – model_config for build_auxiliary_model_wrapper() if you want to.

  • input_module_path (str) – path of module whose output is used as input to paraphraser.

  • paraphraser_kwargs (dict) – kwargs to instantiate Paraphraser4FactorTransfer.

  • uses_decoder (bool) – uses_decoder for Paraphraser4FactorTransfer.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.Student4FactorTransfer(student_model, input_module_path, translator_kwargs, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]

An auxiliary student model wrapper for factor transfer (FT), including translator Translator4FactorTransfer.

Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)

Parameters:
  • student_model (nn.Module) – student model.

  • input_module_path (str) – path of module whose output is used as input to paraphraser.

  • translator_kwargs (dict) – kwargs to instantiate Translator4FactorTransfer.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.Connector4DAB(student_model, connectors, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]

An auxiliary student model wrapper with connector for distillation of activation boundaries (DAB).

Byeongho Heo, Minsik Lee, Sangdoo Yun, Jin Young Choi: “Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons” @ AAAI 2019 (2019)

Parameters:
  • student_model (nn.Module) – student model.

  • connectors (dict) – connector keys and configurations.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.Regressor4VID(in_channels, middle_channels, out_channels, eps, init_pred_var, **kwargs)[source]

An auxiliary module for variational information distillation (VID).

Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, Zhenwen Dai: “Variational Information Distillation for Knowledge Transfer” @ CVPR 2019 (2019)

Parameters:
  • in_channels (int) – number of input channels for the first convolution layer.

  • mid_channels (int) – number of output/input channels for the first/second convolution layer.

  • out_channels (int) – number of output channels for the third convolution layer.

  • eps (float) – eps.

  • init_pred_var (float) – minimum variance introduced for numerical stability.

class torchdistill.models.wrapper.VariationalDistributor4VID(student_model, regressors, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]

An auxiliary student model wrapper for variational information distillation (VID), including translator Regressor4VID.

Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, Zhenwen Dai: “Variational Information Distillation for Knowledge Transfer” @ CVPR 2019 (2019)

Parameters:
  • student_model (nn.Module) – student model.

  • in_channels (int) – number of input channels for the first convolution layer.

  • regressors (dict) – regressor keys and configurations.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.Linear4CCKD(input_module, linear_kwargs, device, device_ids, distributed, teacher_model=None, student_model=None, find_unused_parameters=None, **kwargs)[source]

An auxiliary teacher/student model wrapper for correlation congruence for knowledge distillation (CCKD). Fully-connected layers cope with a mismatch of feature representations of teacher and student models.

Baoyun Peng, Xiao Jin, Jiaheng Liu, Dongsheng Li, Yichao Wu, Yu Liu, Shunfeng Zhou, Zhaoning Zhang: “Correlation Congruence for Knowledge Distillation” @ ICCV 2019 (2019)

Parameters:
  • input_module (dict) – input module configuration.

  • linear_kwargs (dict) – kwargs for Linear.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • teacher_model (nn.Module or None) – teacher model.

  • student_model (nn.Module or None) – student model.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.Normalizer4CRD(linear, power=2)[source]

An auxiliary module for contrastive representation distillation (CRD).

Yonglong Tian, Dilip Krishnan, Phillip Isola: “Contrastive Representation Distillation” @ ICLR 2020 (2020)

Parameters:
  • linear (nn.Module) – linear module.

  • power (int) – the exponents.

class torchdistill.models.wrapper.Linear4CRD(input_module_path, linear_kwargs, device, device_ids, distributed, power=2, teacher_model=None, student_model=None, find_unused_parameters=None, **kwargs)[source]

An auxiliary teacher/student model wrapper for contrastive representation distillation (CRD), including translator Normalizer4CRD. Refactored https://github.com/HobbitLong/RepDistiller/blob/master/crd/memory.py

Yonglong Tian, Dilip Krishnan, Phillip Isola: “Contrastive Representation Distillation” @ ICLR 2020 (2020)

Parameters:
  • input_module_path (str) – path of module whose output will be flattened and then used as input to normalizer.

  • linear_kwargs (dict) – kwargs for Linear.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • power (int) – power for Normalizer4CRD.

  • teacher_model (nn.Module or None) – teacher model.

  • student_model (nn.Module or None) – student model.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.HeadRCNN(head_rcnn, **kwargs)[source]

An auxiliary teacher/student model wrapper for head network distillation (HND) and generalized head network distillation (GHND).

Parameters:
  • head_rcnn (dict) – head R-CNN configuration as model_config in torchdistill.models.util.redesign_model().

  • kwargs (dict) – teacher_model or student_model keys must be included. If both teacher_model and student_model are provided, student_model will be prioritized.

class torchdistill.models.wrapper.SSWrapper4SSKD(input_module, feat_dim, ss_module_ckpt, device, device_ids, distributed, freezes_ss_module=False, teacher_model=None, student_model=None, find_unused_parameters=None, **kwargs)[source]

An auxiliary teacher/student model wrapper for self-supervision knowledge distillation (SSKD). If both teacher_model and student_model are provided, student_model will be prioritized

Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: “Knowledge Distillation Meets Self-Supervision” @ ECCV 2020 (2020)

Parameters:
  • input_module (dict) – input module configuration.

  • feat_dim (int) – number of input/output features for self-supervision module.

  • ss_module_ckpt (str) – self-supervision module checkpoint file path.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • freezes_ss_module (bool) – if True, freezes self-supervision module.

  • teacher_model (nn.Module or None) – teacher model.

  • student_model (nn.Module or None) – student model.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.VarianceBranch4PAD(student_model, input_module, feat_dim, var_estimator_ckpt, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]

An auxiliary teacher/student model wrapper for prime-aware adaptive distillation (PAD).

Youcai Zhang, Zhonghao Lan, Yuchen Dai, Fangao Zeng, Yan Bai, Jie Chang, Yichen Wei: “Prime-Aware Adaptive Distillation” @ ECCV 2020 (2020)

Parameters:
  • student_model (nn.Module) – student model.

  • input_module (dict) – input module configuration.

  • feat_dim (int) – number of input/output features for self-supervision module.

  • var_estimator_ckpt (str) – variance estimator module checkpoint file path.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.AttentionBasedFusion(in_channels, mid_channels, out_channels, uses_attention)[source]

An auxiliary module for knowledge review (KR). Refactored https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/reviewkd.py

Pengguang Chen, Shu Liu, Hengshuang Zhao, Jiaya Jia: “Distilling Knowledge via Knowledge Review” @ CVPR 2021 (2021)

Parameters:
  • in_channels (int) – number of input channels for the first convolution layer.

  • mid_channels (int) – number of output/input channels for the first/second convolution layer.

  • out_channels (int) – number of output channels for the third convolution layer.

class torchdistill.models.wrapper.Student4KnowledgeReview(student_model, abfs, device, device_ids, distributed, sizes=None, find_unused_parameters=None, **kwargs)[source]

An auxiliary student model wrapper for knowledge review (KR). Refactored https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/reviewkd.py

Pengguang Chen, Shu Liu, Hengshuang Zhao, Jiaya Jia: “Distilling Knowledge via Knowledge Review” @ CVPR 2021 (2021)

Parameters:
  • student_model (nn.Module) – student model.

  • abfs (list[dict]) – attention based fusion configurations.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.Student4KTAAD(student_model, input_module_path, feature_adapter_config, affinity_adapter_config, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]

An auxiliary student model wrapper for knowledge translation and adaptation + affinity distillation (KTAAD). Refactored https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/reviewkd.py

Tong He, Chunhua Shen, Zhi Tian, Dong Gong, Changming Sun, Youliang Yan.: “Knowledge Adaptation for Efficient Semantic Segmentation” @ CVPR 2019 (2019)

Parameters:
  • student_model (nn.Module) – student model.

  • input_module_path (str) – path of module whose output is used as input to feature adapter and affinity adapter.

  • feature_adapter_config (dict) – feature adapter configuration.

  • affinity_adapter_config (dict) – affinity adapter configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

class torchdistill.models.wrapper.ChannelSimilarityEmbed(in_channels=512, out_channels=128, **kwargs)[source]

An auxiliary module for Inter-Channel Correlation for Knowledge Distillation (ICKD). Refactored https://github.com/ADLab-AutoDrive/ICKD/blob/main/ImageNet/torchdistill/models/special.py

Li Liu, Qingle Huang, Sihao Lin, Hongwei Xie, Bing Wang, Xiaojun Chang, Xiaodan Liang: “Inter-Channel Correlation for Knowledge Distillation” @ ICCV 2021 (2021)

Parameters:
  • in_channels (int) – number of input channels for the convolution layer.

  • out_channels (int) – number of output channels for the convolution layer.

class torchdistill.models.wrapper.Student4ICKD(student_model, embeddings, device, device_ids, distributed, **kwargs)[source]

An auxiliary student model wrapper for Inter-Channel Correlation for Knowledge Distillation (ICKD). Referred to https://github.com/ADLab-AutoDrive/ICKD/blob/main/ImageNet/torchdistill/models/special.py

Li Liu, Qingle Huang, Sihao Lin, Hongwei Xie, Bing Wang, Xiaojun Chang, Xiaodan Liang: “Inter-Channel Correlation for Knowledge Distillation” @ ICCV 2021 (2021)

Parameters:
  • student_model (nn.Module) – student model.

  • embeddings (dict) – embeddings keys and configuration.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

class torchdistill.models.wrapper.SRDModelWrapper(input_module, norm_kwargs, device, device_ids, distributed, linear_kwargs=None, teacher_model=None, student_model=None, find_unused_parameters=None, **kwargs)[source]

An auxiliary model wrapper for Understanding the Role of the Projector in Knowledge Distillation. Referred to https://github.com/roymiles/Simple-Recipe-Distillation/blob/main/imagenet/torchdistill/losses/single.py

Roy Miles, Krystian Mikolajczyk: “Understanding the Role of the Projector in Knowledge Distillation” @ AAAI 2024 (2024)

Parameters:
  • model (nn.Module) – model.

  • input_module (dict) – input module configuration.

  • linear_kwargs (dict or None) – nn.Linear keyword arguments.

  • norm_kwargs (dict) – nn.BatchNorm1d keyword arguments.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • teacher_model (nn.Module or None) – teacher model.

  • student_model (nn.Module or None) – student model.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

torchdistill.models.wrapper.build_auxiliary_model_wrapper(model_config, **kwargs)[source]

Builds an auxiliary model wrapper for either teacher or student models.

Parameters:

model_config (dict) – configuration to build the auxiliary model wrapper. Should contain either ‘teacher_model’ or `student_model’.

Returns:

auxiliary model wrapper.

Return type:

nn.Module


torchdistill.models.util

torchdistill.models.util.wrap_if_distributed(module, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]

Wraps module with DistributedDataParallel if distributed = True and module has any updatable parameters.

Parameters:
  • module (nn.Module) – module to be wrapped.

  • device (torch.device) – target device.

  • device_ids (list[int]) – target device IDs.

  • distributed (bool) – whether to be in distributed training mode.

  • find_unused_parameters (bool or None) – find_unused_parameters for DistributedDataParallel.

Returns:

wrapped module if distributed = True and it contains any updatable parameters.

Return type:

nn.Module

torchdistill.models.util.load_module_ckpt(module, map_location, ckpt_file_path)[source]

Loads checkpoint for module.

Parameters:
  • module (nn.Module) – module to load checkpoint.

  • map_location (torch.device or str or dict or Callable) – map_location for torch.load.

  • ckpt_file_path (str) – file path to load checkpoint.

torchdistill.models.util.save_module_ckpt(module, ckpt_file_path)[source]

Saves checkpoint of module’s state dict.

Parameters:
  • module (nn.Module) – module to load checkpoint.

  • ckpt_file_path (str) – file path to save checkpoint.

torchdistill.models.util.add_submodule(module, module_path, module_dict)[source]

Recursively adds submodules to module_dict.

Parameters:
  • module (nn.Module) – module.

  • module_path (str) – module path.

  • module_dict (nn.ModuleDict or dict) – module dict.

torchdistill.models.util.build_sequential_container(module_dict)[source]

Builds sequential container (nn.Sequential) from module_dict.

Parameters:

module_dict (nn.ModuleDict or collections.OrderedDict) – module dict to build sequential to build a sequential container.

Returns:

sequential container.

Return type:

nn.Sequential

torchdistill.models.util.redesign_model(org_model, model_config, model_label, model_type='original')[source]

Redesigns org_model and returns a new separate model e.g.,

  • prunes some modules from org_model,

  • freezes parameters of some modules in org_model, and

  • adds adaptation module(s) to org_model as a new separate model.

Note

The parameters and states of modules in org_model will be kept in a new redesigned model.

Parameters:
  • org_model (nn.Module) – original model to be redesigned.

  • model_config (dict) – configuration to redesign org_model.

  • model_label (str) – model label (e.g., ‘teacher’, ‘student’) to be printed just for debugging purpose.

  • model_type (str) – model type (e.g., ‘original’, name of model class, etc) to be printed just for debugging purpose.

Returns:

redesigned model.

Return type:

nn.Module