torchdistill.losses


torchdistill.losses.registry

torchdistill.losses.registry.register_low_level_loss(arg=None, **kwargs)[source]

Registers a low-level loss class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as a low-level loss.

Returns:

registered low-level loss class or function to instantiate it.

Return type:

class or Callable

Note

The low-level loss will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class/function with a key of your choice, add key to the decorator as below:

>>> from torch import nn
>>> from torchdistill.losses.registry import register_low_level_loss
>>>
>>> @register_low_level_loss(key='my_custom_low_level_loss')
>>> class CustomLowLevelLoss(nn.Module):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom low-level loss class')

In the example, CustomLowLevelLoss class is registered with a key “my_custom_low_level_loss”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomLowLevelLoss class by “my_custom_low_level_loss”.

torchdistill.losses.registry.register_mid_level_loss(arg=None, **kwargs)[source]

Registers a middle-level loss class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as a middle-level loss.

Returns:

registered middle-level loss class or function to instantiate it.

Return type:

class or Callable

Note

The middle-level loss will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class/function with a key of your choice, add key to the decorator as below:

>>> from torch import nn
>>> from torchdistill.losses.registry import register_mid_level_loss
>>>
>>> @register_mid_level_loss(key='my_custom_mid_level_loss')
>>> class CustomMidLevelLoss(nn.Module):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom middle-level loss class')

In the example, CustomMidLevelLoss class is registered with a key “my_custom_mid_level_loss”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomMidLevelLoss class by “my_custom_mid_level_loss”.

torchdistill.losses.registry.register_high_level_loss(arg=None, **kwargs)[source]

Registers a high-level loss class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as a high-level loss.

Returns:

registered high-level loss class or function to instantiate it.

Return type:

class or Callable

Note

The high-level loss will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class/function with a key of your choice, add key to the decorator as below:

>>> from torch import nn
>>> from torchdistill.losses.registry import register_high_level_loss
>>>
>>> @register_high_level_loss(key='my_custom_high_level_loss')
>>> class CustomHighLevelLoss(nn.Module):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom high-level loss class')

In the example, CustomHighLevelLoss class is registered with a key “my_custom_high_level_loss”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomHighLevelLoss class by “my_custom_high_level_loss”.

torchdistill.losses.registry.register_loss_wrapper(arg=None, **kwargs)[source]

Registers a loss wrapper class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as a loss wrapper.

Returns:

registered loss wrapper class or function to instantiate it.

Return type:

class or Callable

Note

The loss wrapper will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class/function with a key of your choice, add key to the decorator as below:

>>> from torch import nn
>>> from torchdistill.losses.registry import register_loss_wrapper
>>>
>>> @register_loss_wrapper(key='my_custom_loss_wrapper')
>>> class CustomLossWrapper(nn.Module):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom loss wrapper class')

In the example, CustomHighLevelLoss class is registered with a key “my_custom_loss_wrapper”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomLossWrapper class by “my_custom_loss_wrapper”.

torchdistill.losses.registry.register_func2extract_model_output(arg=None, **kwargs)[source]

Registers a function to extract model output.

Parameters:

arg (Callable or None) – function to be registered for extracting model output.

Returns:

registered function.

Return type:

Callable

Note

The function to extract model output will be registered as an option. You can choose the registered function by specifying the name of the function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the function with a key of your choice, add key to the decorator as below:

>>> from torchdistill.losses.registry import register_func2extract_model_output
>>>
>>> @register_func2extract_model_output(key='my_custom_function2extract_model_output')
>>> def custom_func2extract_model_output(batch, label):
>>>     print('This is my custom collate function')
>>>     return batch, label

In the example, custom_func2extract_model_output function is registered with a key “my_custom_function2extract_model_output”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the custom_func2extract_model_output function by “my_custom_function2extract_model_output”.

torchdistill.losses.registry.get_low_level_loss(key, **kwargs)[source]

Gets a registered (low-level) loss module.

Parameters:

key (str) – unique key to identify the registered loss class/function.

Returns:

registered loss class or function to instantiate it.

Return type:

nn.Module

torchdistill.losses.registry.get_mid_level_loss(mid_level_criterion_config, criterion_wrapper_config=None)[source]

Gets a registered middle-level loss module.

Parameters:
  • mid_level_criterion_config (dict) – middle-level loss configuration to identify and instantiate the registered middle-level loss class.

  • criterion_wrapper_config (dict) – middle-level loss configuration to identify and instantiate the registered middle-level loss class.

Returns:

registered middle-level loss class or function to instantiate it.

Return type:

nn.Module

torchdistill.losses.registry.get_high_level_loss(criterion_config)[source]

Gets a registered high-level loss module.

