Source code for torchdistill.datasets.util

import copy

import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data.distributed import DistributedSampler

from ..common.constant import def_logger
from ..datasets.registry import get_collate_func, get_batch_sampler, get_dataset_wrapper
from ..datasets.wrapper import default_idx2subpath, BaseDatasetWrapper, CacheableDataset

logger = def_logger.getChild(__name__)


[docs] def split_dataset(dataset, lengths=None, generator_seed=None, sub_splits_configs=None, dataset_id=None): """ Randomly splits ``dataset`` into sub datasets. :param dataset: dataset to be split. :type dataset: torch.utils.data.Dataset :param lengths: length ratios e.g., (9, 1) by default (if None). :type lengths: list[int] :param generator_seed: random seed for :meth:`torch.Generator().manual_seed`. :type generator_seed: int or None :param sub_splits_configs: sub-split configurations. :type sub_splits_configs: list[dict] or None :param dataset_id: dataset ID to be printed just for debugging purpose. :type dataset_id: str or None :return: sub-splits of ``dataset``. :rtype: list[torch.utils.data.Subset] """ org_dataset_length = len(dataset) if dataset_id is not None: logger.info('Splitting `{}` dataset ({} samples in total)'.format(dataset_id, org_dataset_length)) if lengths is None: lengths = (9, 1) total_length = sum(lengths) if total_length != org_dataset_length: lengths = [int((l / total_length) * org_dataset_length) for l in lengths] if len(lengths) > 1 and sum(lengths) != org_dataset_length: lengths[-1] = org_dataset_length - sum(lengths[:-1]) sub_datasets = random_split(dataset, lengths) if generator_seed is None \ else random_split(dataset, lengths, generator=torch.Generator().manual_seed(generator_seed)) if sub_splits_configs is None: return sub_datasets # Deep-copy dataset to configure transforms independently as dataset in Subset class is shallow-copied for sub_dataset in sub_datasets: sub_dataset.dataset = copy.deepcopy(sub_dataset.dataset) assert len(sub_datasets) == len(sub_splits_configs), \ 'len(lengths) `{}` should be equal to len(sub_splits_configs) `{}`'.format(len(sub_datasets), len(sub_splits_configs)) for sub_dataset, sub_split_kwargs in zip(sub_datasets, sub_splits_configs): sub_split_kwargs = sub_split_kwargs.copy() transform = sub_split_kwargs.pop('transform', None) target_transform = sub_split_kwargs.pop('target_transform', None) transforms = sub_split_kwargs.pop('transforms', None) if hasattr(sub_dataset.dataset, 'transform') and transform is not None: sub_dataset.dataset.transform = transform if hasattr(sub_dataset.dataset, 'target_transform') and target_transform is not None: sub_dataset.dataset.target_transform = target_transform if hasattr(sub_dataset.dataset, 'transforms') and transforms is not None: sub_dataset.dataset.transforms = transforms return sub_datasets
[docs] def build_data_loader(dataset, data_loader_config, distributed, accelerator=None): """ Builds a data loader for ``dataset``. :param dataset: dataset. :type dataset: torch.utils.data.Dataset :param data_loader_config: data loader configuration. :type data_loader_config: dict :param distributed: whether to be in distributed training mode. :type distributed: bool :param accelerator: Hugging Face accelerator. :type accelerator: accelerate.Accelerator or None :return: data loader. :rtype: torch.utils.data.DataLoader """ cache_dir_path = data_loader_config.get('cache_output', None) dataset_wrapper_config = data_loader_config.get('dataset_wrapper', None) if isinstance(dataset_wrapper_config, dict) and len(dataset_wrapper_config) > 0: dataset_wrapper_args = dataset_wrapper_config.get('args', None) dataset_wrapper_kwargs = dataset_wrapper_config.get('kwargs', None) if dataset_wrapper_args is None: dataset_wrapper_args = list() if dataset_wrapper_kwargs is None: dataset_wrapper_kwargs = dict() dataset_wrapper_cls_or_func = get_dataset_wrapper(dataset_wrapper_config['key']) dataset = dataset_wrapper_cls_or_func(dataset, *dataset_wrapper_args, **dataset_wrapper_kwargs) elif cache_dir_path is not None: dataset = CacheableDataset(dataset, cache_dir_path, idx2subpath_func=default_idx2subpath) elif data_loader_config.get('requires_supp', False): dataset = BaseDatasetWrapper(dataset) sampler_config = data_loader_config.get('sampler', dict()) sampler_kwargs = sampler_config.get('kwargs', None) if sampler_kwargs is None: sampler_kwargs = dict() if distributed and accelerator is None: sampler = DistributedSampler(dataset, **sampler_kwargs) else: sampler_cls_or_func = sampler_config['class_or_func'] sampler = sampler_cls_or_func(dataset, **sampler_kwargs) batch_sampler_config = data_loader_config.get('batch_sampler', None) batch_sampler_cls_or_func = None if batch_sampler_config is None else get_batch_sampler(batch_sampler_config['key']) batch_sampler = None if batch_sampler_cls_or_func is None \ else batch_sampler_cls_or_func(sampler, **batch_sampler_config['kwargs']) collate_fn = get_collate_func(data_loader_config.get('collate_fn', None)) data_loader_kwargs = data_loader_config['kwargs'] if batch_sampler is not None: return DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, **data_loader_kwargs) return DataLoader(dataset, sampler=sampler, collate_fn=collate_fn, **data_loader_kwargs)
[docs] def build_data_loaders(dataset_dict, data_loader_configs, distributed, accelerator=None): """ Builds data loaders for ``dataset_dict``. :param dataset_dict: dict of dataset tied with dataset ID as a key. :type dataset_dict: dict :param data_loader_configs: data loader configurations. :type data_loader_configs: list[dict] :param distributed: whether to be in distributed training mode. :type distributed: bool :param accelerator: Hugging Face accelerator. :type accelerator: accelerate.Accelerator or None :return: data loaders. :rtype: list[torch.utils.data.DataLoader] """ data_loader_list = list() for data_loader_config in data_loader_configs: dataset_id = data_loader_config.get('dataset_id', None) data_loader = None if dataset_id is None or dataset_id not in dataset_dict \ else build_data_loader(dataset_dict[dataset_id], data_loader_config, distributed, accelerator) data_loader_list.append(data_loader) return data_loader_list