torchdistill.datasets


torchdistill.datasets.registry

torchdistill.datasets.registry.register_dataset(arg=None, **kwargs)[source]

Registers a dataset class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as a dataset.

Returns:

registered dataset class or function to instantiate it.

Return type:

class or Callable

Note

The dataset will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class/function with a key of your choice, add key to the decorator as below:

>>> from torch.utils.data import Dataset
>>> from torchdistill.datasets.registry import register_dataset
>>> @register_dataset(key='my_custom_dataset')
>>> class CustomDataset(Dataset):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom dataset class')

In the example, CustomDataset class is registered with a key “my_custom_dataset”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomDataset class by “my_custom_dataset”.

torchdistill.datasets.registry.register_collate_func(arg=None, **kwargs)[source]

Registers a collate function.

Parameters:

arg (Callable or None) – function to be registered as a collate function.

Returns:

registered function.

Return type:

Callable

Note

The collate function will be registered as an option. You can choose the registered function by specifying the name of the function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the function with a key of your choice, add key to the decorator as below:

>>> from torchdistill.datasets.registry import register_collate_func
>>>
>>> @register_collate_func(key='my_custom_collate')
>>> def custom_collate(batch, label):
>>>     print('This is my custom collate function')
>>>     return batch, label

In the example, custom_collate function is registered with a key “my_custom_collate”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the custom_collate function by “my_custom_collate”.

torchdistill.datasets.registry.register_sample_loader(arg=None, **kwargs)[source]

Registers a sample loader class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class or function to be registered as a sample loader.

Returns:

registered sample loader class or function to instantiate it.

Return type:

class

Note

The sample loader will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class with a key of your choice, add key to the decorator as below:

>>> from torch.utils.data import Sampler
>>> from torchdistill.datasets.registry import register_sample_loader
>>> @register_sample_loader(key='my_custom_sample_loader')
>>> class CustomSampleLoader(Sampler):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom dataset class')

In the example, CustomSampleLoader class is registered with a key “my_custom_sample_loader”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomSampleLoader class by “my_custom_sample_loader”.

torchdistill.datasets.registry.register_batch_sampler(arg=None, **kwargs)[source]

Registers a batch sampler or function to instantiate it.

Parameters:

arg (Callable or None) – function to be registered as a batch sample loader.

Returns:

registered batch sample loader function.

Return type:

Callable

Note

The batch sampler will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class with a key of your choice, add key to the decorator as below:

>>> from torch.utils.data import Sampler
>>> from torchdistill.datasets.registry import register_batch_sampler
>>> @register_batch_sampler(key='my_custom_batch_sampler')
>>> class CustomSampleLoader(Sampler):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom dataset class')

In the example, CustomSampleLoader class is registered with a key “my_custom_batch_sampler”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomSampleLoader class by “my_custom_batch_sampler”.

torchdistill.datasets.registry.register_transform(arg=None, **kwargs)[source]

Registers a transform class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class/function to be registered as a transform.

Returns:

registered transform class/function.

Return type:

Callable

Note

The transform will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class with a key of your choice, add key to the decorator as below:

>>> from torch import nn
>>> from torchdistill.datasets.registry import register_transform
>>> @register_transform(key='my_custom_transform')
>>> class CustomTransform(nn.Module):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom transform class')

In the example, CustomTransform class is registered with a key “my_custom_transform”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomTransform class by “my_custom_transform”.

torchdistill.datasets.registry.register_dataset_wrapper(arg=None, **kwargs)[source]

Registers a dataset wrapper class or function to instantiate it.

Parameters:

arg (class or Callable or None) – class/function to be registered as a dataset wrapper.

Returns:

registered dataset wrapper class/function.

Return type:

Callable

Note

The dataset wrapper will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or key you used for the registration, in a training configuration used for torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox.

If you want to register the class with a key of your choice, add key to the decorator as below:

>>> from torch.utils.data import Dataset
>>> from torchdistill.datasets.registry import register_dataset_wrapper
>>> @register_transform(key='my_custom_dataset_wrapper')
>>> class CustomDatasetWrapper(Dataset):
>>>     def __init__(self, **kwargs):
>>>         print('This is my custom dataset wrapper class')

In the example, CustomDatasetWrapper class is registered with a key “my_custom_dataset_wrapper”. When you configure torchdistill.core.distillation.DistillationBox or torchdistill.core.training.TrainingBox, you can choose the CustomDatasetWrapper class by “my_custom_dataset_wrapper”.

torchdistill.datasets.registry.get_dataset(key)[source]

Gets a registered dataset class or function to instantiate it.

Parameters:

key (str) – unique key to identify the registered dataset class/function.

Returns:

registered dataset class or function to instantiate it.

Return type:

class or Callable

torchdistill.datasets.registry.get_collate_func(key)[source]

Gets a registered collate function.

Parameters:

key (str) – unique key to identify the registered collate function.

Returns:

registered collate function.

Return type:

Callable

torchdistill.datasets.registry.get_sample_loader(key)[source]