Parameters:

criterion_config (dict) – high-level loss configuration to identify and instantiate the registered high-level loss class.

Returns:

registered high-level loss class or function to instantiate it.

Return type:

nn.Module

torchdistill.losses.registry.get_loss_wrapper(mid_level_loss, criterion_wrapper_config)[source]

Gets a registered loss wrapper module.

Parameters:
  • mid_level_loss (nn.Module) – middle-level loss module.

  • criterion_wrapper_config (dict) – loss wrapper configuration to identify and instantiate the registered loss wrapper class.

Returns:

registered loss wrapper class or function to instantiate it.

Return type:

nn.Module

torchdistill.losses.registry.get_func2extract_model_output(key)[source]

Gets a registered function to extract model output.

Parameters:

key (str) – unique key to identify the registered function to extract model output.

Returns:

registered function to extract model output.

Return type:

Callable


torchdistill.losses.high_level

class torchdistill.losses.high_level.AbstractLoss(sub_terms=None, **kwargs)[source]

An abstract loss module.

forward() and __str__() should be overridden by all subclasses.

Parameters:

sub_terms (dict or None) – loss module configurations.

An example yaml of sub_terms
 sub_terms:
   ce:
     criterion:
       key: 'CrossEntropyLoss'
       kwargs:
         reduction: 'mean'
     criterion_wrapper:
       key: 'SimpleLossWrapper'
       kwargs:
         input:
           is_from_teacher: False
           module_path: '.'
           io: 'output'
         target:
           uses_label: True
     weight: 1.0
class torchdistill.losses.high_level.WeightedSumLoss(model_term=None, sub_terms=None, **kwargs)[source]

A weighted sum (linear combination) of mid-/low-level loss modules.

If model_term contains a numerical value with weight key, it will be a multiplier \(W_{model}\) for the sum of model-driven loss values \(\sum_{i} L_{model, i}\).

\[L_{total} = W_{model} \cdot (\sum_{i} L_{model, i}) + \sum_{k} W_{sub, k} \cdot L_{sub, k}\]
Parameters:
  • model_term (dict or None) – model-driven loss module configurations.

  • sub_terms (dict or None) – loss module configurations.


torchdistill.losses.mid_level

class torchdistill.losses.mid_level.SimpleLossWrapper(low_level_loss, **kwargs)[source]

A simple loss wrapper module designed to use low-level loss modules (e.g., loss modules in PyTorch) in torchdistill’s pipelines.

Parameters:
  • low_level_loss (nn.Module) – low-level loss module e.g., torch.nn.CrossEntropyLoss.

  • kwargs (dict or None) – kwargs to configure what the wrapper passes low_level_loss.

An example YAML to instantiate SimpleLossWrapper.
 criterion_wrapper:
   key: 'SimpleLossWrapper'
   kwargs:
     input:
       is_from_teacher: False
       module_path: '.'
       io: 'output'
     target:
       uses_label: True
class torchdistill.losses.mid_level.DictLossWrapper(low_level_loss, weights, **kwargs)[source]

A dict-based wrapper module designed to use low-level loss modules (e.g., loss modules in PyTorch) in torchdistill’s pipelines. This is a subclass of SimpleLossWrapper and useful for models whose forward output is dict.

Parameters:
  • low_level_loss (nn.Module) – low-level loss module e.g., torch.nn.CrossEntropyLoss.

  • weights (dict) – dict contains keys that match the model’s output dict keys and corresponding loss weights.

  • kwargs (dict or None) – kwargs to configure what the wrapper passes low_level_loss.

An example YAML to instantiate DictLossWrapper for deeplabv3_resnet50 in torchvision, whose default output is a dict of outputs from its main and auxiliary branches with keys ‘out’ and ‘aux’ respectively.
 criterion_wrapper:
   key: 'DictLossWrapper'
   kwargs:
     input:
       is_from_teacher: False
       module_path: '.'
       io: 'output'
     target:
       uses_label: True
     weights:
       out: 1.0
       aux: 0.5
class torchdistill.losses.mid_level.KDLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, temperature, alpha=None, beta=None, reduction='batchmean', **kwargs)[source]

A standard knowledge distillation (KD) loss module.

\[L_{KD} = \alpha \cdot L_{CE} + (1 - \alpha) \cdot \tau^2 \cdot L_{KL}\]

Geoffrey Hinton, Oriol Vinyals, Jeff Dean: “Distilling the Knowledge in a Neural Network” @ NIPS 2014 Deep Learning and Representation Learning Workshop (2014)

