torchdistill.models
torchdistill.models.registry
- torchdistill.models.registry.register_model(arg=None, **kwargs)[source]
Registers a model class or function to instantiate it.
- Parameters:
arg (class or Callable or None) – class or function to be registered as a model.
- Returns:
registered model class or function to instantiate it.
- Return type:
class or Callable
Note
The model will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class/function with a key of your choice, add
key
to the decorator as below:>>> from torch import nn >>> from torchdistill.models.registry import register_model >>> >>> @register_model(key='my_custom_model') >>> class CustomModel(nn.Module): >>> def __init__(self, **kwargs): >>> print('This is my custom model class')
In the example,
CustomModel
class is registered with a key “my_custom_model”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomModel
class by “my_custom_model”.
- torchdistill.models.registry.register_adaptation_module(arg=None, **kwargs)[source]
Registers an adaptation module class or function to instantiate it.
- Parameters:
arg (class or Callable or None) – class or function to be registered as an adaptation module.
- Returns:
registered adaptation module class or function to instantiate it.
- Return type:
class or Callable
Note
The adaptation module will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class/function with a key of your choice, add
key
to the decorator as below:>>> from torch import nn >>> from torchdistill.models.registry import register_adaptation_module >>> >>> @register_adaptation_module(key='my_custom_adaptation_module') >>> class CustomAdaptationModule(nn.Module): >>> def __init__(self, **kwargs): >>> print('This is my custom adaptation module class')
In the example,
CustomAdaptationModule
class is registered with a key “my_custom_adaptation_module”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomAdaptationModule
class by “my_custom_adaptation_module”.
- torchdistill.models.registry.register_auxiliary_model_wrapper(arg=None, **kwargs)[source]
Registers an auxiliary model wrapper class or function to instantiate it.
- Parameters:
arg (class or Callable or None) – class or function to be registered as an auxiliary model wrapper.
- Returns:
registered auxiliary model wrapper class or function to instantiate it.
- Return type:
class or Callable
Note
The auxiliary model wrapper will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class/function with a key of your choice, add
key
to the decorator as below:>>> from torch import nn >>> from torchdistill.models.registry import register_auxiliary_model_wrapper >>> >>> @register_auxiliary_model_wrapper(key='my_custom_auxiliary_model_wrapper') >>> class CustomAuxiliaryModelWrapper(nn.Module): >>> def __init__(self, **kwargs): >>> print('This is my custom auxiliary model wrapper class')
In the example,
CustomAuxiliaryModelWrapper
class is registered with a key “my_custom_auxiliary_model_wrapper”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomAuxiliaryModelWrapper
class by “my_custom_auxiliary_model_wrapper”.
- torchdistill.models.registry.get_model(key, repo_or_dir=None, *args, **kwargs)[source]
Gets a model from the model registry.
- Parameters:
key (str) – model key.
repo_or_dir (str or None) –
repo_or_dir
for torch.hub.load.
- Returns:
model.
- Return type:
nn.Module
torchdistill.models.classification
To reproduce the test results for CIFAR datasets, the following repositories were referred for training methods:
WRN (Wide ResNet): https://github.com/szagoruyko/wide-residual-networks
DenseNet-BC: https://github.com/liuzhuang13/DenseNet
Model |
CIFAR-10 |
CIFAR-100 |
---|---|---|
91.92 |
N/A |
|
93.03 |
N/A |
|
93.20 |
N/A |
|
93.57 |
N/A |
|
93.50 |
N/A |
|
95.24 |
79.44 |
|
95.53 |
81.27 |
|
94.76 |
79.26 |
|
95.53 |
77.14 |
Those results are reported in the following paper:
Yoshitomo Matsubara: “torchdistill Meets Hugging Face Libraries for Reproducible, Coding-Free Deep Learning Studies: A Case Study on NLP”
torchdistill.models.classification.densenet
- class torchdistill.models.classification.densenet.DenseNet4Cifar(growth_rate: int = 32, block_config: Tuple[int, int, int] = (12, 12, 12), num_init_features: int = 64, bn_size: int = 4, drop_rate: float = 0, num_classes: int = 10, memory_efficient: bool = False)[source]
DenseNet-BC model for CIFAR datasets. Refactored https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py for CIFAR datasets, referring to https://github.com/liuzhuang13/DenseNet
Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).
- Parameters:
growth_rate (int) – number of filters to add each layer (k in paper).
block_config (list[int]) – three numbers of layers in each pooling block.
num_init_features (int) – number of filters to learn in the first convolution layer.
bn_size (int) – multiplicative factor for number of bottleneck layers. (i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) – dropout rate after each dense layer.
num_classes (int) – number of classification classes.
memory_efficient (bool) – if True, uses checkpointing. Much more memory efficient, but slower. Refer to “the paper” for details.
- torchdistill.models.classification.densenet.densenet(growth_rate: int, depth: int, num_init_features: int, bottleneck: bool, num_classes: int, pretrained: bool, progress: bool, **kwargs: Any)[source]
Instantiates a DenseNet model for CIFAR datasets.
Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).
- Parameters:
growth_rate (int) – number of filters to add each layer (k in paper).
depth (int) – depth.
num_init_features (int) – number of filters to learn in the first convolution layer.
bottleneck (bool) – if True, uses bottleneck.
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
DenseNet model.
- Return type:
- torchdistill.models.classification.densenet.densenet_bc_k12_depth100(num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]
DenseNet-BC (k=12, depth=100) model.
Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
DenseNet-BC (k=12, depth=100) model.
- Return type:
- torchdistill.models.classification.densenet.densenet_bc_k24_depth250(num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]
DenseNet-BC (k=24, depth=250) model.
Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
DenseNet-BC (k=24, depth=250) model.
- Return type:
- torchdistill.models.classification.densenet.densenet_bc_k40_depth190(num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]
DenseNet-BC (k=40, depth=190) model.
Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger: “Densely Connected Convolutional Networks” @ CVPR 2017 (2017).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
DenseNet-BC (k=40, depth=190) model.
- Return type:
torchdistill.models.classification.resnet
- class torchdistill.models.classification.resnet.ResNet4Cifar(block: Type[BasicBlock], layers: List[int], num_classes: int = 10, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: List[bool] | None = None, norm_layer: Callable[[...], Module] | None = None)[source]
ResNet model for CIFAR datasets. Refactored https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py for CIFAR datasets, referring to https://github.com/facebookarchive/fb.resnet.torch
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).
- Parameters:
block (BasicBlock) – block class.
layers (list[int]) – three numbers of layers in each pooling block.
num_classes (int) – number of classification classes.
zero_init_residual (bool) – if True, zero-initializes the last BN in each residual branch
groups (int) –
groups
for Conv2d.width_per_group (int) – base width for Conv2d.
replace_stride_with_dilation (list[bool] or None) – indicates if we should replace the 2x2 stride with a dilated convolution instead.
norm_layer (Callable or nn.Module or None) – normalization module class or callable object.
- torchdistill.models.classification.resnet.resnet(depth: int, num_classes: int, pretrained: bool, progress: bool, **kwargs: Any)[source]
Instantiates a ResNet model for CIFAR datasets.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).
- Parameters:
depth (int) – depth.
num_classes (int) – number of classification classes.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
ResNet model.
- Return type:
- torchdistill.models.classification.resnet.resnet20(num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]
ResNet-20 model.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
ResNet-20 model.
- Return type:
- torchdistill.models.classification.resnet.resnet32(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar [source]
ResNet-32 model.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
ResNet-32 model.
- Return type:
- torchdistill.models.classification.resnet.resnet44(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar [source]
ResNet-44 model.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
ResNet-44 model.
- Return type:
- torchdistill.models.classification.resnet.resnet56(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar [source]
ResNet-56 model.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
ResNet-56 model.
- Return type:
- torchdistill.models.classification.resnet.resnet110(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar [source]
ResNet-110 model.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
ResNet-110 model.
- Return type:
- torchdistill.models.classification.resnet.resnet1202(num_classes=10, pretrained=False, progress=True, **kwargs: Any) ResNet4Cifar [source]
ResNet-1202 model.
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun: “Deep Residual Learning for Image Recognition” @ CVPR 2016 (2016).
- Parameters:
num_classes (int) – 10 or 100 for CIFAR-10 or CIFAR-100, respectively.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
ResNet-1202 model.
- Return type:
torchdistill.models.classification.wide_resnet
- class torchdistill.models.classification.wide_resnet.WideBasicBlock(in_planes, planes, dropout_rate, stride=1)[source]
A basic block of Wide ResNet for CIFAR datasets.
- Parameters:
in_planes (int) – number of input feature planes.
planes (int) – number of output feature planes.
dropout_rate (float) – dropout rate.
stride (int) – stride for Conv2d.
- class torchdistill.models.classification.wide_resnet.WideResNet4Cifar(depth, k, dropout_p, block, num_classes, norm_layer=None)[source]
Wide ResNet (WRN) model for CIFAR datasets. Refactored https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py for CIFAR datasets, referring to https://github.com/szagoruyko/wide-residual-networks
Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)
- Parameters:
depth (int) – depth.
k (int) – widening factor.
dropout_p (float) – dropout rate.
block (WideBasicBlock) – block class.
num_classes (int) – number of classification classes.
norm_layer (Callable or nn.Module or None) – normalization module class or callable object.
- torchdistill.models.classification.wide_resnet.wide_resnet(depth: int, k: int, dropout_p: float, num_classes: int, pretrained: bool, progress: bool, **kwargs: Any)[source]
Instantiates a Wide ResNet model for CIFAR datasets.
Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)
- Parameters:
depth (int) – depth.
k (int) – widening factor.
dropout_p (float) – dropout rate.
num_classes (int) – number of classification classes.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
Wide ResNet model.
- Return type:
- torchdistill.models.classification.wide_resnet.wide_resnet40_4(dropout_p=0.3, num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]
WRN-40-4 model.
Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)
- Parameters:
dropout_p (float) – dropout rate.
num_classes (int) – number of classification classes.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
WRN-40-4 model.
- Return type:
- torchdistill.models.classification.wide_resnet.wide_resnet28_10(dropout_p=0.3, num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]
WRN-28-10 model.
Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)
- Parameters:
dropout_p (float) – dropout rate.
num_classes (int) – number of classification classes.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
WRN-28-10 model.
- Return type:
- torchdistill.models.classification.wide_resnet.wide_resnet16_8(dropout_p=0.3, num_classes=10, pretrained=False, progress=True, **kwargs: Any)[source]
WRN-16-8 model.
Sergey Zagoruyko, Nikos Komodakis: “Wide Residual Networks” @ BMVC 2016 (2016)
- Parameters:
dropout_p (float) – dropout rate.
num_classes (int) – number of classification classes.
pretrained (bool) – if True, returns a model pre-trained on CIFAR dataset.
progress (bool) – if True, displays a progress bar of the download to stderr.
- Returns:
WRN-16-8 model.
- Return type:
torchdistill.models.official
- torchdistill.models.official.get_image_classification_model(model_config, distributed=False)[source]
Gets an image classification model from torchvision.
- Parameters:
model_config (dict) – image classification model configuration.
distributed (bool) – whether to be in distributed training mode.
- Returns:
image classification model.
- Return type:
nn.Module
- torchdistill.models.official.get_object_detection_model(model_config)[source]
Gets an object detection model from torchvision.
- Parameters:
model_config (dict) – object detection model configuration.
- Returns:
object detection model.
- Return type:
nn.Module
torchdistill.models.adaptation
- class torchdistill.models.adaptation.ConvReg(num_input_channels, num_output_channels, kernel_size, stride, padding, uses_relu=True)[source]
A convolutional regression for FitNets used in “Contrastive Representation Distillation” (CRD)
- Parameters:
num_input_channels (int) –
in_channels
for Conv2d.num_output_channels (int) –
out_channels
for Conv2d.kernel_size ((int, int) or int) –
kernel_size
for Conv2d.stride (int) –
stride
for Conv2d.padding (int) –
padding
for Conv2d.uses_relu (bool) – if True, uses ReLU as the last module.
torchdistill.models.wrapper
- class torchdistill.models.wrapper.AuxiliaryModelWrapper[source]
An abstract auxiliary model wrapper.
forward()
,secondary_forward()
, andpost_epoch_process()
should be overridden by all subclasses.
- class torchdistill.models.wrapper.EmptyModule(**kwargs)[source]
An empty auxiliary model wrapper. This module returns input as output and is useful when you want to replace your teacher/student model with an empty model for saving inference time. e.g., Multi-stage knowledge distillation may have some stages that do not require either teacher or student models.
- class torchdistill.models.wrapper.Paraphraser4FactorTransfer(k, num_input_channels, kernel_size=3, stride=1, padding=1, uses_bn=True, uses_decoder=True)[source]
Paraphraser for factor transfer (FT). This module is used at the 1st and 2nd stages of FT method.
Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)
- Parameters:
k (float) – paraphrase rate.
num_input_channels (int) – number of input channels.
kernel_size (int) –
kernel_size
for Conv2d.stride (int) –
stride
for Conv2d.padding (int) –
padding
for Conv2d.uses_bn (bool) – if True, uses BatchNorm2d.
uses_decoder (bool) – if True, uses decoder in
forward()
.
- class torchdistill.models.wrapper.Translator4FactorTransfer(num_input_channels, num_output_channels, kernel_size=3, stride=1, padding=1, uses_bn=True)[source]
Translator for factor transfer (FT). This module is used at the 2nd stage of FT method. Note that “the student translator has the same three convolution layers as the paraphraser”.
Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)
- Parameters:
num_input_channels (int) – number of input channels.
kernel_size (int) –
kernel_size
for Conv2d.stride (int) –
stride
for Conv2d.padding (int) –
padding
for Conv2d.uses_bn (bool) – if True, uses BatchNorm2d.
- class torchdistill.models.wrapper.Teacher4FactorTransfer(teacher_model, minimal, input_module_path, paraphraser_kwargs, paraphraser_ckpt, uses_decoder, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]
An auxiliary teacher model wrapper for factor transfer (FT), including paraphraser
Paraphraser4FactorTransfer
.Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)
- Parameters:
teacher_model (nn.Module) – teacher model.
minimal (dict or None) –
model_config
forbuild_auxiliary_model_wrapper()
if you want to.input_module_path (str) – path of module whose output is used as input to paraphraser.
paraphraser_kwargs (dict) – kwargs to instantiate
Paraphraser4FactorTransfer
.uses_decoder (bool) –
uses_decoder
forParaphraser4FactorTransfer
.device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.Student4FactorTransfer(student_model, input_module_path, translator_kwargs, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]
An auxiliary student model wrapper for factor transfer (FT), including translator
Translator4FactorTransfer
.Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)
- Parameters:
student_model (nn.Module) – student model.
input_module_path (str) – path of module whose output is used as input to paraphraser.
translator_kwargs (dict) – kwargs to instantiate
Translator4FactorTransfer
.device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.Connector4DAB(student_model, connectors, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]
An auxiliary student model wrapper with connector for distillation of activation boundaries (DAB).
Byeongho Heo, Minsik Lee, Sangdoo Yun, Jin Young Choi: “Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons” @ AAAI 2019 (2019)
- Parameters:
student_model (nn.Module) – student model.
connectors (dict) – connector keys and configurations.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.Regressor4VID(in_channels, middle_channels, out_channels, eps, init_pred_var, **kwargs)[source]
An auxiliary module for variational information distillation (VID).
Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, Zhenwen Dai: “Variational Information Distillation for Knowledge Transfer” @ CVPR 2019 (2019)
- Parameters:
in_channels (int) – number of input channels for the first convolution layer.
mid_channels (int) – number of output/input channels for the first/second convolution layer.
out_channels (int) – number of output channels for the third convolution layer.
eps (float) – eps.
init_pred_var (float) – minimum variance introduced for numerical stability.
- class torchdistill.models.wrapper.VariationalDistributor4VID(student_model, regressors, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]
An auxiliary student model wrapper for variational information distillation (VID), including translator
Regressor4VID
.Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, Zhenwen Dai: “Variational Information Distillation for Knowledge Transfer” @ CVPR 2019 (2019)
- Parameters:
student_model (nn.Module) – student model.
in_channels (int) – number of input channels for the first convolution layer.
regressors (dict) – regressor keys and configurations.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.Linear4CCKD(input_module, linear_kwargs, device, device_ids, distributed, teacher_model=None, student_model=None, find_unused_parameters=None, **kwargs)[source]
An auxiliary teacher/student model wrapper for correlation congruence for knowledge distillation (CCKD). Fully-connected layers cope with a mismatch of feature representations of teacher and student models.
Baoyun Peng, Xiao Jin, Jiaheng Liu, Dongsheng Li, Yichao Wu, Yu Liu, Shunfeng Zhou, Zhaoning Zhang: “Correlation Congruence for Knowledge Distillation” @ ICCV 2019 (2019)
- Parameters:
input_module (dict) – input module configuration.
linear_kwargs (dict) – kwargs for Linear.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
teacher_model (nn.Module or None) – teacher model.
student_model (nn.Module or None) – student model.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.Normalizer4CRD(linear, power=2)[source]
An auxiliary module for contrastive representation distillation (CRD).
Yonglong Tian, Dilip Krishnan, Phillip Isola: “Contrastive Representation Distillation” @ ICLR 2020 (2020)
- Parameters:
linear (nn.Module) – linear module.
power (int) – the exponents.
- class torchdistill.models.wrapper.Linear4CRD(input_module_path, linear_kwargs, device, device_ids, distributed, power=2, teacher_model=None, student_model=None, find_unused_parameters=None, **kwargs)[source]
An auxiliary teacher/student model wrapper for contrastive representation distillation (CRD), including translator
Normalizer4CRD
. Refactored https://github.com/HobbitLong/RepDistiller/blob/master/crd/memory.pyYonglong Tian, Dilip Krishnan, Phillip Isola: “Contrastive Representation Distillation” @ ICLR 2020 (2020)
- Parameters:
input_module_path (str) – path of module whose output will be flattened and then used as input to normalizer.
linear_kwargs (dict) – kwargs for Linear.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
power (int) –
power
forNormalizer4CRD
.teacher_model (nn.Module or None) – teacher model.
student_model (nn.Module or None) – student model.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.HeadRCNN(head_rcnn, **kwargs)[source]
An auxiliary teacher/student model wrapper for head network distillation (HND) and generalized head network distillation (GHND).
Yoshitomo Matsubara, Sabur Baidya, Davide Callegaro, Marco Levorato, Sameer Singh: “Distilled Split Deep Neural Networks for Edge-Assisted Real-Time Systems” @ MobiCom 2019 Workshop on Hot Topics in Video Analytics and Intelligent Edges (2019)
Yoshitomo Matsubara, Marco Levorato: “Neural Compression and Filtering for Edge-assisted Real-time Object Detection in Challenged Networks” @ ICPR 2020 (2021)
- Parameters:
head_rcnn (dict) – head R-CNN configuration as
model_config
intorchdistill.models.util.redesign_model()
.kwargs (dict) –
teacher_model
orstudent_model
keys must be included. If bothteacher_model
andstudent_model
are provided,student_model
will be prioritized.
- class torchdistill.models.wrapper.SSWrapper4SSKD(input_module, feat_dim, ss_module_ckpt, device, device_ids, distributed, freezes_ss_module=False, teacher_model=None, student_model=None, find_unused_parameters=None, **kwargs)[source]
An auxiliary teacher/student model wrapper for self-supervision knowledge distillation (SSKD). If both
teacher_model
andstudent_model
are provided,student_model
will be prioritizedGuodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: “Knowledge Distillation Meets Self-Supervision” @ ECCV 2020 (2020)
- Parameters:
input_module (dict) – input module configuration.
feat_dim (int) – number of input/output features for self-supervision module.
ss_module_ckpt (str) – self-supervision module checkpoint file path.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
freezes_ss_module (bool) – if True, freezes self-supervision module.
teacher_model (nn.Module or None) – teacher model.
student_model (nn.Module or None) – student model.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.VarianceBranch4PAD(student_model, input_module, feat_dim, var_estimator_ckpt, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]
An auxiliary teacher/student model wrapper for prime-aware adaptive distillation (PAD).
Youcai Zhang, Zhonghao Lan, Yuchen Dai, Fangao Zeng, Yan Bai, Jie Chang, Yichen Wei: “Prime-Aware Adaptive Distillation” @ ECCV 2020 (2020)
- Parameters:
student_model (nn.Module) – student model.
input_module (dict) – input module configuration.
feat_dim (int) – number of input/output features for self-supervision module.
var_estimator_ckpt (str) – variance estimator module checkpoint file path.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.AttentionBasedFusion(in_channels, mid_channels, out_channels, uses_attention)[source]
An auxiliary module for knowledge review (KR). Refactored https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/reviewkd.py
Pengguang Chen, Shu Liu, Hengshuang Zhao, Jiaya Jia: “Distilling Knowledge via Knowledge Review” @ CVPR 2021 (2021)
- Parameters:
in_channels (int) – number of input channels for the first convolution layer.
mid_channels (int) – number of output/input channels for the first/second convolution layer.
out_channels (int) – number of output channels for the third convolution layer.
- class torchdistill.models.wrapper.Student4KnowledgeReview(student_model, abfs, device, device_ids, distributed, sizes=None, find_unused_parameters=None, **kwargs)[source]
An auxiliary student model wrapper for knowledge review (KR). Refactored https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/reviewkd.py
Pengguang Chen, Shu Liu, Hengshuang Zhao, Jiaya Jia: “Distilling Knowledge via Knowledge Review” @ CVPR 2021 (2021)
- Parameters:
student_model (nn.Module) – student model.
abfs (list[dict]) – attention based fusion configurations.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.Student4KTAAD(student_model, input_module_path, feature_adapter_config, affinity_adapter_config, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]
An auxiliary student model wrapper for knowledge translation and adaptation + affinity distillation (KTAAD). Refactored https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/reviewkd.py
Tong He, Chunhua Shen, Zhi Tian, Dong Gong, Changming Sun, Youliang Yan.: “Knowledge Adaptation for Efficient Semantic Segmentation” @ CVPR 2019 (2019)
- Parameters:
student_model (nn.Module) – student model.
input_module_path (str) – path of module whose output is used as input to feature adapter and affinity adapter.
feature_adapter_config (dict) – feature adapter configuration.
affinity_adapter_config (dict) – affinity adapter configuration.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- class torchdistill.models.wrapper.ChannelSimilarityEmbed(in_channels=512, out_channels=128, **kwargs)[source]
An auxiliary module for Inter-Channel Correlation for Knowledge Distillation (ICKD). Refactored https://github.com/ADLab-AutoDrive/ICKD/blob/main/ImageNet/torchdistill/models/special.py
Li Liu, Qingle Huang, Sihao Lin, Hongwei Xie, Bing Wang, Xiaojun Chang, Xiaodan Liang: “Inter-Channel Correlation for Knowledge Distillation” @ ICCV 2021 (2021)
- Parameters:
in_channels (int) – number of input channels for the convolution layer.
out_channels (int) – number of output channels for the convolution layer.
- class torchdistill.models.wrapper.Student4ICKD(student_model, embeddings, device, device_ids, distributed, **kwargs)[source]
An auxiliary student model wrapper for Inter-Channel Correlation for Knowledge Distillation (ICKD). Referred to https://github.com/ADLab-AutoDrive/ICKD/blob/main/ImageNet/torchdistill/models/special.py
Li Liu, Qingle Huang, Sihao Lin, Hongwei Xie, Bing Wang, Xiaojun Chang, Xiaodan Liang: “Inter-Channel Correlation for Knowledge Distillation” @ ICCV 2021 (2021)
- Parameters:
student_model (nn.Module) – student model.
embeddings (dict) – embeddings keys and configuration.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
- class torchdistill.models.wrapper.SRDModelWrapper(input_module, norm_kwargs, device, device_ids, distributed, linear_kwargs=None, teacher_model=None, student_model=None, find_unused_parameters=None, **kwargs)[source]
An auxiliary model wrapper for Understanding the Role of the Projector in Knowledge Distillation. Referred to https://github.com/roymiles/Simple-Recipe-Distillation/blob/main/imagenet/torchdistill/losses/single.py
Roy Miles, Krystian Mikolajczyk: “Understanding the Role of the Projector in Knowledge Distillation” @ AAAI 2024 (2024)
- Parameters:
model (nn.Module) – model.
input_module (dict) – input module configuration.
linear_kwargs (dict or None) – nn.Linear keyword arguments.
norm_kwargs (dict) – nn.BatchNorm1d keyword arguments.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
teacher_model (nn.Module or None) – teacher model.
student_model (nn.Module or None) – student model.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- torchdistill.models.wrapper.build_auxiliary_model_wrapper(model_config, **kwargs)[source]
Builds an auxiliary model wrapper for either teacher or student models.
- Parameters:
model_config (dict) – configuration to build the auxiliary model wrapper. Should contain either ‘teacher_model’ or `student_model’.
- Returns:
auxiliary model wrapper.
- Return type:
nn.Module
torchdistill.models.util
- torchdistill.models.util.wrap_if_distributed(module, device, device_ids, distributed, find_unused_parameters=None, **kwargs)[source]
Wraps
module
with DistributedDataParallel ifdistributed
= True andmodule
has any updatable parameters.- Parameters:
module (nn.Module) – module to be wrapped.
device (torch.device) – target device.
device_ids (list[int]) – target device IDs.
distributed (bool) – whether to be in distributed training mode.
find_unused_parameters (bool or None) –
find_unused_parameters
for DistributedDataParallel.
- Returns:
wrapped module if
distributed
= True and it contains any updatable parameters.- Return type:
nn.Module
- torchdistill.models.util.load_module_ckpt(module, map_location, ckpt_file_path)[source]
Loads checkpoint for
module
.- Parameters:
module (nn.Module) – module to load checkpoint.
map_location (torch.device or str or dict or Callable) –
map_location
for torch.load.ckpt_file_path (str) – file path to load checkpoint.
- torchdistill.models.util.save_module_ckpt(module, ckpt_file_path)[source]
Saves checkpoint of
module
’s state dict.- Parameters:
module (nn.Module) – module to load checkpoint.
ckpt_file_path (str) – file path to save checkpoint.
- torchdistill.models.util.add_submodule(module, module_path, module_dict)[source]
Recursively adds submodules to module_dict.
- Parameters:
module (nn.Module) – module.
module_path (str) – module path.
module_dict (nn.ModuleDict or dict) – module dict.
- torchdistill.models.util.build_sequential_container(module_dict)[source]
Builds sequential container (nn.Sequential) from
module_dict
.- Parameters:
module_dict (nn.ModuleDict or collections.OrderedDict) – module dict to build sequential to build a sequential container.
- Returns:
sequential container.
- Return type:
nn.Sequential
- torchdistill.models.util.redesign_model(org_model, model_config, model_label, model_type='original')[source]
Redesigns
org_model
and returns a new separate model e.g.,prunes some modules from
org_model
,freezes parameters of some modules in
org_model
, andadds adaptation module(s) to
org_model
as a new separate model.
Note
The parameters and states of modules in
org_model
will be kept in a new redesigned model.- Parameters:
org_model (nn.Module) – original model to be redesigned.
model_config (dict) – configuration to redesign
org_model
.model_label (str) – model label (e.g., ‘teacher’, ‘student’) to be printed just for debugging purpose.
model_type (str) – model type (e.g., ‘original’, name of model class, etc) to be printed just for debugging purpose.
- Returns:
redesigned model.
- Return type:
nn.Module