torchdistill.lossesï
torchdistill.losses.registryï
- torchdistill.losses.registry.register_low_level_loss(arg=None, **kwargs)[source]ï
Registers a low-level loss class or function to instantiate it.
- Parameters:
arg (class or Callable or None) â class or function to be registered as a low-level loss.
- Returns:
registered low-level loss class or function to instantiate it.
- Return type:
class or Callable
Note
The low-level loss will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class/function with a key of your choice, add
key
to the decorator as below:>>> from torch import nn >>> from torchdistill.losses.registry import register_low_level_loss >>> >>> @register_low_level_loss(key='my_custom_low_level_loss') >>> class CustomLowLevelLoss(nn.Module): >>> def __init__(self, **kwargs): >>> print('This is my custom low-level loss class')
In the example,
CustomLowLevelLoss
class is registered with a key âmy_custom_low_level_lossâ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomLowLevelLoss
class by âmy_custom_low_level_lossâ.
- torchdistill.losses.registry.register_mid_level_loss(arg=None, **kwargs)[source]ï
Registers a middle-level loss class or function to instantiate it.
- Parameters:
arg (class or Callable or None) â class or function to be registered as a middle-level loss.
- Returns:
registered middle-level loss class or function to instantiate it.
- Return type:
class or Callable
Note
The middle-level loss will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class/function with a key of your choice, add
key
to the decorator as below:>>> from torch import nn >>> from torchdistill.losses.registry import register_mid_level_loss >>> >>> @register_mid_level_loss(key='my_custom_mid_level_loss') >>> class CustomMidLevelLoss(nn.Module): >>> def __init__(self, **kwargs): >>> print('This is my custom middle-level loss class')
In the example,
CustomMidLevelLoss
class is registered with a key âmy_custom_mid_level_lossâ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomMidLevelLoss
class by âmy_custom_mid_level_lossâ.
- torchdistill.losses.registry.register_high_level_loss(arg=None, **kwargs)[source]ï
Registers a high-level loss class or function to instantiate it.
- Parameters:
arg (class or Callable or None) â class or function to be registered as a high-level loss.
- Returns:
registered high-level loss class or function to instantiate it.
- Return type:
class or Callable
Note
The high-level loss will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class/function with a key of your choice, add
key
to the decorator as below:>>> from torch import nn >>> from torchdistill.losses.registry import register_high_level_loss >>> >>> @register_high_level_loss(key='my_custom_high_level_loss') >>> class CustomHighLevelLoss(nn.Module): >>> def __init__(self, **kwargs): >>> print('This is my custom high-level loss class')
In the example,
CustomHighLevelLoss
class is registered with a key âmy_custom_high_level_lossâ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomHighLevelLoss
class by âmy_custom_high_level_lossâ.
- torchdistill.losses.registry.register_loss_wrapper(arg=None, **kwargs)[source]ï
Registers a loss wrapper class or function to instantiate it.
- Parameters:
arg (class or Callable or None) â class or function to be registered as a loss wrapper.
- Returns:
registered loss wrapper class or function to instantiate it.
- Return type:
class or Callable
Note
The loss wrapper will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class/function with a key of your choice, add
key
to the decorator as below:>>> from torch import nn >>> from torchdistill.losses.registry import register_loss_wrapper >>> >>> @register_loss_wrapper(key='my_custom_loss_wrapper') >>> class CustomLossWrapper(nn.Module): >>> def __init__(self, **kwargs): >>> print('This is my custom loss wrapper class')
In the example,
CustomHighLevelLoss
class is registered with a key âmy_custom_loss_wrapperâ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomLossWrapper
class by âmy_custom_loss_wrapperâ.
- torchdistill.losses.registry.register_func2extract_model_output(arg=None, **kwargs)[source]ï
Registers a function to extract model output.
- Parameters:
arg (Callable or None) â function to be registered for extracting model output.
- Returns:
registered function.
- Return type:
Callable
Note
The function to extract model output will be registered as an option. You can choose the registered function by specifying the name of the function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the function with a key of your choice, add
key
to the decorator as below:>>> from torchdistill.losses.registry import register_func2extract_model_output >>> >>> @register_func2extract_model_output(key='my_custom_function2extract_model_output') >>> def custom_func2extract_model_output(batch, label): >>> print('This is my custom collate function') >>> return batch, label
In the example,
custom_func2extract_model_output
function is registered with a key âmy_custom_function2extract_model_outputâ. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose thecustom_func2extract_model_output
function by âmy_custom_function2extract_model_outputâ.
- torchdistill.losses.registry.get_low_level_loss(key, **kwargs)[source]ï
Gets a registered (low-level) loss module.
- Parameters:
key (str) â unique key to identify the registered loss class/function.
- Returns:
registered loss class or function to instantiate it.
- Return type:
nn.Module
- torchdistill.losses.registry.get_mid_level_loss(mid_level_criterion_config, criterion_wrapper_config=None)[source]ï
Gets a registered middle-level loss module.
- Parameters:
mid_level_criterion_config (dict) â middle-level loss configuration to identify and instantiate the registered middle-level loss class.
criterion_wrapper_config (dict) â middle-level loss configuration to identify and instantiate the registered middle-level loss class.
- Returns:
registered middle-level loss class or function to instantiate it.
- Return type:
nn.Module
- torchdistill.losses.registry.get_high_level_loss(criterion_config)[source]ï
Gets a registered high-level loss module.
- Parameters:
criterion_config (dict) â high-level loss configuration to identify and instantiate the registered high-level loss class.
- Returns:
registered high-level loss class or function to instantiate it.
- Return type:
nn.Module
- torchdistill.losses.registry.get_loss_wrapper(mid_level_loss, criterion_wrapper_config)[source]ï
Gets a registered loss wrapper module.
- Parameters:
mid_level_loss (nn.Module) â middle-level loss module.
criterion_wrapper_config (dict) â loss wrapper configuration to identify and instantiate the registered loss wrapper class.
- Returns:
registered loss wrapper class or function to instantiate it.
- Return type:
nn.Module
- torchdistill.losses.registry.get_func2extract_model_output(key)[source]ï
Gets a registered function to extract model output.
- Parameters:
key (str) â unique key to identify the registered function to extract model output.
- Returns:
registered function to extract model output.
- Return type:
Callable
torchdistill.losses.high_levelï
- class torchdistill.losses.high_level.AbstractLoss(sub_terms=None, **kwargs)[source]ï
An abstract loss module.
forward()
and__str__()
should be overridden by all subclasses.- Parameters:
sub_terms (dict or None) â loss module configurations.
sub_terms: ce: criterion: key: 'CrossEntropyLoss' kwargs: reduction: 'mean' criterion_wrapper: key: 'SimpleLossWrapper' kwargs: input: is_from_teacher: False module_path: '.' io: 'output' target: uses_label: True weight: 1.0
- class torchdistill.losses.high_level.WeightedSumLoss(model_term=None, sub_terms=None, **kwargs)[source]ï
A weighted sum (linear combination) of mid-/low-level loss modules.
If
model_term
contains a numerical value withweight
key, it will be a multiplier \(W_{model}\) for the sum of model-driven loss values \(\sum_{i} L_{model, i}\).\[L_{total} = W_{model} \cdot (\sum_{i} L_{model, i}) + \sum_{k} W_{sub, k} \cdot L_{sub, k}\]- Parameters:
model_term (dict or None) â model-driven loss module configurations.
sub_terms (dict or None) â loss module configurations.
torchdistill.losses.mid_levelï
- class torchdistill.losses.mid_level.SimpleLossWrapper(low_level_loss, **kwargs)[source]ï
A simple loss wrapper module designed to use low-level loss modules (e.g., loss modules in PyTorch) in torchdistillâs pipelines.
- Parameters:
low_level_loss (nn.Module) â low-level loss module e.g., torch.nn.CrossEntropyLoss.
kwargs (dict or None) â kwargs to configure what the wrapper passes
low_level_loss
.
criterion_wrapper: key: 'SimpleLossWrapper' kwargs: input: is_from_teacher: False module_path: '.' io: 'output' target: uses_label: True
- class torchdistill.losses.mid_level.DictLossWrapper(low_level_loss, weights, **kwargs)[source]ï
A dict-based wrapper module designed to use low-level loss modules (e.g., loss modules in PyTorch) in torchdistillâs pipelines. This is a subclass of
SimpleLossWrapper
and useful for models whose forward output is dict.- Parameters:
low_level_loss (nn.Module) â low-level loss module e.g., torch.nn.CrossEntropyLoss.
weights (dict) â dict contains keys that match the modelâs output dict keys and corresponding loss weights.
kwargs (dict or None) â kwargs to configure what the wrapper passes
low_level_loss
.
criterion_wrapper: key: 'DictLossWrapper' kwargs: input: is_from_teacher: False module_path: '.' io: 'output' target: uses_label: True weights: out: 1.0 aux: 0.5
- class torchdistill.losses.mid_level.KDLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, temperature, alpha=None, beta=None, reduction='batchmean', **kwargs)[source]ï
A standard knowledge distillation (KD) loss module.
\[L_{KD} = \alpha \cdot L_{CE} + (1 - \alpha) \cdot \tau^2 \cdot L_{KL}\]Geoffrey Hinton, Oriol Vinyals, Jeff Dean: âDistilling the Knowledge in a Neural Networkâ @ NIPS 2014 Deep Learning and Representation Learning Workshop (2014)
- Parameters:
student_module_path (str) â student modelâs logit module path.
student_module_io (str) â âinputâ or âoutputâ of the module in the student model.
teacher_module_path (str) â teacher modelâs logit module path.
teacher_module_io (str) â âinputâ or âoutputâ of the module in the teacher model.
temperature (float) â hyperparameter \(\tau\) to soften class-probability distributions.
alpha (float) â balancing factor for \(L_{CE}\), cross-entropy.
beta (float or None) â balancing factor (default: \(1 - \alpha\)) for \(L_{KL}\), KL divergence between class-probability distributions softened by \(\tau\).
reduction (str or None) â
reduction
for KLDivLoss. Ifreduction
= âbatchmeanâ, CrossEntropyLossâsreduction
will be âmeanâ.
- class torchdistill.losses.mid_level.FSPLoss(fsp_pairs, **kwargs)[source]ï
A loss module for the flow of solution procedure (FSP) matrix.
Junho Yim, Donggyu Joo, Jihoon Bae, Junmo Kim: âA Gift From Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learningâ @ CVPR 2017 (2017)
- Parameters:
fsp_pairs (dict) â configuration of teacher-student module pairs to compute the loss for the FSP matrix.
criterion: key: 'FSPLoss' kwargs: fsp_pairs: pair1: teacher_first: io: 'input' path: 'layer1' teacher_second: io: 'output' path: 'layer1' student_first: io: 'input' path: 'layer1' student_second: io: 'output' path: 'layer1' weight: 1 pair2: teacher_first: io: 'input' path: 'layer2.1' teacher_second: io: 'output' path: 'layer2' student_first: io: 'input' path: 'layer2.1' student_second: io: 'output' path: 'layer2' weight: 1
- class torchdistill.losses.mid_level.ATLoss(at_pairs, mode='code', **kwargs)[source]ï
A loss module for attention transfer (AT). Referred to https://github.com/szagoruyko/attention-transfer/blob/master/utils.py
Sergey Zagoruyko, Nikos Komodakis: âPaying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transferâ @ ICLR 2017 (2017)
- Parameters:
at_pairs (dict) â configuration of teacher-student module pairs to compute the loss for attention transfer.
mode (dict) â reference to follow âpaperâ or âcodeâ.
Warning
There is a discrepancy between Eq. (2) in the paper and the authorsâ implementation as pointed out in a paper and an issue at the repository. Use
mode
= âpaperâ instead of âcodeâ if you want to follow the equations in the paper.criterion: key: 'ATLoss' kwargs: at_pairs: pair1: teacher: io: 'output' path: 'layer3' student: io: 'output' path: 'layer3' weight: 1 pair2: teacher: io: 'output' path: 'layer4' student: io: 'output' path: 'layer4' weight: 1 mode: 'code'
- class torchdistill.losses.mid_level.PKTLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, eps=1e-07)[source]ï
A loss module for probabilistic knowledge transfer (PKT). Refactored https://github.com/passalis/probabilistic_kt/blob/master/nn/pkt.py
Nikolaos Passalis, Anastasios Tefas: âLearning Deep Representations with Probabilistic Knowledge Transferâ @ ECCV 2018 (2018)
- Parameters:
student_module_path (str) â student modelâs logit module path.
student_module_io (str) â âinputâ or âoutputâ of the module in the student model.
teacher_module_path (str) â teacher modelâs logit module path.
teacher_module_io (str) â âinputâ or âoutputâ of the module in the teacher model.
eps (float) â constant to avoid zero division.
criterion: key: 'PKTLoss' kwargs: student_module_path: 'fc' student_module_io: 'input' teacher_module_path: 'fc' teacher_module_io: 'input' eps: 0.0000001
- class torchdistill.losses.mid_level.FTLoss(p=1, reduction='mean', paraphraser_path='paraphraser', translator_path='translator', **kwargs)[source]ï
A loss module for factor transfer (FT). This loss module is used at the 2nd stage of FT method.
Jangho Kim, Seonguk Park, Nojun Kwak: âParaphrasing Complex Network: Network Compression via Factor Transferâ @ NeurIPS 2018 (2018)
- Parameters:
p (int) â the order of norm.
reduction (str) â loss reduction type.
paraphraser_path (str) â teacher modelâs paraphrase module path.
translator_path (str) â student modelâs translator module path.
criterion: key: 'FTLoss' kwargs: p: 1 reduction: 'mean' paraphraser_path: 'paraphraser' translator_path: 'translator'
- class torchdistill.losses.mid_level.AltActTransferLoss(feature_pairs, margin, reduction, **kwargs)[source]ï
A loss module for distillation of activation boundaries (DAB). Refactored https://github.com/bhheo/AB_distillation/blob/master/cifar10_AB_distillation.py
Byeongho Heo, Minsik Lee, Sangdoo Yun, Jin Young Choi: âKnowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neuronsâ @ AAAI 2019 (2019)
- Parameters:
feature_pairs (dict) â configuration of teacher-student module pairs to compute the loss for distillation of activation boundaries.
margin (float) â margin.
reduction (str) â loss reduction type.
criterion: key: 'AltActTransferLoss' kwargs: feature_pairs: pair1: teacher: io: 'output' path: 'layer1' student: io: 'output' path: 'connector_dict.connector1' weight: 1 pair2: teacher: io: 'output' path: 'layer2' student: io: 'output' path: 'connector_dict.connector2' weight: 1 pair3: teacher: io: 'output' path: 'layer3' student: io: 'output' path: 'connector_dict.connector3' weight: 1 pair4: teacher: io: 'output' path: 'layer4' student: io: 'output' path: 'connector_dict.connector4' weight: 1 margin: 1.0 reduction: 'mean'
- class torchdistill.losses.mid_level.RKDLoss(student_output_path, teacher_output_path, dist_factor, angle_factor, reduction, **kwargs)[source]ï
A loss module for relational knowledge distillation (RKD). Refactored https://github.com/lenscloth/RKD/blob/master/metric/loss.py
Wonpyo Park, Dongju Kim, Yan Lu, Minsu Cho: âRelational Knowledge Distillationâ @ CVPR 2019 (2019)
- Parameters:
student_output_path (str) â student module path whose output is used in this loss module.
teacher_output_path (str) â teacher module path whose output is used in this loss module.
dist_factor (float) â weight on distance-based RKD loss.
angle_factor (float) â weight on angle-based RKD loss.
reduction (str) â
reduction
for SmoothL1Loss.
criterion: key: 'RKDLoss' kwargs: teacher_output_path: 'layer4' student_output_path: 'layer4' dist_factor: 1.0 angle_factor: 2.0 reduction: 'mean'
- class torchdistill.losses.mid_level.VIDLoss(feature_pairs, **kwargs)[source]ï
A loss module for variational information distillation (VID). Referred to https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/VID.py
Sungsoo Ahn, Shell Xu Hu, Andreas Damianou, Neil D. Lawrence, Zhenwen Dai: âVariational Information Distillation for Knowledge Transferâ @ CVPR 2019 (2019)
- Parameters:
feature_pairs (dict) â configuration of teacher-student module pairs to compute the loss for variational information distillation.
criterion: key: 'VIDLoss' kwargs: feature_pairs: pair1: teacher: io: 'output' path: 'layer1' student: io: 'output' path: 'regressor_dict.regressor1' weight: 1 pair2: teacher: io: 'output' path: 'layer2' student: io: 'output' path: 'regressor_dict.regressor2' weight: 1 pair3: teacher: io: 'output' path: 'layer3' student: io: 'output' path: 'regressor_dict.regressor3' weight: 1 pair4: teacher: io: 'output' path: 'layer4' student: io: 'output' path: 'regressor_dict.regressor4' weight: 1 margin: 1.0
- class torchdistill.losses.mid_level.CCKDLoss(student_linear_path, teacher_linear_path, kernel_config, reduction, **kwargs)[source]ï
A loss module for correlation congruence for knowledge distillation (CCKD).
Baoyun Peng, Xiao Jin, Jiaheng Liu, Dongsheng Li, Yichao Wu, Yu Liu, Shunfeng Zhou, Zhaoning Zhang: âCorrelation Congruence for Knowledge Distillationâ @ ICCV 2019 (2019)
- Parameters:
student_linear_path (str) â student modelâs linear module path in an auxiliary wrapper
torchdistill.models.wrapper.Linear4CCKD
.teacher_linear_path (str) â teacher modelâs linear module path in an auxiliary wrapper
torchdistill.models.wrapper.Linear4CCKD
.kernel_config (dict) â kernel (âgaussianâ or âbilinearâ) configuration.
reduction (str) â loss reduction type.
criterion: key: 'CCKDLoss' kwargs: teacher_linear_path: 'linear' student_linear_path: 'linear' kernel_params: key: 'gaussian' gamma: 0.4 max_p: 2 reduction: 'batchmean'
- class torchdistill.losses.mid_level.SPKDLoss(student_output_path, teacher_output_path, reduction, **kwargs)[source]ï
A loss module for similarity-preserving knowledge distillation (SPKD).
Frederick Tung, Greg Mori: âSimilarity-Preserving Knowledge Distillationâ @ ICCV2019 (2019)
- Parameters:
student_output_path (str) â student module path whose output is used in this loss module.
teacher_output_path (str) â teacher module path whose output is used in this loss module.
reduction (str) â loss reduction type.
criterion: key: 'SPKDLoss' kwargs: teacher_output_path: 'layer4' student_output_path: 'layer4' reduction: 'batchmean'
- class torchdistill.losses.mid_level.CRDLoss(student_norm_module_path, student_empty_module_path, teacher_norm_module_path, input_size, output_size, num_negative_samples, num_samples, temperature=0.07, momentum=0.5, eps=1e-07)[source]ï
A loss module for contrastive representation distillation (CRD). Refactored https://github.com/HobbitLong/RepDistiller/blob/master/crd/criterion.py
Yonglong Tian, Dilip Krishnan, Phillip Isola: âContrastive Representation Distillationâ @ ICLR 2020 (2020)
- Parameters:
student_norm_module_path (str) â student modelâs normalizer module path (
torchdistill.models.wrapper.Normalizer4CRD
in an auxiliary wrappertorchdistill.models.wrapper.Linear4CRD
).student_empty_module_path (str) â student modelâs empty module path in an auxiliary wrapper
torchdistill.models.wrapper.Linear4CRD
.teacher_norm_module_path (str) â teacher modelâs normalizer module path (
torchdistill.models.wrapper.Normalizer4CRD
in an auxiliary wrappertorchdistill.models.wrapper.Linear4CRD
).input_size (int) â number of input features.
output_size (int) â number of output features.
num_negative_samples (int) â number of negative samples.
num_samples (int) â number of samples.
temperature (float) â temperature to adjust concentration level (not the temperature for
KDLoss
).momentum (float) â momentum.
eps (float) â eps.
criterion: key: 'CRDLoss' kwargs: teacher_norm_module_path: 'normalizer' student_norm_module_path: 'normalizer' student_empty_module_path: 'empty' input_size: *feature_dim output_size: &num_samples 1281167 num_negative_samples: *num_negative_samples num_samples: *num_samples temperature: 0.07 momentum: 0.5 eps: 0.0000001
- class torchdistill.losses.mid_level.AuxSSKDLoss(module_path='ss_module', module_io='output', reduction='mean', **kwargs)[source]ï
A loss module for self-supervision knowledge distillation (SSKD) that treats contrastive prediction as a self-supervision task (auxiliary task). This loss module is used at the 1st stage of SSKD method. Refactored https://github.com/xuguodong03/SSKD/blob/master/student.py
Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: âKnowledge Distillation Meets Self-Supervisionâ @ ECCV 2020 (2020)
- Parameters:
module_path (str) â modelâs self-supervision module path.
module_io (str) â âinputâ or âoutputâ of the module in the model.
reduction (str) â
reduction
for CrossEntropyLoss.
criterion: key: 'AuxSSKDLoss' kwargs: module_path: 'ss_module' module_io: 'output' reduction: 'mean'
- class torchdistill.losses.mid_level.SSKDLoss(student_linear_module_path, teacher_linear_module_path, student_ss_module_path, teacher_ss_module_path, kl_temp, ss_temp, tf_temp, ss_ratio, tf_ratio, student_linear_module_io='output', teacher_linear_module_io='output', student_ss_module_io='output', teacher_ss_module_io='output', loss_weights=None, reduction='batchmean', **kwargs)[source]ï
A loss module for self-supervision knowledge distillation (SSKD). This loss module is used at the 2nd stage of SSKD method. Refactored https://github.com/xuguodong03/SSKD/blob/master/student.py
Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: âKnowledge Distillation Meets Self-Supervisionâ @ ECCV 2020 (2020)
- Parameters:
student_linear_path (str) â student modelâs linear module path in an auxiliary wrapper
torchdistill.models.wrapper.SSWrapper4SSKD
.teacher_linear_path (str) â teacher modelâs linear module path in an auxiliary wrapper
torchdistill.models.wrapper.SSWrapper4SSKD
.student_ss_module_path (str) â student modelâs self-supervision module path.
teacher_ss_module_path (str) â teacher modelâs self-supervision module path.
kl_temp (float) â temperature to soften teacher and studentâs class-probability distributions for KL divergence given original data.
ss_temp (float) â temperature to soften teacher and studentâs self-supervision cosine similarities for KL divergence.
tf_temp (float) â temperature to soften teacher and studentâs class-probability distributions for KL divergence given augmented data by transform.
ss_ratio (float) â ratio of samples with the smallest error levels used for self-supervision.
tf_ratio (float) â ratio of samples with the smallest error levels used for transform.
student_linear_module_io (str) â âinputâ or âoutputâ of the linear module in the student model.
teacher_linear_module_io (str) â âinputâ or âoutputâ of the linear module in the teacher model.
student_ss_module_io (str) â âinputâ or âoutputâ of the self-supervision module in the student model.
teacher_ss_module_io (str) â âinputâ or âoutputâ of the self-supervision module in the teacher model.
loss_weights (list[float] or None) â weights for 1) cross-entropy, 2) KL divergence for the original data, 3) KL divergence for self-supervision cosine similarities, and 4) KL divergence for the augmented data by transform.
reduction (str or None) â
reduction
for KLDivLoss. Ifreduction
= âbatchmeanâ, CrossEntropyLossâsreduction
will be âmeanâ.
criterion: key: 'SSKDLoss' kwargs: student_linear_module_path: 'model.fc' teacher_linear_module_path: 'model.fc' student_ss_module_path: 'ss_module' teacher_ss_module_path: 'ss_module' kl_temp: 4.0 ss_temp: 0.5 tf_temp: 4.0 ss_ratio: 0.75 tf_ratio: 1.0 loss_weights: [1.0, 0.9, 10.0, 2.7] reduction: 'batchmean'
- class torchdistill.losses.mid_level.PADL2Loss(student_embed_module_path, teacher_embed_module_path, student_embed_module_io='output', teacher_embed_module_io='output', module_path='var_estimator', module_io='output', eps=1e-06, reduction='mean', **kwargs)[source]ï
A loss module for prime-aware adaptive distillation (PAD) with L2 loss. This loss module is used at the 2nd stage of PAD method.
Youcai Zhang, Zhonghao Lan, Yuchen Dai, Fangao Zeng, Yan Bai, Jie Chang, Yichen Wei: âPrime-Aware Adaptive Distillationâ @ ECCV 2020 (2020)
- Parameters:
student_embed_module_path (str) â student modelâs embedding module path in an auxiliary wrapper
torchdistill.models.wrapper.VarianceBranch4PAD
.teacher_embed_module_path (str) â teacher modelâs embedding module path.
student_embed_module_io (str) â âinputâ or âoutputâ of the embedding module in the student model.
teacher_embed_module_io (str) â âinputâ or âoutputâ of the embedding module in the teacher model.
module_path (str) â student modelâs variance estimator module path in an auxiliary wrapper
torchdistill.models.wrapper.VarianceBranch4PAD
.module_io (str) â âinputâ or âoutputâ of the variance estimator module in the student model.
eps (float) â constant to avoid zero division.
reduction (str) â loss reduction type.
criterion: key: 'PADL2Loss' kwargs: student_embed_module_path: 'student_model.avgpool' student_embed_module_io: 'output' teacher_embed_module_path: 'avgpool' teacher_embed_module_io: 'output' module_path: 'var_estimator' module_io: 'output' eps: 0.000001 reduction: 'mean'
- class torchdistill.losses.mid_level.HierarchicalContextLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, reduction='mean', output_sizes=None, **kwargs)[source]ï
A loss module for knowledge review (KR) method. Referred to https://github.com/dvlab-research/ReviewKD/blob/master/ImageNet/models/reviewkd.py
Pengguang Chen, Shu Liu, Hengshuang Zhao, Jiaya Jia: âDistilling Knowledge via Knowledge Reviewâ @ CVPR 2021 (2021)
- Parameters:
student_module_path (str) â student modelâs module path in an auxiliary wrapper
torchdistill.models.wrapper.Student4KnowledgeReview
.student_module_io (str) â âinputâ or âoutputâ of the module in the student model.
teacher_module_path (str) â teacher modelâs module path.
teacher_module_io (str) â âinputâ or âoutputâ of the module in the teacher model.
reduction (str or None) â
reduction
for MSELoss.output_sizes (list[int] or None) â output sizes of adaptive_avg_pool2d.
criterion: key: 'HierarchicalContextLoss' kwargs: student_module_path: 'abf_modules.4' student_module_io: 'output' teacher_module_path: 'layer1.-1.relu' teacher_module_io: 'input' reduction: 'mean' output_sizes: [4, 2, 1]
- class torchdistill.losses.mid_level.RegularizationLoss(module_path, io_type='output', is_from_teacher=False, p=1, **kwargs)[source]ï
A regularization loss module.
- Parameters:
module_path (str) â module path.
module_io (str) â âinputâ or âoutputâ of the module in the student model.
is_from_teacher (bool) â True if you use teacherâs I/O dict. Otherwise, you use studentâs I/O dict.
p (int) â the order of norm.
- class torchdistill.losses.mid_level.KTALoss(p=1, q=2, reduction='mean', knowledge_translator_path='paraphraser', feature_adapter_path='feature_adapter', **kwargs)[source]ï
A loss module for knowledge translation and adaptation (KTA). This loss module is used at the 2nd stage of KTAAD method.
Tong He, Chunhua Shen, Zhi Tian, Dong Gong, Changming Sun, Youliang Yan.: âKnowledge Adaptation for Efficient Semantic Segmentationâ @ CVPR 2019 (2019)
- Parameters:
p (int) â the order of norm for differences between normalized feature adapterâs (flattened) output and knowledge translatorâs (flattened) output.
q (int) â the order of norm for the denominator to normalize feature adapter (flattened) output.
reduction (str) â loss reduction type.
knowledge_translator_path (str) â knowledge translator module path.
feature_adapter_path (str) â feature adapter module path.
criterion: key: 'KTALoss' kwargs: p: 1 q: 2 reduction: 'mean' knowledge_translator_path: 'paraphraser.encoder' feature_adapter_path: 'feature_adapter'
- class torchdistill.losses.mid_level.AffinityLoss(student_module_path, teacher_module_path, student_module_io='output', teacher_module_io='output', reduction='mean', **kwargs)[source]ï
A loss module for affinity distillation in KTA. This loss module is used at the 2nd stage of KTAAD method.
Tong He, Chunhua Shen, Zhi Tian, Dong Gong, Changming Sun, Youliang Yan.: âKnowledge Adaptation for Efficient Semantic Segmentationâ @ CVPR 2019 (2019)
- Parameters:
student_module_path (str) â student modelâs module path in an auxiliary wrapper
torchdistill.models.wrapper.Student4KTAAD
.student_module_io (str) â âinputâ or âoutputâ of the module in the student model.
teacher_module_path (str) â teacher modelâs module path in an auxiliary wrapper
torchdistill.models.wrapper.Teacher4FactorTransfer
.teacher_module_io (str) â âinputâ or âoutputâ of the module in the teacher model.
reduction (str or None) â loss reduction type.
criterion: key: 'AffinityLoss' kwargs: student_module_path: 'affinity_adapter' student_module_io: 'output' teacher_module_path: 'paraphraser.encoder' teacher_module_io: 'output' reduction: 'mean'
- class torchdistill.losses.mid_level.ChSimLoss(feature_pairs, **kwargs)[source]ï
A loss module for Inter-Channel Correlation for Knowledge Distillation (ICKD). Refactored https://github.com/ADLab-AutoDrive/ICKD/blob/main/ImageNet/torchdistill/losses/single.py
Li Liu, Qingle Huang, Sihao Lin, Hongwei Xie, Bing Wang, Xiaojun Chang, Xiaodan Liang: âInter-Channel Correlation for Knowledge Distillationâ @ ICCV 2021 (2021)
- Parameters:
feature_pairs (dict) â configuration of teacher-student module pairs to compute the L2 distance between the inter-channel correlation matrices of the student and the teacher.
criterion: key: 'ChSimLoss' kwargs: feature_pairs: pair1: teacher: io: 'output' path: 'layer4' student: io: 'output' path: 'embed_dict.embed1' weight: 1
- class torchdistill.losses.mid_level.DISTLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, beta=1.0, gamma=1.0, tau=1.0, eps=1e-08, **kwargs)[source]ï
A loss module for Knowledge Distillation from A Stronger Teacher (DIST). Referred to https://github.com/hunto/image_classification_sota/blob/main/lib/models/losses/dist_kd.py
Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu: âKnowledge Distillation from A Stronger Teacherâ @ NeurIPS 2022 (2022)
- Parameters:
student_module_path (str) â student modelâs logit module path.
student_module_io (str) â âinputâ or âoutputâ of the module in the student model.
teacher_module_path (str) â teacher modelâs logit module path.
teacher_module_io â âinputâ or âoutputâ of the module in the teacher model.
beta (float) â balancing factor for inter-loss.
gamma (float) â balancing factor for intra-loss.
tau (float) â hyperparameter \(\tau\) to soften class-probability distributions.
- class torchdistill.losses.mid_level.SRDLoss(student_feature_module_path, student_feature_module_io, teacher_feature_module_path, teacher_feature_module_io, student_linear_module_path, student_linear_module_io, teacher_linear_module_path, teacher_linear_module_io, exponent=1.0, temperature=1.0, reduction='batchmean', **kwargs)[source]ï
A loss module for Understanding the Role of the Projector in Knowledge Distillation. Referred to https://github.com/roymiles/Simple-Recipe-Distillation/blob/main/imagenet/torchdistill/losses/single.py
Roy Miles, Krystian Mikolajczyk: âUnderstanding the Role of the Projector in Knowledge Distillationâ @ AAAI 2024 (2024)
- Parameters:
student_feature_module_path (str) â student modelâs feature module path in an auxiliary wrapper
torchdistill.models.wrapper.SRDModelWrapper
.student_feature_module_io (str) â âinputâ or âoutputâ of the feature module in the student model.
teacher_feature_module_path (str) â teacher modelâs feature module path in an auxiliary wrapper
torchdistill.models.wrapper.SRDModelWrapper
.teacher_feature_module_io (str) â âinputâ or âoutputâ of the feature module in the teacher model.
student_linear_module_path (str) â student modelâs linear module path.
student_linear_module_io (str) â âinputâ or âoutputâ of the linear module in the student model.
teacher_linear_module_path (str) â teacher modelâs linear module path.
teacher_linear_module_io (str) â âinputâ or âoutputâ of the linear module in the teacher model.
exponent (float) â exponent for feature distillation loss.
temperature (float) â hyperparameter \(\tau\) to soften class-probability distributions.
reduction (str or None) â loss reduction type.
- class torchdistill.losses.mid_level.LogitStdKDLoss(student_module_path, student_module_io, teacher_module_path, teacher_module_io, temperature, eps=1e-07, alpha=None, beta=None, reduction='batchmean', **kwargs)[source]ï
A standard knowledge distillation (KD) loss module with logits standardization.
Shangquan Sun, Wenqi Ren, Jingzhi Li, Rui Wang, Xiaochun Cao: âLogit Standardization in Knowledge Distillationâ @ CVPR 2024 (2024)
- Parameters:
student_module_path (str) â student modelâs logit module path.
student_module_io (str) â âinputâ or âoutputâ of the module in the student model.
teacher_module_path (str) â teacher modelâs logit module path.
teacher_module_io (str) â âinputâ or âoutputâ of the module in the teacher model.
temperature (float) â hyperparameter \(\tau\) to soften class-probability distributions.
eps (float) â value added to the denominator for numerical stability.
alpha (float) â balancing factor for \(L_{CE}\), cross-entropy.
beta (float or None) â balancing factor (default: \(1 - \alpha\)) for \(L_{KL}\), KL divergence between class-probability distributions softened by \(\tau\).
reduction (str or None) â
reduction
for KLDivLoss. Ifreduction
= âbatchmeanâ, CrossEntropyLossâsreduction
will be âmeanâ.
torchdistill.losses.utilï
- torchdistill.losses.util.extract_model_loss_dict(student_outputs, targets, **kwargs)[source]ï
Extracts modelâs loss dict.
- Parameters:
student_outputs (Amy) â student modelâs output.
targets (Amy) â training targets (wonât be used).
- Returns:
registered function to extract model output.
- Return type:
dict