Parameters:
  • student_module_path (str) – student model’s logit module path.

  • student_module_io (str) – ‘input’ or ‘output’ of the module in the student model.

  • teacher_module_path (str) – teacher model’s logit module path.

  • teacher_module_io (str) – ‘input’ or ‘output’ of the module in the teacher model.

  • temperature (float) – hyperparameter \(\tau\) to soften class-probability distributions.

  • alpha (float) – balancing factor for \(L_{CE}\), cross-entropy.

  • beta (float or None) – balancing factor (default: \(1 - \alpha\)) for \(L_{KL}\), KL divergence between class-probability distributions softened by \(\tau\).

  • reduction (str or None) – reduction for KLDivLoss. If reduction = ‘batchmean’, CrossEntropyLoss’s reduction will be ‘mean’.

class torchdistill.losses.mid_level.FSPLoss(fsp_pairs, **kwargs)[source]

A loss module for the flow of solution procedure (FSP) matrix.

Junho Yim, Donggyu Joo, Jihoon Bae, Junmo Kim: “A Gift From Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning” @ CVPR 2017 (2017)

Parameters:

fsp_pairs (dict) – configuration of teacher-student module pairs to compute the loss for the FSP matrix.

An example YAML to instantiate FSPLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision.
 criterion:
   key: 'FSPLoss'
   kwargs:
     fsp_pairs:
       pair1:
         teacher_first:
           io: 'input'
           path: 'layer1'
         teacher_second:
           io: 'output'
           path: 'layer1'
         student_first:
           io: 'input'
           path: 'layer1'
         student_second:
           io: 'output'
           path: 'layer1'
         weight: 1
       pair2:
         teacher_first:
           io: 'input'
           path: 'layer2.1'
         teacher_second:
           io: 'output'
           path: 'layer2'
         student_first:
           io: 'input'
           path: 'layer2.1'
         student_second:
           io: 'output'
           path: 'layer2'
         weight: 1
class torchdistill.losses.mid_level.ATLoss(at_pairs, mode='code', **kwargs)[source]

A loss module for attention transfer (AT). Referred to https://github.com/szagoruyko/attention-transfer/blob/master/utils.py

Sergey Zagoruyko, Nikos Komodakis: “Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer” @ ICLR 2017 (2017)

Parameters:
  • at_pairs (dict) – configuration of teacher-student module pairs to compute the loss for attention transfer.

  • mode (dict) – reference to follow ‘paper’ or ‘code’.

Warning

There is a discrepancy between Eq. (2) in the paper and the authors’ implementation as pointed out in a paper and an issue at the repository. Use mode = ‘paper’ instead of ‘code’ if you want to follow the equations in the paper.

An example YAML to instantiate ATLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision.
 criterion:
   key: 'ATLoss'
   kwargs:
     at_pairs:
       pair1:
         teacher:
           io: 'output'
           path: 'layer3'
         student:
           io: 'output'
           path: 'layer3'
         weight: 1
       pair2:
         teacher:
           io: 'output'
           path: 'layer4'
         student:
           io: 'output'
           path: 'layer4'
         weight: 1
     mode: 'code'
class torchdistill.losses.mid_level.PKTLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, eps=1e-07)[source]

A loss module for probabilistic knowledge transfer (PKT). Refactored https://github.com/passalis/probabilistic_kt/blob/master/nn/pkt.py

Nikolaos Passalis, Anastasios Tefas: “Learning Deep Representations with Probabilistic Knowledge Transfer” @ ECCV 2018 (2018)

Parameters:
  • student_module_path (str) – student model’s logit module path.

  • student_module_io (str) – ‘input’ or ‘output’ of the module in the student model.

  • teacher_module_path (str) – teacher model’s logit module path.

  • teacher_module_io (str) – ‘input’ or ‘output’ of the module in the teacher model.

  • eps (float) – constant to avoid zero division.

An example YAML to instantiate PKTLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision.
 criterion:
   key: 'PKTLoss'
   kwargs:
     student_module_path: 'fc'
     student_module_io: 'input'
     teacher_module_path: 'fc'
     teacher_module_io: 'input'
     eps: 0.0000001
class torchdistill.losses.mid_level.FTLoss(p=1, reduction='mean', paraphraser_path='paraphraser', translator_path='translator', **kwargs)[source]

A loss module for factor transfer (FT). This loss module is used at the 2nd stage of FT method.

Jangho Kim, Seonguk Park, Nojun Kwak: “Paraphrasing Complex Network: Network Compression via Factor Transfer” @ NeurIPS 2018 (2018)

Parameters:
  • p (int) – the order of norm.

  • reduction (str) – loss reduction type.

  • paraphraser_path (str) – teacher model’s paraphrase module path.

  • translator_path (str) – student model’s translator module path.