Gets a registered sample loader class or function to instantiate it.

Parameters:

key (str) – unique key to identify the registered sample loader class or function to instantiate it.

Returns:

registered sample loader class or function to instantiate it.

Return type:

class or Callable

torchdistill.datasets.registry.get_batch_sampler(key)[source]

Gets a registered batch sampler class or function to instantiate it.

Parameters:

key (str) – unique key to identify the registered batch sampler class or function to instantiate it.

Returns:

registered batch sampler class or function to instantiate it.

Return type:

class or Callable

torchdistill.datasets.registry.get_transform(key)[source]

Gets a registered transform class or function to instantiate it.

Parameters:

key (str) – unique key to identify the registered transform class or function to instantiate it.

Returns:

registered transform class or function to instantiate it.

Return type:

class or Callable

torchdistill.datasets.registry.get_dataset_wrapper(key)[source]

Gets a registered dataset wrapper class or function to instantiate it.

Parameters:

key (str) – unique key to identify the registered dataset wrapper class or function to instantiate it.

Returns:

registered dataset wrapper class or function to instantiate it.

Return type:

class or Callable


torchdistill.datasets.sample_loader

class torchdistill.datasets.sample_loader.JpegCompressionLoader(jpeg_quality=None)[source]

A sample loader with JPEG compression.

Parameters:

jpeg_quality (int) – quality for JPEG compression with PIL.


torchdistill.datasets.util

torchdistill.datasets.util.split_dataset(dataset, lengths=None, generator_seed=None, sub_splits_configs=None, dataset_id=None)[source]

Randomly splits dataset into sub datasets.

Parameters:
  • dataset (torch.utils.data.Dataset) – dataset to be split.

  • lengths (list[int]) – length ratios e.g., (9, 1) by default (if None).

  • generator_seed (int or None) – random seed for torch.Generator().manual_seed().

  • sub_splits_configs (list[dict] or None) – sub-split configurations.

  • dataset_id (str or None) – dataset ID to be printed just for debugging purpose.

Returns:

sub-splits of dataset.

Return type:

list[torch.utils.data.Subset]

torchdistill.datasets.util.build_data_loader(dataset, data_loader_config, distributed, accelerator=None)[source]

Builds a data loader for dataset.

Parameters:
  • dataset (torch.utils.data.Dataset) – dataset.

  • data_loader_config (dict) – data loader configuration.

  • distributed (bool) – whether to be in distributed training mode.

  • accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.

Returns:

data loader.

Return type:

torch.utils.data.DataLoader

torchdistill.datasets.util.build_data_loaders(dataset_dict, data_loader_configs, distributed, accelerator=None)[source]

Builds data loaders for dataset_dict.

Parameters:
  • dataset_dict (dict) – dict of dataset tied with dataset ID as a key.

  • data_loader_configs (list[dict]) – data loader configurations.

  • distributed (bool) – whether to be in distributed training mode.

  • accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.

Returns:

data loaders.

Return type:

list[torch.utils.data.DataLoader]


torchdistill.datasets.wrapper

torchdistill.datasets.wrapper.default_idx2subpath(index)[source]

Converts index to a file path including a parent dir name, which consists of the last four digits of the index.

Parameters:

index (int) – index.

Returns:

file path with a parent directory.

Return type:

str

class torchdistill.datasets.wrapper.BaseDatasetWrapper(org_dataset)[source]

A base dataset wrapper. This is a subclass of torch.utils.data.Dataset.

Parameters:

org_dataset (torch.utils.data.Dataset) – original dataset to be wrapped.

class torchdistill.datasets.wrapper.CacheableDataset(org_dataset, cache_dir_path, idx2subpath_func=None, ext='.pt')[source]

A dataset wrapper that additionally loads cached files in cache_dir_path if exists.

Parameters:
  • org_dataset (torch.utils.data.Dataset) – original dataset to be wrapped.

  • cache_dir_path (str) – cache directory path.

  • idx2subpath_func (Callable or None) – function to convert a sample index to a file path.

  • ext (str) – cache file extension.

class torchdistill.datasets.wrapper.CRDDatasetWrapper(org_dataset, num_negative_samples, mode, ratio)[source]

A dataset wrapper for Contrastive Representation Distillation (CRD).

Yonglong Tian, Dilip Krishnan, Phillip Isola: “Contrastive Representation Distillation” @ ICLR 2020 (2020)

Parameters:
  • org_dataset (torch.utils.data.Dataset) – original dataset to be wrapped.

  • num_negative_samples (int) – number of negative samples for CRD.

  • mode (str) – either ‘exact’ or ‘relax’.

  • ratio (float) – ratio of class-wise negative samples.

class torchdistill.datasets.wrapper.SSKDDatasetWrapper(org_dataset)[source]

A dataset wrapper for Self-Supervised Knowledge Distillation (SSKD).

Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: “Knowledge Distillation Meets Self-Supervision” @ ECCV 2020 (2020)

Parameters:

org_dataset (torch.utils.data.Dataset) – original dataset to be wrapped.