torchdistill.datasets
torchdistill.datasets.registry
- torchdistill.datasets.registry.register_dataset(arg=None, **kwargs)[source]
Registers a dataset class or function to instantiate it.
- Parameters:
arg (class or Callable or None) – class or function to be registered as a dataset.
- Returns:
registered dataset class or function to instantiate it.
- Return type:
class or Callable
Note
The dataset will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class/function with a key of your choice, add
key
to the decorator as below:>>> from torch.utils.data import Dataset >>> from torchdistill.datasets.registry import register_dataset >>> @register_dataset(key='my_custom_dataset') >>> class CustomDataset(Dataset): >>> def __init__(self, **kwargs): >>> print('This is my custom dataset class')
In the example,
CustomDataset
class is registered with a key “my_custom_dataset”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomDataset
class by “my_custom_dataset”.
- torchdistill.datasets.registry.register_collate_func(arg=None, **kwargs)[source]
Registers a collate function.
- Parameters:
arg (Callable or None) – function to be registered as a collate function.
- Returns:
registered function.
- Return type:
Callable
Note
The collate function will be registered as an option. You can choose the registered function by specifying the name of the function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the function with a key of your choice, add
key
to the decorator as below:>>> from torchdistill.datasets.registry import register_collate_func >>> >>> @register_collate_func(key='my_custom_collate') >>> def custom_collate(batch, label): >>> print('This is my custom collate function') >>> return batch, label
In the example,
custom_collate
function is registered with a key “my_custom_collate”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose thecustom_collate
function by “my_custom_collate”.
- torchdistill.datasets.registry.register_sample_loader(arg=None, **kwargs)[source]
Registers a sample loader class or function to instantiate it.
- Parameters:
arg (class or Callable or None) – class or function to be registered as a sample loader.
- Returns:
registered sample loader class or function to instantiate it.
- Return type:
class
Note
The sample loader will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class with a key of your choice, add
key
to the decorator as below:>>> from torch.utils.data import Sampler >>> from torchdistill.datasets.registry import register_sample_loader >>> @register_sample_loader(key='my_custom_sample_loader') >>> class CustomSampleLoader(Sampler): >>> def __init__(self, **kwargs): >>> print('This is my custom dataset class')
In the example,
CustomSampleLoader
class is registered with a key “my_custom_sample_loader”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomSampleLoader
class by “my_custom_sample_loader”.
- torchdistill.datasets.registry.register_batch_sampler(arg=None, **kwargs)[source]
Registers a batch sampler or function to instantiate it.
- Parameters:
arg (Callable or None) – function to be registered as a batch sample loader.
- Returns:
registered batch sample loader function.
- Return type:
Callable
Note
The batch sampler will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class with a key of your choice, add
key
to the decorator as below:>>> from torch.utils.data import Sampler >>> from torchdistill.datasets.registry import register_batch_sampler >>> @register_batch_sampler(key='my_custom_batch_sampler') >>> class CustomSampleLoader(Sampler): >>> def __init__(self, **kwargs): >>> print('This is my custom dataset class')
In the example,
CustomSampleLoader
class is registered with a key “my_custom_batch_sampler”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomSampleLoader
class by “my_custom_batch_sampler”.
- torchdistill.datasets.registry.register_transform(arg=None, **kwargs)[source]
Registers a transform class or function to instantiate it.
- Parameters:
arg (class or Callable or None) – class/function to be registered as a transform.
- Returns:
registered transform class/function.
- Return type:
Callable
Note
The transform will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class with a key of your choice, add
key
to the decorator as below:>>> from torch import nn >>> from torchdistill.datasets.registry import register_transform >>> @register_transform(key='my_custom_transform') >>> class CustomTransform(nn.Module): >>> def __init__(self, **kwargs): >>> print('This is my custom transform class')
In the example,
CustomTransform
class is registered with a key “my_custom_transform”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomTransform
class by “my_custom_transform”.
- torchdistill.datasets.registry.register_dataset_wrapper(arg=None, **kwargs)[source]
Registers a dataset wrapper class or function to instantiate it.
- Parameters:
arg (class or Callable or None) – class/function to be registered as a dataset wrapper.
- Returns:
registered dataset wrapper class/function.
- Return type:
Callable
Note
The dataset wrapper will be registered as an option. You can choose the registered class/function by specifying the name of the class/function or
key
you used for the registration, in a training configuration used fortorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
.If you want to register the class with a key of your choice, add
key
to the decorator as below:>>> from torch.utils.data import Dataset >>> from torchdistill.datasets.registry import register_dataset_wrapper >>> @register_transform(key='my_custom_dataset_wrapper') >>> class CustomDatasetWrapper(Dataset): >>> def __init__(self, **kwargs): >>> print('This is my custom dataset wrapper class')
In the example,
CustomDatasetWrapper
class is registered with a key “my_custom_dataset_wrapper”. When you configuretorchdistill.core.distillation.DistillationBox
ortorchdistill.core.training.TrainingBox
, you can choose theCustomDatasetWrapper
class by “my_custom_dataset_wrapper”.
- torchdistill.datasets.registry.get_dataset(key)[source]
Gets a registered dataset class or function to instantiate it.
- Parameters:
key (str) – unique key to identify the registered dataset class/function.
- Returns:
registered dataset class or function to instantiate it.
- Return type:
class or Callable
- torchdistill.datasets.registry.get_collate_func(key)[source]
Gets a registered collate function.
- Parameters:
key (str or None) – unique key to identify the registered collate function.
- Returns:
registered collate function.
- Return type:
Callable
- torchdistill.datasets.registry.get_sample_loader(key)[source]
Gets a registered sample loader class or function to instantiate it.
- Parameters:
key (str) – unique key to identify the registered sample loader class or function to instantiate it.
- Returns:
registered sample loader class or function to instantiate it.
- Return type:
class or Callable
- torchdistill.datasets.registry.get_batch_sampler(key)[source]
Gets a registered batch sampler class or function to instantiate it.
- Parameters:
key (str) – unique key to identify the registered batch sampler class or function to instantiate it.
- Returns:
registered batch sampler class or function to instantiate it.
- Return type:
class or Callable
- torchdistill.datasets.registry.get_transform(key)[source]
Gets a registered transform class or function to instantiate it.
- Parameters:
key (str) – unique key to identify the registered transform class or function to instantiate it.
- Returns:
registered transform class or function to instantiate it.
- Return type:
class or Callable
- torchdistill.datasets.registry.get_dataset_wrapper(key)[source]
Gets a registered dataset wrapper class or function to instantiate it.
- Parameters:
key (str) – unique key to identify the registered dataset wrapper class or function to instantiate it.
- Returns:
registered dataset wrapper class or function to instantiate it.
- Return type:
class or Callable
torchdistill.datasets.sample_loader
torchdistill.datasets.util
- torchdistill.datasets.util.split_dataset(dataset, lengths=None, generator_seed=None, sub_splits_configs=None, dataset_id=None)[source]
Randomly splits
dataset
into sub datasets.- Parameters:
dataset (torch.utils.data.Dataset) – dataset to be split.
lengths (list[int]) – length ratios e.g., (9, 1) by default (if None).
generator_seed (int or None) – random seed for
torch.Generator().manual_seed()
.sub_splits_configs (list[dict] or None) – sub-split configurations.
dataset_id (str or None) – dataset ID to be printed just for debugging purpose.
- Returns:
sub-splits of
dataset
.- Return type:
list[torch.utils.data.Subset]
- torchdistill.datasets.util.build_data_loader(dataset, data_loader_config, distributed, accelerator=None)[source]
Builds a data loader for
dataset
.- Parameters:
dataset (torch.utils.data.Dataset) – dataset.
data_loader_config (dict) – data loader configuration.
distributed (bool) – whether to be in distributed training mode.
accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.
- Returns:
data loader.
- Return type:
torch.utils.data.DataLoader
- torchdistill.datasets.util.build_data_loaders(dataset_dict, data_loader_configs, distributed, accelerator=None)[source]
Builds data loaders for
dataset_dict
.- Parameters:
dataset_dict (dict) – dict of dataset tied with dataset ID as a key.
data_loader_configs (list[dict]) – data loader configurations.
distributed (bool) – whether to be in distributed training mode.
accelerator (accelerate.Accelerator or None) – Hugging Face accelerator.
- Returns:
data loaders.
- Return type:
list[torch.utils.data.DataLoader]
torchdistill.datasets.wrapper
- torchdistill.datasets.wrapper.default_idx2subpath(index)[source]
Converts index to a file path including a parent dir name, which consists of the last four digits of the index.
- Parameters:
index (int) – index.
- Returns:
file path with a parent directory.
- Return type:
str
- class torchdistill.datasets.wrapper.BaseDatasetWrapper(org_dataset)[source]
A base dataset wrapper. This is a subclass of
torch.utils.data.Dataset
.- Parameters:
org_dataset (torch.utils.data.Dataset) – original dataset to be wrapped.
- class torchdistill.datasets.wrapper.CacheableDataset(org_dataset, cache_dir_path, idx2subpath_func=None, ext='.pt')[source]
A dataset wrapper that additionally loads cached files in
cache_dir_path
if exists.- Parameters:
org_dataset (torch.utils.data.Dataset) – original dataset to be wrapped.
cache_dir_path (str) – cache directory path.
idx2subpath_func (Callable or None) – function to convert a sample index to a file path.
ext (str) – cache file extension.
- class torchdistill.datasets.wrapper.CRDDatasetWrapper(org_dataset, num_negative_samples, mode, ratio)[source]
A dataset wrapper for Contrastive Representation Distillation (CRD).
Yonglong Tian, Dilip Krishnan, Phillip Isola: “Contrastive Representation Distillation” @ ICLR 2020 (2020)
- Parameters:
org_dataset (torch.utils.data.Dataset) – original dataset to be wrapped.
num_negative_samples (int) – number of negative samples for CRD.
mode (str) – either ‘exact’ or ‘relax’.
ratio (float) – ratio of class-wise negative samples.
- class torchdistill.datasets.wrapper.SSKDDatasetWrapper(org_dataset)[source]
A dataset wrapper for Self-Supervised Knowledge Distillation (SSKD).
Guodong Xu, Ziwei Liu, Xiaoxiao Li, Chen Change Loy: “Knowledge Distillation Meets Self-Supervision” @ ECCV 2020 (2020)
- Parameters:
org_dataset (torch.utils.data.Dataset) – original dataset to be wrapped.