An example YAML to instantiate FTLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using auxiliary modules torchdistill.models.wrapper.Teacher4FactorTransfer and torchdistill.models.wrapper.Student4FactorTransfer.
 criterion:
   key: 'FTLoss'
   kwargs:
     p: 1
     reduction: 'mean'
     paraphraser_path: 'paraphraser'
     translator_path: 'translator'
class torchdistill.losses.mid_level.AltActTransferLoss(feature_pairs, margin, reduction, **kwargs)[source]

A loss module for distillation of activation boundaries (DAB). Refactored https://github.com/bhheo/AB_distillation/blob/master/cifar10_AB_distillation.py

Byeongho Heo, Minsik Lee, Sangdoo Yun, Jin Young Choi: “Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons” @ AAAI 2019 (2019)

Parameters:
  • feature_pairs (dict) – configuration of teacher-student module pairs to compute the loss for distillation of activation boundaries.

  • margin (float) – margin.

  • reduction (str) – loss reduction type.

An example YAML to instantiate AltActTransferLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.Connector4DAB.
 criterion:
   key: 'AltActTransferLoss'
   kwargs:
     feature_pairs:
       pair1:
         teacher:
           io: 'output'
           path: 'layer1'
         student:
           io: 'output'
           path: 'connector_dict.connector1'
         weight: 1
       pair2:
         teacher:
           io: 'output'
           path: 'layer2'
         student:
           io: 'output'
           path: 'connector_dict.connector2'
         weight: 1
       pair3:
         teacher:
           io: 'output'
           path: 'layer3'
         student:
           io: 'output'
           path: 'connector_dict.connector3'
         weight: 1
       pair4:
         teacher:
           io: 'output'
           path: 'layer4'
         student:
           io: 'output'
           path: 'connector_dict.connector4'
         weight: 1
     margin: 1.0
     reduction: 'mean'
class torchdistill.losses.mid_level.RKDLoss(student_output_path, teacher_output_path, dist_factor, angle_factor, reduction, **kwargs)[source]

A loss module for relational knowledge distillation (RKD). Refactored https://github.com/lenscloth/RKD/blob/master/metric/loss.py

Wonpyo Park, Dongju Kim, Yan Lu, Minsu Cho: “Relational Knowledge Distillation” @ CVPR 2019 (2019)

Parameters:
  • student_output_path (str) – student module path whose output is used in this loss module.

  • teacher_output_path (str) – teacher module path whose output is used in this loss module.

  • dist_factor (float) – weight on distance-based RKD loss.

  • angle_factor (float) – weight on angle-based RKD loss.

  • reduction (str) – reduction for SmoothL1Loss.

An example YAML to instantiate RKDLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision.
 criterion:
   key: 'RKDLoss'
   kwargs:
     teacher_output_path: 'layer4'
     student_output_path: 'layer4'
     dist_factor: 1.0
     angle_factor: 2.0
     reduction: 'mean'
class torchdistill.losses.mid_level.VIDLoss(feature_pairs, **kwargs)[source]

A loss module for variational information distillation (VID). Referred to https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/VID.py

Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, Zhenwen Dai: “Variational Information Distillation for Knowledge Transfer” @ CVPR 2019 (2019)

Parameters:

feature_pairs (dict) – configuration of teacher-student module pairs to compute the loss for variational information distillation.

An example YAML to instantiate VIDLoss for a teacher-student pair of ResNet-50 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.VariationalDistributor4VID for the student model.
 criterion:
   key: 'VIDLoss'
   kwargs:
     feature_pairs:
       pair1:
         teacher:
           io: 'output'
           path: 'layer1'
         student:
           io: 'output'
           path: 'regressor_dict.regressor1'
         weight: 1
       pair2:
         teacher:
           io: 'output'
           path: 'layer2'
         student:
           io: 'output'
           path: 'regressor_dict.regressor2'
         weight: 1
       pair3:
         teacher:
           io: 'output'
           path: 'layer3'
         student:
           io: 'output'
           path: 'regressor_dict.regressor3'
         weight: 1
       pair4:
         teacher:
           io: 'output'
           path: 'layer4'
         student:
           io: 'output'
           path: 'regressor_dict.regressor4'
         weight: 1
     margin: 1.0
class torchdistill.losses.mid_level.CCKDLoss(student_linear_path, teacher_linear_path, kernel_config, reduction, **kwargs)[source]

A loss module for correlation congruence for knowledge distillation (CCKD).

Baoyun Peng, Xiao Jin, Jiaheng Liu, Dongsheng Li, Yichao Wu, Yu Liu, Shunfeng Zhou, Zhaoning Zhang: “Correlation Congruence for Knowledge Distillation” @ ICCV 2019 (2019)

