Source code for torchdistill.core.distillation

import copy

import torch
from torch import nn

from .interfaces.post_epoch_proc import default_post_epoch_process_with_teacher
from .interfaces.post_forward_proc import default_post_forward_process
from .interfaces.pre_epoch_proc import default_pre_epoch_process_with_teacher
from .interfaces.pre_forward_proc import default_pre_forward_process
from .interfaces.registry import get_pre_epoch_proc_func, get_pre_forward_proc_func, get_forward_proc_func, \
    get_post_forward_proc_func, get_post_epoch_proc_func
from .util import set_hooks, wrap_model, change_device, tensor2numpy2tensor, extract_io_dict, update_io_dict, \
    extract_sub_model_io_dict
from ..common.constant import SELF_MODULE_PATH, def_logger
from ..common.file_util import make_parent_dirs
from ..common.main_util import load_ckpt, save_on_master
from ..common.module_util import check_if_wrapped, freeze_module_params, get_module, \
    unfreeze_module_params, get_updatable_param_names
from ..datasets.util import build_data_loaders
from ..losses.registry import get_high_level_loss, get_func2extract_model_output
from ..models.util import redesign_model
from ..models.wrapper import AuxiliaryModelWrapper, build_auxiliary_model_wrapper
from ..optim.registry import get_optimizer, get_scheduler

logger = def_logger.getChild(__name__)