Parameters:
  • student_linear_path (str) – student model’s linear module path in an auxiliary wrapper torchdistill.models.wrapper.Linear4CCKD.

  • teacher_linear_path (str) – teacher model’s linear module path in an auxiliary wrapper torchdistill.models.wrapper.Linear4CCKD.

  • kernel_config (dict) – kernel (‘gaussian’ or ‘bilinear’) configuration.

  • reduction (str) – loss reduction type.

An example YAML to instantiate CCKDLoss for a teacher-student pair of ResNet-50 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.Linear4CCKD for the teacher and student models.
 criterion:
   key: 'CCKDLoss'
   kwargs:
     teacher_linear_path: 'linear'
     student_linear_path: 'linear'
     kernel_params:
       key: 'gaussian'
       gamma: 0.4
       max_p: 2
     reduction: 'batchmean'
class torchdistill.losses.mid_level.SPKDLoss(student_output_path, teacher_output_path, reduction, **kwargs)[source]

A loss module for similarity-preserving knowledge distillation (SPKD).

Frederick Tung, Greg Mori: “Similarity-Preserving Knowledge Distillation” @ ICCV2019 (2019)

Parameters:
  • student_output_path (str) – student module path whose output is used in this loss module.

  • teacher_output_path (str) – teacher module path whose output is used in this loss module.

  • reduction (str) – loss reduction type.

An example YAML to instantiate SPKDLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision.
 criterion:
   key: 'SPKDLoss'
   kwargs:
     teacher_output_path: 'layer4'
     student_output_path: 'layer4'
     reduction: 'batchmean'
class torchdistill.losses.mid_level.CRDLoss(student_norm_module_path, student_empty_module_path, teacher_norm_module_path, input_size, output_size, num_negative_samples, num_samples, temperature=0.07, momentum=0.5, eps=1e-07)[source]

A loss module for contrastive representation distillation (CRD). Refactored https://github.com/HobbitLong/RepDistiller/blob/master/crd/criterion.py

Yonglong Tian, Dilip Krishnan, Phillip Isola: “Contrastive Representation Distillation” @ ICLR 2020 (2020)

Parameters:
An example YAML to instantiate CRDLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.Linear4CRD for the teacher and student models.
 criterion:
   key: 'CRDLoss'
   kwargs:
     teacher_norm_module_path: 'normalizer'
     student_norm_module_path: 'normalizer'
     student_empty_module_path: 'empty'
     input_size: *feature_dim
     output_size: &num_samples 1281167
     num_negative_samples: *num_negative_samples
     num_samples: *num_samples
     temperature: 0.07
     momentum: 0.5
     eps: 0.0000001
class torchdistill.losses.mid_level.AuxSSKDLoss(module_path='ss_module', module_io='output', reduction='mean', **kwargs)[source]

A loss module for self-supervision knowledge distillation (SSKD) that treats contrastive prediction as a self-supervision task (auxiliary task). This loss module is used at the 1st stage of SSKD method. Refactored https://github.com/xuguodong03/SSKD/blob/master/student.py

Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: “Knowledge Distillation Meets Self-Supervision” @ ECCV 2020 (2020)

Parameters:
  • module_path (str) – model’s self-supervision module path.

  • module_io (str) – ‘input’ or ‘output’ of the module in the model.

  • reduction (str) – reduction for CrossEntropyLoss.

An example YAML to instantiate AuxSSKDLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.SSWrapper4SSKD for teacher model.
 criterion:
   key: 'AuxSSKDLoss'
   kwargs:
     module_path: 'ss_module'
     module_io: 'output'
     reduction: 'mean'
class torchdistill.losses.mid_level.SSKDLoss(student_linear_module_path, teacher_linear_module_path, student_ss_module_path, teacher_ss_module_path, kl_temp, ss_temp, tf_temp, ss_ratio, tf_ratio, student_linear_module_io='output', teacher_linear_module_io='output', student_ss_module_io='output', teacher_ss_module_io='output', loss_weights=None, reduction='batchmean', **kwargs)[source]

A loss module for self-supervision knowledge distillation (SSKD). This loss module is used at the 2nd stage of SSKD method. Refactored https://github.com/xuguodong03/SSKD/blob/master/student.py

Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: “Knowledge Distillation Meets Self-Supervision” @ ECCV 2020 (2020)

Parameters:
  • student_linear_path (str) – student model’s linear module path in an auxiliary wrapper torchdistill.models.wrapper.SSWrapper4SSKD.

  • teacher_linear_path (str) – teacher model’s linear module path in an auxiliary wrapper torchdistill.models.wrapper.SSWrapper4SSKD.

  • student_ss_module_path (str) – student model’s self-supervision module path.

  • teacher_ss_module_path (str) – teacher model’s self-supervision module path.

  • kl_temp (float) – temperature to soften teacher and student’s class-probability distributions for KL divergence given original data.

  • ss_temp (float) – temperature to soften teacher and student’s self-supervision cosine similarities for KL divergence.

  • tf_temp (float) – temperature to soften teacher and student’s class-probability distributions for KL divergence given augmented data by transform.

  • ss_ratio (float) – ratio of samples with the smallest error levels used for self-supervision.

  • tf_ratio (float) – ratio of samples with the smallest error levels used for transform.

  • student_linear_module_io (str) – ‘input’ or ‘output’ of the linear module in the student model.

  • teacher_linear_module_io (str) – ‘input’ or ‘output’ of the linear module in the teacher model.

  • student_ss_module_io (str) – ‘input’ or ‘output’ of the self-supervision module in the student model.

  • teacher_ss_module_io (str) – ‘input’ or ‘output’ of the self-supervision module in the teacher model.

  • loss_weights (list[float] or None) – weights for 1) cross-entropy, 2) KL divergence for the original data, 3) KL divergence for self-supervision cosine similarities, and 4) KL divergence for the augmented data by transform.

  • reduction (str or None) – reduction for KLDivLoss. If reduction = ‘batchmean’, CrossEntropyLoss’s reduction will be ‘mean’.

An example YAML to instantiate SSKDLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.SSWrapper4SSKD for the teacher and student models.
 criterion:
   key: 'SSKDLoss'
   kwargs:
     student_linear_module_path: 'model.fc'
     teacher_linear_module_path: 'model.fc'
     student_ss_module_path: 'ss_module'
     teacher_ss_module_path: 'ss_module'
     kl_temp: 4.0
     ss_temp: 0.5
     tf_temp: 4.0
     ss_ratio: 0.75
     tf_ratio: 1.0
     loss_weights: [1.0, 0.9, 10.0, 2.7]
     reduction: 'batchmean'
class torchdistill.losses.mid_level.PADL2Loss(student_embed_module_path, teacher_embed_module_path, student_embed_module_io='output', teacher_embed_module_io='output', module_path='var_estimator', module_io='output', eps=1e-06, reduction='mean', **kwargs)[source]

A loss module for prime-aware adaptive distillation (PAD) with L2 loss. This loss module is used at the 2nd stage of PAD method.

Youcai Zhang, Zhonghao Lan, Yuchen Dai, Fangao Zeng, Yan Bai, Jie Chang, Yichen Wei: “Prime-Aware Adaptive Distillation” @ ECCV 2020 (2020)

Parameters:
  • student_embed_module_path (str) – student model’s embedding module path in an auxiliary wrapper torchdistill.models.wrapper.VarianceBranch4PAD.

  • teacher_embed_module_path (str) – teacher model’s embedding module path.

  • student_embed_module_io (str) – ‘input’ or ‘output’ of the embedding module in the student model.

  • teacher_embed_module_io (str) – ‘input’ or ‘output’ of the embedding module in the teacher model.

  • module_path (str) – student model’s variance estimator module path in an auxiliary wrapper torchdistill.models.wrapper.VarianceBranch4PAD.

  • module_io (str) – ‘input’ or ‘output’ of the variance estimator module in the student model.

  • eps (float) – constant to avoid zero division.

  • reduction (str) – loss reduction type.

An example YAML to instantiate PADL2Loss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.VarianceBranch4PAD for the student model.
 criterion:
   key: 'PADL2Loss'
   kwargs:
     student_embed_module_path: 'student_model.avgpool'
     student_embed_module_io: 'output'
     teacher_embed_module_path: 'avgpool'
     teacher_embed_module_io: 'output'
     module_path: 'var_estimator'
     module_io: 'output'
     eps: 0.000001
     reduction: 'mean'
class torchdistill.losses.mid_level.HierarchicalContextLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, reduction='mean', output_sizes=None, **kwargs)[source]

A loss module for knowledge review (KR) method. Referred to https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/reviewkd.py

Pengguang Chen, Shu Liu, Hengshuang Zhao, Jiaya Jia: “Distilling Knowledge via Knowledge Review” @ CVPR 2021 (2021)

Parameters:
  • student_module_path (str) – student model’s module path in an auxiliary wrapper torchdistill.models.wrapper.Student4KnowledgeReview.

  • student_module_io (str) – ‘input’ or ‘output’ of the module in the student model.

  • teacher_module_path (str) – teacher model’s module path.

  • teacher_module_io (str) – ‘input’ or ‘output’ of the module in the teacher model.

  • reduction (str or None) – reduction for MSELoss.

  • output_sizes (list[int] or None) – output sizes of adaptive_avg_pool2d.