[docs] class DistillationBox(object): """ A single-stage knowledge distillation framework. :param teacher_model: teacher model. :type teacher_model: nn.Module :param student_model: student model. :type student_model: nn.Module :param dataset_dict: dict that contains datasets with IDs of your choice. :type dataset_dict: dict :param train_config: training configuration. :type train_config: dict :param device: target device. :type device: torch.device :param device_ids: target device IDs. :type device_ids: list[int] :param distributed: whether to be in distributed training mode. :type distributed: bool :param lr_factor: multiplier for learning rate. :type lr_factor: float or int :param accelerator: Hugging Face accelerator. :type accelerator: accelerate.Accelerator or None """
[docs] def setup_data_loaders(self, train_config): """ Sets up training and validation data loaders for the current training stage. This method will be internally called when instantiating this class and when calling :meth:`MultiStagesDistillationBox.advance_to_next_stage`. :param train_config: training configuration. :type train_config: dict """ train_data_loader_config = train_config.get('train_data_loader', dict()) if 'requires_supp' not in train_data_loader_config: train_data_loader_config['requires_supp'] = True val_data_loader_config = train_config.get('val_data_loader', dict()) train_data_loader, val_data_loader =\ build_data_loaders(self.dataset_dict, [train_data_loader_config, val_data_loader_config], self.distributed, self.accelerator) if train_data_loader is not None: self.train_data_loader = train_data_loader if val_data_loader is not None: self.val_data_loader = val_data_loader
[docs] def setup_teacher_student_models(self, teacher_config, student_config): """ Sets up teacher and student models for the current training stage. This method will be internally called when instantiating this class and when calling :meth:`MultiStagesDistillationBox.advance_to_next_stage`. :param teacher_config: teacher configuration. :type teacher_config: dict :param student_config: student configuration. :type student_config: dict """ unwrapped_org_teacher_model =\ self.org_teacher_model.module if check_if_wrapped(self.org_teacher_model) else self.org_teacher_model unwrapped_org_student_model = \ self.org_student_model.module if check_if_wrapped(self.org_student_model) else self.org_student_model self.target_teacher_pairs.clear() self.target_student_pairs.clear() teacher_ref_model = unwrapped_org_teacher_model student_ref_model = unwrapped_org_student_model if len(teacher_config) > 0 or (len(teacher_config) == 0 and self.teacher_model is None): logger.info('[teacher model]') model_type = 'original' auxiliary_teacher_model_wrapper = \ build_auxiliary_model_wrapper(teacher_config, teacher_model=unwrapped_org_teacher_model, device=self.device, device_ids=self.device_ids, distributed=self.distributed) if auxiliary_teacher_model_wrapper is not None: teacher_ref_model = auxiliary_teacher_model_wrapper model_type = type(teacher_ref_model).__name__ self.teacher_model = redesign_model(teacher_ref_model, teacher_config, 'teacher', model_type) src_teacher_ckpt_file_path = teacher_config.get('src_ckpt', None) if src_teacher_ckpt_file_path is not None: load_ckpt(src_teacher_ckpt_file_path, self.teacher_model) if len(student_config) > 0 or (len(student_config) == 0 and self.student_model is None): logger.info('[student model]') model_type = 'original' auxiliary_student_model_wrapper = \ build_auxiliary_model_wrapper(student_config, student_model=unwrapped_org_student_model, device=self.device, device_ids=self.device_ids, distributed=self.distributed) if auxiliary_student_model_wrapper is not None: student_ref_model = auxiliary_student_model_wrapper model_type = type(student_ref_model).__name__ self.student_model = redesign_model(student_ref_model, student_config, 'student', model_type) src_student_ckpt_file_path = student_config.get('src_ckpt', None) if src_student_ckpt_file_path is not None: load_ckpt(src_student_ckpt_file_path, self.student_model) self.teacher_any_frozen = \ len(teacher_config.get('frozen_modules', list())) > 0 or not teacher_config.get('requires_grad', True) self.student_any_frozen = \ len(student_config.get('frozen_modules', list())) > 0 or not student_config.get('requires_grad', True) self.target_teacher_pairs.extend(set_hooks(self.teacher_model, teacher_ref_model, teacher_config, self.teacher_io_dict)) self.target_student_pairs.extend(set_hooks(self.student_model, student_ref_model, student_config, self.student_io_dict)) self.teacher_forward_proc = get_forward_proc_func(teacher_config.get('forward_proc', None)) self.student_forward_proc = get_forward_proc_func(student_config.get('forward_proc', None))
[docs] def setup_loss(self, train_config): """ Sets up a training loss module for the current training stage. This method will be internally called when instantiating this class and when calling :meth:`MultiStagesDistillationBox.advance_to_next_stage`. :param train_config: training configuration. :type train_config: dict """ criterion_config = train_config['criterion'] self.criterion = get_high_level_loss(criterion_config) logger.info(self.criterion) self.extract_model_loss = get_func2extract_model_output(criterion_config.get('func2extract_model_loss', None))
[docs] def setup_pre_post_processes(self, train_config): """ Sets up pre/post-epoch/forward processes for the current training stage. This method will be internally called when instantiating this class and when calling :meth:`MultiStagesDistillationBox.advance_to_next_stage`. :param train_config: training configuration. :type train_config: dict """ pre_epoch_process = default_pre_epoch_process_with_teacher if 'pre_epoch_process' in train_config: pre_epoch_process = get_pre_epoch_proc_func(train_config['pre_epoch_process']) setattr(DistillationBox, 'pre_epoch_process', pre_epoch_process) pre_forward_process = default_pre_forward_process if 'pre_forward_process' in train_config: pre_forward_process = get_pre_forward_proc_func(train_config['pre_forward_process']) setattr(DistillationBox, 'pre_forward_process', pre_forward_process) post_forward_process = default_post_forward_process if 'post_forward_process' in train_config: post_forward_process = get_post_forward_proc_func(train_config['post_forward_process']) setattr(DistillationBox, 'post_forward_process', post_forward_process) post_epoch_process = default_post_epoch_process_with_teacher if 'post_epoch_process' in train_config: post_epoch_process = get_post_epoch_proc_func(train_config['post_epoch_process']) setattr(DistillationBox, 'post_epoch_process', post_epoch_process)
[docs] def setup(self, train_config): """ Configures a :class:`DistillationBox`/:class:`MultiStagesDistillationBox` for the current training stage. This method will be internally called when instantiating this class and when calling :meth:`MultiStagesDistillationBox.advance_to_next_stage`. :param train_config: training configuration. :type train_config: dict """ # Set up train and val data loaders self.setup_data_loaders(train_config) # Define teacher and student models used in this stage teacher_config = train_config.get('teacher', dict()) student_config = train_config.get('student', dict()) self.setup_teacher_student_models(teacher_config, student_config) # Define loss function used in this stage self.setup_loss(train_config) # Freeze parameters if specified self.teacher_updatable = True if not teacher_config.get('requires_grad', True): logger.info('Freezing the whole teacher model') freeze_module_params(self.teacher_model) self.teacher_updatable = False if not student_config.get('requires_grad', True): logger.info('Freezing the whole student model') freeze_module_params(self.student_model) # Wrap models if necessary teacher_any_updatable = len(get_updatable_param_names(self.teacher_model)) > 0 self.teacher_model =\ wrap_model(self.teacher_model, teacher_config, self.device, self.device_ids, self.distributed, self.teacher_any_frozen, teacher_any_updatable) student_any_updatable = len(get_updatable_param_names(self.student_model)) > 0 self.student_model =\ wrap_model(self.student_model, student_config, self.device, self.device_ids, self.distributed, self.student_any_frozen, student_any_updatable) # Set up optimizer and scheduler optim_config = train_config.get('optimizer', dict()) optimizer_reset = False if len(optim_config) > 0: optim_kwargs = optim_config['kwargs'] if 'lr' in optim_kwargs: optim_kwargs['lr'] *= self.lr_factor module_wise_configs = optim_config.get('module_wise_configs', list()) if len(module_wise_configs) > 0: trainable_module_list = list() for module_wise_config in module_wise_configs: module_wise_kwargs = dict() if isinstance(module_wise_config.get('kwargs', None), dict): module_wise_kwargs.update(module_wise_config['kwargs']) if 'lr' in module_wise_kwargs: module_wise_kwargs['lr'] *= self.lr_factor target_model = \ self.teacher_model if module_wise_config.get('is_teacher', False) else self.student_model module = get_module(target_model, module_wise_config['module']) module_wise_kwargs['params'] = module.parameters() if isinstance(module, nn.Module) else [module] trainable_module_list.append(module_wise_kwargs) else: trainable_module_list = nn.ModuleList([self.student_model]) if self.teacher_updatable: logger.info('Note that you are training some/all of the modules in the teacher model') trainable_module_list.append(self.teacher_model) filters_params = optim_config.get('filters_params', True) self.optimizer = \ get_optimizer(trainable_module_list, optim_config['key'], **optim_kwargs, filters_params=filters_params) self.optimizer.zero_grad() self.max_grad_norm = optim_config.get('max_grad_norm', None) self.grad_accum_step = optim_config.get('grad_accum_step', 1) optimizer_reset = True scheduler_config = train_config.get('scheduler', None) if scheduler_config is not None and len(scheduler_config) > 0: self.lr_scheduler = get_scheduler(self.optimizer, scheduler_config['key'], **scheduler_config['kwargs']) self.scheduling_step = scheduler_config.get('scheduling_step', 0) elif optimizer_reset: self.lr_scheduler = None self.scheduling_step = None # Set up accelerator if necessary if self.accelerator is not None: if self.teacher_updatable: self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \ self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader) else: self.teacher_model = self.teacher_model.to(self.accelerator.device) if self.accelerator.state.use_fp16: self.teacher_model = self.teacher_model.half() self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \ self.accelerator.prepare(self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader) # Set up {pre,post}-{epoch,forward} processes self.setup_pre_post_processes(train_config)
def __init__(self, teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None): # Key attributes (should not be modified) self.org_teacher_model = teacher_model self.org_student_model = student_model self.dataset_dict = dataset_dict self.device = device self.device_ids = device_ids self.distributed = distributed self.lr_factor = lr_factor self.accelerator = accelerator # Local attributes (can be updated at each stage) self.teacher_model = None self.student_model = None self.teacher_forward_proc, self.student_forward_proc = None, None self.target_teacher_pairs, self.target_student_pairs = list(), list() self.teacher_io_dict, self.student_io_dict = dict(), dict() self.train_data_loader, self.val_data_loader, self.optimizer, self.lr_scheduler = None, None, None, None self.criterion, self.extract_model_loss = None, None self.teacher_updatable, self.teacher_any_frozen, self.student_any_frozen = None, None, None self.grad_accum_step = None self.max_grad_norm = None self.scheduling_step = 0 self.stage_grad_count = 0 self.setup(train_config) self.num_epochs = train_config['num_epochs']
[docs] def pre_epoch_process(self, *args, **kwargs): """ Performs a pre-epoch process Shows the summary of results. This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`. """ raise NotImplementedError()
[docs] def pre_forward_process(self, *args, **kwargs): """ Performs a pre-forward process Shows the summary of results. This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`. """ raise NotImplementedError()
[docs] def get_teacher_output(self, sample_batch, targets, supp_dict, **kwargs): """ Gets teacher model's output. :param sample_batch: sample batch. :type sample_batch: Any :param targets: training targets. :type targets: Any :param supp_dict: supplementary dict. :type supp_dict: dict :return: teacher's outputs and teacher's I/O dict. :rtype: (Any, dict) """ if supp_dict is None: supp_dict = dict() cached_data = supp_dict.get('cached_data', None) cache_file_paths = supp_dict.get('cache_file_path', None) teacher_outputs = None cached_extracted_teacher_output_dict = None # Use cached data if available if cached_data is not None and isinstance(cached_data, dict): teacher_outputs = cached_data['teacher_outputs'] cached_extracted_teacher_output_dict = cached_data['extracted_outputs'] if self.device.type != 'cpu': teacher_outputs = change_device(teacher_outputs, self.device) cached_extracted_teacher_output_dict = change_device(cached_extracted_teacher_output_dict, self.device) if not self.teacher_updatable: return teacher_outputs, cached_extracted_teacher_output_dict # If no cached data if teacher_outputs is None: if self.teacher_updatable: teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict, **kwargs) else: with torch.no_grad(): teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch, targets, supp_dict, **kwargs) if cached_extracted_teacher_output_dict is not None: if isinstance(self.teacher_model, AuxiliaryModelWrapper) or \ (check_if_wrapped(self.teacher_model) and isinstance(self.teacher_model.module, AuxiliaryModelWrapper)): self.teacher_io_dict.update(cached_extracted_teacher_output_dict) if isinstance(self.teacher_model, AuxiliaryModelWrapper): self.teacher_model.secondary_forward(self.teacher_io_dict) extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device) return teacher_outputs, extracted_teacher_io_dict # Deep copy of teacher info dict if auxiliary teacher model wrapper contains trainable module(s) teacher_io_dict4cache = copy.deepcopy(self.teacher_io_dict) \ if self.teacher_updatable and isinstance(cache_file_paths, (list, tuple)) is not None else None extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device) extracted_teacher_io_dict[SELF_MODULE_PATH]['output'] = teacher_outputs if isinstance(self.teacher_model, AuxiliaryModelWrapper): self.teacher_model.secondary_forward(extracted_teacher_io_dict) update_io_dict(extracted_teacher_io_dict, extract_io_dict(self.teacher_io_dict, self.device)) # Write cache files if output file paths (cache_file_paths) are given if isinstance(cache_file_paths, (list, tuple)): if teacher_io_dict4cache is None: teacher_io_dict4cache = extracted_teacher_io_dict cpu_device = torch.device('cpu') for i, (teacher_output, cache_file_path) in enumerate(zip(teacher_outputs.cpu().numpy(), cache_file_paths)): sub_dict = extract_sub_model_io_dict(teacher_io_dict4cache, i) sub_dict = tensor2numpy2tensor(sub_dict, cpu_device) cache_dict = {'teacher_outputs': torch.Tensor(teacher_output), 'extracted_outputs': sub_dict} make_parent_dirs(cache_file_path) torch.save(cache_dict, cache_file_path) return teacher_outputs, extracted_teacher_io_dict
[docs] def forward_process(self, sample_batch, targets=None, supp_dict=None, **kwargs): """ Performs forward computations for teacher and student models. :param sample_batch: sample batch. :type sample_batch: Any :param targets: training targets. :type targets: Any :param supp_dict: supplementary dict. :type supp_dict: dict :return: loss tensor. :rtype: torch.Tensor """ teacher_outputs, extracted_teacher_io_dict =\ self.get_teacher_output(sample_batch=sample_batch, targets=targets, supp_dict=supp_dict, **kwargs) student_outputs = self.student_forward_proc(self.student_model, sample_batch, targets, supp_dict, **kwargs) extracted_student_io_dict = extract_io_dict(self.student_io_dict, self.device) extracted_student_io_dict[SELF_MODULE_PATH]['output'] = student_outputs if isinstance(self.student_model, AuxiliaryModelWrapper): self.student_model.secondary_forward(extracted_student_io_dict) model_loss_dict = self.extract_model_loss(student_outputs, targets, supp_dict=supp_dict) update_io_dict(extracted_student_io_dict, extract_io_dict(self.student_io_dict, self.device)) io_dict = {'teacher': extracted_teacher_io_dict, 'student': extracted_student_io_dict} total_loss = self.criterion(io_dict, model_loss_dict, targets) return total_loss
[docs] def post_forward_process(self, *args, **kwargs): """ Performs a post-forward process. This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`. """ raise NotImplementedError()
[docs] def post_epoch_process(self, *args, **kwargs): """ Performs a post-epoch process. This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`. """ raise NotImplementedError()
[docs] def clean_modules(self): """ Unfreezes all the teacher and student modules, clears I/O dicts, unregisters forward hook handles, and clears the handle lists. """ unfreeze_module_params(self.org_teacher_model) unfreeze_module_params(self.org_student_model) self.teacher_io_dict.clear() self.student_io_dict.clear() for _, module_handle in self.target_teacher_pairs + self.target_student_pairs: module_handle.remove() self.target_teacher_pairs.clear() self.target_student_pairs.clear()
[docs] class MultiStagesDistillationBox(DistillationBox): """ A multi-stage knowledge distillation framework. This is a subclass of :class:`DistillationBox`. :param teacher_model: teacher model. :type teacher_model: nn.Module :param student_model: student model. :type student_model: nn.Module :param dataset_dict: dict that contains datasets with IDs of your choice. :type dataset_dict: dict :param train_config: training configuration. :type train_config: dict :param device: target device. :type device: torch.device :param device_ids: target device IDs. :type device_ids: list[int] :param distributed: whether to be in distributed training mode. :type distributed: bool :param lr_factor: multiplier for learning rate. :type lr_factor: float or int :param accelerator: Hugging Face accelerator. :type accelerator: accelerate.Accelerator or None """ def __init__(self, teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None): stage1_config = train_config['stage1'] super().__init__(teacher_model, student_model, dataset_dict, stage1_config, device, device_ids, distributed, lr_factor, accelerator) self.train_config = train_config self.stage_number = 1 self.stage_end_epoch = stage1_config['num_epochs'] self.num_epochs = sum(train_config[key]['num_epochs'] for key in train_config.keys() if key.startswith('stage')) self.current_epoch = 0 logger.info('Started stage {}'.format(self.stage_number))
[docs] def save_stage_ckpt(self, model, local_model_config): """ Saves the checkpoint of ``model`` for the current training stage. :param model: model to be saved. :type model: nn.Module :param local_model_config: model configuration at the current training stage. :type local_model_config: dict """ dst_ckpt_file_path = local_model_config.get('dst_ckpt', None) if dst_ckpt_file_path is not None: model_state_dict = model.module.state_dict() if check_if_wrapped(model) else model.state_dict() make_parent_dirs(dst_ckpt_file_path) save_on_master(model_state_dict, dst_ckpt_file_path)
[docs] def advance_to_next_stage(self): """ Reads the next training stage's configuration in ``train_config`` and advances to the next training stage. """ self.save_stage_ckpt(self.teacher_model, self.train_config.get('teacher', dict())) self.save_stage_ckpt(self.student_model, self.train_config.get('student', dict())) self.clean_modules() self.stage_grad_count = 0 self.stage_number += 1 next_stage_config = self.train_config['stage{}'.format(self.stage_number)] self.setup(next_stage_config) self.stage_end_epoch += next_stage_config['num_epochs'] logger.info('Advanced to stage {}'.format(self.stage_number))
[docs] def post_epoch_process(self, *args, **kwargs): """ Performs a post-epoch process. The superclass's post_epoch_process should be overridden by all subclasses or defined through :meth:`DistillationBox.setup_pre_post_processes`. """ super().post_epoch_process(*args, **kwargs) self.current_epoch += 1 if self.current_epoch == self.stage_end_epoch and self.current_epoch < self.num_epochs: self.advance_to_next_stage()
[docs] def get_distillation_box(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator=None): """ Gets a distillation box. :param teacher_model: teacher model. :type teacher_model: nn.Module :param student_model: student model. :type student_model: nn.Module :param dataset_dict: dict that contains datasets with IDs of your choice. :type dataset_dict: dict :param train_config: training configuration. :type train_config: dict :param device: target device. :type device: torch.device :param device_ids: target device IDs. :type device_ids: list[int] :param distributed: whether to be in distributed training mode. :type distributed: bool :param lr_factor: multiplier for learning rate. :type lr_factor: float or int :param accelerator: Hugging Face accelerator. :type accelerator: accelerate.Accelerator or None :return: distillation box. :rtype: DistillationBox or MultiStagesDistillationBox """ if 'stage1' in train_config: return MultiStagesDistillationBox(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator) return DistillationBox(teacher_model, student_model, dataset_dict, train_config, device, device_ids, distributed, lr_factor, accelerator)