An example YAML to instantiate HierarchicalContextLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.Student4KnowledgeReview for the student model.
 criterion:
   key: 'HierarchicalContextLoss'
   kwargs:
     student_module_path: 'abf_modules.4'
     student_module_io: 'output'
     teacher_module_path: 'layer1.-1.relu'
     teacher_module_io: 'input'
     reduction: 'mean'
     output_sizes: [4, 2, 1]
class torchdistill.losses.mid_level.RegularizationLoss(module_path, io_type='output', is_from_teacher=False, p=1, **kwargs)[source]

A regularization loss module.

Parameters:
  • module_path (str) – module path.

  • module_io (str) – ‘input’ or ‘output’ of the module in the student model.

  • is_from_teacher (bool) – True if you use teacher’s I/O dict. Otherwise, you use student’s I/O dict.

  • p (int) – the order of norm.

class torchdistill.losses.mid_level.KTALoss(p=1, q=2, reduction='mean', knowledge_translator_path='paraphraser', feature_adapter_path='feature_adapter', **kwargs)[source]

A loss module for knowledge translation and adaptation (KTA). This loss module is used at the 2nd stage of KTAAD method.

Tong He, Chunhua Shen, Zhi Tian, Dong Gong, Changming Sun, Youliang Yan.: “Knowledge Adaptation for Efficient Semantic Segmentation” @ CVPR 2019 (2019)

Parameters:
  • p (int) – the order of norm for differences between normalized feature adapter’s (flattened) output and knowledge translator’s (flattened) output.

  • q (int) – the order of norm for the denominator to normalize feature adapter (flattened) output.

  • reduction (str) – loss reduction type.

  • knowledge_translator_path (str) – knowledge translator module path.

  • feature_adapter_path (str) – feature adapter module path.

An example YAML to instantiate KTALoss for a teacher-student pair of DeepLabv3 with ResNet50 and LRASPP with MobileNet v3 (Large) in torchvision, using an auxiliary module torchdistill.models.wrapper.Teacher4FactorTransfer and torchdistill.models.wrapper.Student4KTAAD for the teacher and student models.
 criterion:
   key: 'KTALoss'
   kwargs:
     p: 1
     q: 2
     reduction: 'mean'
     knowledge_translator_path: 'paraphraser.encoder'
     feature_adapter_path: 'feature_adapter'
class torchdistill.losses.mid_level.AffinityLoss(student_module_path, teacher_module_path, student_module_io='output', teacher_module_io='output', reduction='mean', **kwargs)[source]

A loss module for affinity distillation in KTA. This loss module is used at the 2nd stage of KTAAD method.

Tong He, Chunhua Shen, Zhi Tian, Dong Gong, Changming Sun, Youliang Yan.: “Knowledge Adaptation for Efficient Semantic Segmentation” @ CVPR 2019 (2019)

Parameters:
  • student_module_path (str) – student model’s module path in an auxiliary wrapper torchdistill.models.wrapper.Student4KTAAD.

  • student_module_io (str) – ‘input’ or ‘output’ of the module in the student model.

  • teacher_module_path (str) – teacher model’s module path in an auxiliary wrapper torchdistill.models.wrapper.Teacher4FactorTransfer.

  • teacher_module_io (str) – ‘input’ or ‘output’ of the module in the teacher model.

  • reduction (str or None) – loss reduction type.

An example YAML to instantiate AffinityLoss for a teacher-student pair of DeepLabv3 with ResNet50 and LRASPP with MobileNet v3 (Large) in torchvision, using an auxiliary module torchdistill.models.wrapper.Teacher4FactorTransfer and torchdistill.models.wrapper.Student4KTAAD for the teacher and student models.
 criterion:
   key: 'AffinityLoss'
   kwargs:
     student_module_path: 'affinity_adapter'
     student_module_io: 'output'
     teacher_module_path: 'paraphraser.encoder'
     teacher_module_io: 'output'
     reduction: 'mean'
class torchdistill.losses.mid_level.ChSimLoss(feature_pairs, **kwargs)[source]

A loss module for Inter-Channel Correlation for Knowledge Distillation (ICKD). Refactored https://github.com/ADLab-AutoDrive/ICKD/blob/main/ImageNet/torchdistill/losses/single.py

Li Liu, Qingle Huang, Sihao Lin, Hongwei Xie, Bing Wang, Xiaojun Chang, Xiaodan Liang: “Inter-Channel Correlation for Knowledge Distillation” @ ICCV 2021 (2021)

Parameters:

feature_pairs (dict) – configuration of teacher-student module pairs to compute the L2 distance between the inter-channel correlation matrices of the student and the teacher.

An example YAML to instantiate ChSimLoss for a teacher-student pair of ResNet-34 and ResNet-18 in torchvision, using an auxiliary module torchdistill.models.wrapper.Student4ICKD.
 criterion:
   key: 'ChSimLoss'
   kwargs:
     feature_pairs:
       pair1:
         teacher:
           io: 'output'
           path: 'layer4'
         student:
           io: 'output'
           path: 'embed_dict.embed1'
         weight: 1
class torchdistill.losses.mid_level.DISTLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, beta=1.0, gamma=1.0, tau=1.0, eps=1e-08, **kwargs)[source]

A loss module for Knowledge Distillation from A Stronger Teacher (DIST). Referred to https://github.com/hunto/image_classification_sota/blob/main/lib/models/losses/dist_kd.py

Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu: “Knowledge Distillation from A Stronger Teacher” @ NeurIPS 2022 (2022)

Parameters:
  • student_module_path (str) – student model’s logit module path.

  • student_module_io (str) – ‘input’ or ‘output’ of the module in the student model.

  • teacher_module_path (str) – teacher model’s logit module path.

  • teacher_module_io – ‘input’ or ‘output’ of the module in the teacher model.

  • beta (float) – balancing factor for inter-loss.

  • gamma (float) – balancing factor for intra-loss.

  • tau (float) – hyperparameter \(\tau\) to soften class-probability distributions.

class torchdistill.losses.mid_level.SRDLoss(student_feature_module_path, student_feature_module_io, teacher_feature_module_path, teacher_feature_module_io, student_linear_module_path, student_linear_module_io, teacher_linear_module_path, teacher_linear_module_io, exponent=1.0, temperature=1.0, reduction='batchmean', **kwargs)[source]

A loss module for Understanding the Role of the Projector in Knowledge Distillation. Referred to https://github.com/roymiles/Simple-Recipe-Distillation/blob/main/imagenet/torchdistill/losses/single.py

Roy Miles, Krystian Mikolajczyk: “Understanding the Role of the Projector in Knowledge Distillation” @ AAAI 2024 (2024)

Parameters:
  • student_feature_module_path (str) – student model’s feature module path in an auxiliary wrapper torchdistill.models.wrapper.SRDModelWrapper.

  • student_feature_module_io (str) – ‘input’ or ‘output’ of the feature module in the student model.

  • teacher_feature_module_path (str) – teacher model’s feature module path in an auxiliary wrapper torchdistill.models.wrapper.SRDModelWrapper.

  • teacher_feature_module_io (str) – ‘input’ or ‘output’ of the feature module in the teacher model.

  • student_linear_module_path (str) – student model’s linear module path.

  • student_linear_module_io (str) – ‘input’ or ‘output’ of the linear module in the student model.

  • teacher_linear_module_path (str) – teacher model’s linear module path.

  • teacher_linear_module_io (str) – ‘input’ or ‘output’ of the linear module in the teacher model.

  • exponent (float) – exponent for feature distillation loss.

  • temperature (float) – hyperparameter \(\tau\) to soften class-probability distributions.

  • reduction (str or None) – loss reduction type.

class torchdistill.losses.mid_level.LogitStdKDLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, temperature, eps=1e-07, alpha=None, beta=None, reduction='batchmean', **kwargs)[source]

A standard knowledge distillation (KD) loss module with logits standardization.

Shangquan Sun, Wenqi Ren, Jingzhi Li, Rui Wang, Xiaochun Cao: “Logit Standardization in Knowledge Distillation” @ CVPR 2024 (2024)

Parameters:
  • student_module_path (str) – student model’s logit module path.

  • student_module_io (str) – ‘input’ or ‘output’ of the module in the student model.

  • teacher_module_path (str) – teacher model’s logit module path.

  • teacher_module_io (str) – ‘input’ or ‘output’ of the module in the teacher model.

  • temperature (float) – hyperparameter \(\tau\) to soften class-probability distributions.

  • eps (float) – value added to the denominator for numerical stability.

  • alpha (float) – balancing factor for \(L_{CE}\), cross-entropy.

  • beta (float or None) – balancing factor (default: \(1 - \alpha\)) for \(L_{KL}\), KL divergence between class-probability distributions softened by \(\tau\).

  • reduction (str or None) – reduction for KLDivLoss. If reduction = ‘batchmean’, CrossEntropyLoss’s reduction will be ‘mean’.


torchdistill.losses.util

torchdistill.losses.util.extract_model_loss_dict(student_outputs, targets, **kwargs)[source]

Extracts model’s loss dict.

Parameters:
  • student_outputs (Amy) – student model’s output.

  • targets (Amy) – training targets (won’t be used).

Returns:

registered function to extract model output.

Return type:

dict