import copy
import torch
from torch import nn
from .interfaces.post_epoch_proc import default_post_epoch_process_with_teacher
from .interfaces.post_forward_proc import default_post_forward_process
from .interfaces.pre_epoch_proc import default_pre_epoch_process_with_teacher
from .interfaces.pre_forward_proc import default_pre_forward_process
from .interfaces.registry import get_pre_epoch_proc_func, get_pre_forward_proc_func, get_forward_proc_func, \
get_post_forward_proc_func, get_post_epoch_proc_func
from .util import set_hooks, wrap_model, change_device, tensor2numpy2tensor, extract_io_dict, update_io_dict, \
extract_sub_model_io_dict
from ..common.constant import SELF_MODULE_PATH, def_logger
from ..common.file_util import make_parent_dirs
from ..common.main_util import load_ckpt, save_on_master
from ..common.module_util import check_if_wrapped, freeze_module_params, get_module, \
unfreeze_module_params, get_updatable_param_names
from ..datasets.util import build_data_loaders
from ..losses.registry import get_high_level_loss, get_func2extract_model_output
from ..models.util import redesign_model
from ..models.wrapper import AuxiliaryModelWrapper, build_auxiliary_model_wrapper
from ..optim.registry import get_optimizer, get_scheduler
logger = def_logger.getChild(__name__)
[docs]
class DistillationBox(object):
"""
A single-stage knowledge distillation framework.
:param teacher_model: teacher model.
:type teacher_model: nn.Module
:param student_model: student model.
:type student_model: nn.Module
:param dataset_dict: dict that contains datasets with IDs of your choice.
:type dataset_dict: dict
:param train_config: training configuration.
:type train_config: dict
:param device: target device.
:type device: torch.device
:param device_ids: target device IDs.
:type device_ids: list[int]
:param distributed: whether to be in distributed training mode.
:type distributed: bool
:param lr_factor: multiplier for learning rate.
:type lr_factor: float or int
:param accelerator: Hugging Face accelerator.
:type accelerator: accelerate.Accelerator or None
"""
[docs]
def setup_data_loaders(self, train_config):
"""
Sets up training and validation data loaders for the current training stage.
This method will be internally called when instantiating this class and when calling
:meth:`MultiStagesDistillationBox.advance_to_next_stage`.
:param train_config: training configuration.
:type train_config: dict
"""
train_data_loader_config = train_config.get('train_data_loader', dict())
if 'requires_supp' not in train_data_loader_config:
train_data_loader_config['requires_supp'] = True
val_data_loader_config = train_config.get('val_data_loader', dict())
train_data_loader, val_data_loader =\
build_data_loaders(self.dataset_dict, [train_data_loader_config, val_data_loader_config],
self.distributed, self.accelerator)
if train_data_loader is not None:
self.train_data_loader = train_data_loader
if val_data_loader is not None:
self.val_data_loader = val_data_loader
[docs]
def setup_teacher_student_models(self, teacher_config, student_config):
"""
Sets up teacher and student models for the current training stage.
This method will be internally called when instantiating this class and when calling
:meth:`MultiStagesDistillationBox.advance_to_next_stage`.
:param teacher_config: teacher configuration.
:type teacher_config: dict
:param student_config: student configuration.
:type student_config: dict
"""
unwrapped_org_teacher_model =\
self.org_teacher_model.module if check_if_wrapped(self.org_teacher_model) else self.org_teacher_model
unwrapped_org_student_model = \
self.org_student_model.module if check_if_wrapped(self.org_student_model) else self.org_student_model
self.target_teacher_pairs.clear()
self.target_student_pairs.clear()
teacher_ref_model = unwrapped_org_teacher_model
student_ref_model = unwrapped_org_student_model
if len(teacher_config) > 0 or (len(teacher_config) == 0 and self.teacher_model is None):
logger.info('[teacher model]')
model_type = 'original'
auxiliary_teacher_model_wrapper = \
build_auxiliary_model_wrapper(teacher_config, teacher_model=unwrapped_org_teacher_model,
device=self.device, device_ids=self.device_ids,
distributed=self.distributed)
if auxiliary_teacher_model_wrapper is not None:
teacher_ref_model = auxiliary_teacher_model_wrapper
model_type = type(teacher_ref_model).__name__
self.teacher_model = redesign_model(teacher_ref_model, teacher_config, 'teacher', model_type)
src_teacher_ckpt_file_path = teacher_config.get('src_ckpt', None)
if src_teacher_ckpt_file_path is not None:
load_ckpt(src_teacher_ckpt_file_path, self.teacher_model)
if len(student_config) > 0 or (len(student_config) == 0 and self.student_model is None):
logger.info('[student model]')
model_type = 'original'
auxiliary_student_model_wrapper = \
build_auxiliary_model_wrapper(student_config, student_model=unwrapped_org_student_model,
device=self.device, device_ids=self.device_ids,
distributed=self.distributed)
if auxiliary_student_model_wrapper is not None:
student_ref_model = auxiliary_student_model_wrapper
model_type = type(student_ref_model).__name__
self.student_model = redesign_model(student_ref_model, student_config, 'student', model_type)
src_student_ckpt_file_path = student_config.get('src_ckpt', None)
if src_student_ckpt_file_path is not None:
load_ckpt(src_student_ckpt_file_path, self.student_model)
self.teacher_any_frozen = \
len(teacher_config.get('frozen_modules', list())) > 0 or not teacher_config.get('requires_grad', True)
self.student_any_frozen = \
len(student_config.get('frozen_modules', list())) > 0 or not student_config.get('requires_grad', True)
self.target_teacher_pairs.extend(set_hooks(self.teacher_model, teacher_ref_model,
teacher_config, self.teacher_io_dict))
self.target_student_pairs.extend(set_hooks(self.student_model, student_ref_model,
student_config, self.student_io_dict))
self.teacher_forward_proc = get_forward_proc_func(teacher_config.get('forward_proc', None))
self.student_forward_proc = get_forward_proc_func(student_config.get('forward_proc', None))
[docs]
def setup_loss(self, train_config):
"""
Sets up a training loss module for the current training stage.
This method will be internally called when instantiating this class and when calling
:meth:`MultiStagesDistillationBox.advance_to_next_stage`.
:param train_config: training configuration.
:type train_config: dict
"""
criterion_config = train_config['criterion']
self.criterion = get_high_level_loss(criterion_config)
logger.info(self.criterion)
self.extract_model_loss = get_func2extract_model_output(criterion_config.get('func2extract_model_loss', None))
[docs]
def setup_pre_post_processes(self, train_config):
"""
Sets up pre/post-epoch/forward processes for the current training stage.
This method will be internally called when instantiating this class and when calling
:meth:`MultiStagesDistillationBox.advance_to_next_stage`.
:param train_config: training configuration.
:type train_config: dict
"""
pre_epoch_process = default_pre_epoch_process_with_teacher
if 'pre_epoch_process' in train_config:
pre_epoch_process = get_pre_epoch_proc_func(train_config['pre_epoch_process'])
setattr(DistillationBox, 'pre_epoch_process', pre_epoch_process)
pre_forward_process = default_pre_forward_process
if 'pre_forward_process' in train_config:
pre_forward_process = get_pre_forward_proc_func(train_config['pre_forward_process'])
setattr(DistillationBox, 'pre_forward_process', pre_forward_process)
post_forward_process = default_post_forward_process
if 'post_forward_process' in train_config:
post_forward_process = get_post_forward_proc_func(train_config['post_forward_process'])
setattr(DistillationBox, 'post_forward_process', post_forward_process)
post_epoch_process = default_post_epoch_process_with_teacher
if 'post_epoch_process' in train_config:
post_epoch_process = get_post_epoch_proc_func(train_config['post_epoch_process'])
setattr(DistillationBox, 'post_epoch_process', post_epoch_process)
[docs]
def setup(self, train_config):
"""
Configures a :class:`DistillationBox`/:class:`MultiStagesDistillationBox` for the current training stage.
This method will be internally called when instantiating this class and when calling
:meth:`MultiStagesDistillationBox.advance_to_next_stage`.
:param train_config: training configuration.
:type train_config: dict
"""
# Set up train and val data loaders
self.setup_data_loaders(train_config)
# Define teacher and student models used in this stage
teacher_config = train_config.get('teacher', dict())
student_config = train_config.get('student', dict())
self.setup_teacher_student_models(teacher_config, student_config)
# Define loss function used in this stage
self.setup_loss(train_config)
# Freeze parameters if specified
self.teacher_updatable = True
if not teacher_config.get('requires_grad', True):
logger.info('Freezing the whole teacher model')
freeze_module_params(self.teacher_model)
self.teacher_updatable = False
if not student_config.get('requires_grad', True):
logger.info('Freezing the whole student model')
freeze_module_params(self.student_model)
# Wrap models if necessary
teacher_any_updatable = len(get_updatable_param_names(self.teacher_model)) > 0
self.teacher_model =\
wrap_model(self.teacher_model, teacher_config, self.device, self.device_ids, self.distributed,
self.teacher_any_frozen, teacher_any_updatable)
student_any_updatable = len(get_updatable_param_names(self.student_model)) > 0
self.student_model =\
wrap_model(self.student_model, student_config, self.device, self.device_ids, self.distributed,
self.student_any_frozen, student_any_updatable)
# Set up optimizer and scheduler
optim_config = train_config.get('optimizer', dict())
optimizer_reset = False
if len(optim_config) > 0:
optim_kwargs = optim_config['kwargs']
if 'lr' in optim_kwargs:
optim_kwargs['lr'] *= self.lr_factor
module_wise_configs = optim_config.get('module_wise_configs', list())
if len(module_wise_configs) > 0:
trainable_module_list = list()
for module_wise_config in module_wise_configs:
module_wise_kwargs = dict()
if isinstance(module_wise_config.get('kwargs', None), dict):
module_wise_kwargs.update(module_wise_config['kwargs'])
if 'lr' in module_wise_kwargs:
module_wise_kwargs['lr'] *= self.lr_factor
target_model = \
self.teacher_model if module_wise_config.get('is_teacher', False) else self.student_model
module = get_module(target_model, module_wise_config['module'])
module_wise_kwargs['params'] = module.parameters() if isinstance(module, nn.Module) else [module]
trainable_module_list.append(module_wise_kwargs)
else:
trainable_module_list = nn.ModuleList([self.student_model])
if self.teacher_updatable:
logger.info('Note that you are training some/all of the modules in the teacher model')
trainable_module_list.append(self.teacher_model)
filters_params = optim_config.get('filters_params', True)
self.optimizer = \
get_optimizer(trainable_module_list, optim_config['key'],
**optim_kwargs, filters_params=filters_params)
self.optimizer.zero_grad()
self.max_grad_norm = optim_config.get('max_grad_norm', None)
self.grad_accum_step = optim_config.get('grad_accum_step', 1)
optimizer_reset = True
scheduler_config = train_config.get('scheduler', None)
if scheduler_config is not None and len(scheduler_config) > 0:
self.lr_scheduler = get_scheduler(self.optimizer, scheduler_config['key'], **scheduler_config['kwargs'])
self.scheduling_step = scheduler_config.get('scheduling_step', 0)
elif optimizer_reset:
self.lr_scheduler = None
self.scheduling_step = None
# Set up accelerator if necessary
if self.accelerator is not None:
if self.teacher_updatable:
self.teacher_model, self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
self.accelerator.prepare(self.teacher_model, self.student_model, self.optimizer,
self.train_data_loader, self.val_data_loader)
else:
self.teacher_model = self.teacher_model.to(self.accelerator.device)
if self.accelerator.state.use_fp16:
self.teacher_model = self.teacher_model.half()
self.student_model, self.optimizer, self.train_data_loader, self.val_data_loader = \
self.accelerator.prepare(self.student_model, self.optimizer,
self.train_data_loader, self.val_data_loader)
# Set up {pre,post}-{epoch,forward} processes
self.setup_pre_post_processes(train_config)
def __init__(self, teacher_model, student_model, dataset_dict,
train_config, device, device_ids, distributed, lr_factor, accelerator=None):
# Key attributes (should not be modified)
self.org_teacher_model = teacher_model
self.org_student_model = student_model
self.dataset_dict = dataset_dict
self.device = device
self.device_ids = device_ids
self.distributed = distributed
self.lr_factor = lr_factor
self.accelerator = accelerator
# Local attributes (can be updated at each stage)
self.teacher_model = None
self.student_model = None
self.teacher_forward_proc, self.student_forward_proc = None, None
self.target_teacher_pairs, self.target_student_pairs = list(), list()
self.teacher_io_dict, self.student_io_dict = dict(), dict()
self.train_data_loader, self.val_data_loader, self.optimizer, self.lr_scheduler = None, None, None, None
self.criterion, self.extract_model_loss = None, None
self.teacher_updatable, self.teacher_any_frozen, self.student_any_frozen = None, None, None
self.grad_accum_step = None
self.max_grad_norm = None
self.scheduling_step = 0
self.stage_grad_count = 0
self.setup(train_config)
self.num_epochs = train_config['num_epochs']
[docs]
def pre_epoch_process(self, *args, **kwargs):
"""
Performs a pre-epoch process Shows the summary of results.
This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
"""
raise NotImplementedError()
[docs]
def pre_forward_process(self, *args, **kwargs):
"""
Performs a pre-forward process Shows the summary of results.
This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
"""
raise NotImplementedError()
[docs]
def get_teacher_output(self, sample_batch, targets, supp_dict, **kwargs):
"""
Gets teacher model's output.
:param sample_batch: sample batch.
:type sample_batch: Any
:param targets: training targets.
:type targets: Any
:param supp_dict: supplementary dict.
:type supp_dict: dict
:return: teacher's outputs and teacher's I/O dict.
:rtype: (Any, dict)
"""
if supp_dict is None:
supp_dict = dict()
cached_data = supp_dict.get('cached_data', None)
cache_file_paths = supp_dict.get('cache_file_path', None)
teacher_outputs = None
cached_extracted_teacher_output_dict = None
# Use cached data if available
if cached_data is not None and isinstance(cached_data, dict):
teacher_outputs = cached_data['teacher_outputs']
cached_extracted_teacher_output_dict = cached_data['extracted_outputs']
if self.device.type != 'cpu':
teacher_outputs = change_device(teacher_outputs, self.device)
cached_extracted_teacher_output_dict = change_device(cached_extracted_teacher_output_dict, self.device)
if not self.teacher_updatable:
return teacher_outputs, cached_extracted_teacher_output_dict
# If no cached data
if teacher_outputs is None:
if self.teacher_updatable:
teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch,
targets, supp_dict, **kwargs)
else:
with torch.no_grad():
teacher_outputs = self.teacher_forward_proc(self.teacher_model, sample_batch,
targets, supp_dict, **kwargs)
if cached_extracted_teacher_output_dict is not None:
if isinstance(self.teacher_model, AuxiliaryModelWrapper) or \
(check_if_wrapped(self.teacher_model) and
isinstance(self.teacher_model.module, AuxiliaryModelWrapper)):
self.teacher_io_dict.update(cached_extracted_teacher_output_dict)
if isinstance(self.teacher_model, AuxiliaryModelWrapper):
self.teacher_model.secondary_forward(self.teacher_io_dict)
extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device)
return teacher_outputs, extracted_teacher_io_dict
# Deep copy of teacher info dict if auxiliary teacher model wrapper contains trainable module(s)
teacher_io_dict4cache = copy.deepcopy(self.teacher_io_dict) \
if self.teacher_updatable and isinstance(cache_file_paths, (list, tuple)) is not None else None
extracted_teacher_io_dict = extract_io_dict(self.teacher_io_dict, self.device)
extracted_teacher_io_dict[SELF_MODULE_PATH]['output'] = teacher_outputs
if isinstance(self.teacher_model, AuxiliaryModelWrapper):
self.teacher_model.secondary_forward(extracted_teacher_io_dict)
update_io_dict(extracted_teacher_io_dict, extract_io_dict(self.teacher_io_dict, self.device))
# Write cache files if output file paths (cache_file_paths) are given
if isinstance(cache_file_paths, (list, tuple)):
if teacher_io_dict4cache is None:
teacher_io_dict4cache = extracted_teacher_io_dict
cpu_device = torch.device('cpu')
for i, (teacher_output, cache_file_path) in enumerate(zip(teacher_outputs.cpu().numpy(), cache_file_paths)):
sub_dict = extract_sub_model_io_dict(teacher_io_dict4cache, i)
sub_dict = tensor2numpy2tensor(sub_dict, cpu_device)
cache_dict = {'teacher_outputs': torch.Tensor(teacher_output), 'extracted_outputs': sub_dict}
make_parent_dirs(cache_file_path)
torch.save(cache_dict, cache_file_path)
return teacher_outputs, extracted_teacher_io_dict
[docs]
def forward_process(self, sample_batch, targets=None, supp_dict=None, **kwargs):
"""
Performs forward computations for teacher and student models.
:param sample_batch: sample batch.
:type sample_batch: Any
:param targets: training targets.
:type targets: Any
:param supp_dict: supplementary dict.
:type supp_dict: dict
:return: loss tensor.
:rtype: torch.Tensor
"""
teacher_outputs, extracted_teacher_io_dict =\
self.get_teacher_output(sample_batch=sample_batch, targets=targets, supp_dict=supp_dict, **kwargs)
student_outputs = self.student_forward_proc(self.student_model, sample_batch, targets, supp_dict, **kwargs)
extracted_student_io_dict = extract_io_dict(self.student_io_dict, self.device)
extracted_student_io_dict[SELF_MODULE_PATH]['output'] = student_outputs
if isinstance(self.student_model, AuxiliaryModelWrapper):
self.student_model.secondary_forward(extracted_student_io_dict)
model_loss_dict = self.extract_model_loss(student_outputs, targets, supp_dict=supp_dict)
update_io_dict(extracted_student_io_dict, extract_io_dict(self.student_io_dict, self.device))
io_dict = {'teacher': extracted_teacher_io_dict, 'student': extracted_student_io_dict}
total_loss = self.criterion(io_dict, model_loss_dict, targets)
return total_loss
[docs]
def post_forward_process(self, *args, **kwargs):
"""
Performs a post-forward process.
This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
"""
raise NotImplementedError()
[docs]
def post_epoch_process(self, *args, **kwargs):
"""
Performs a post-epoch process.
This should be overridden by all subclasses or defined through :meth:`setup_pre_post_processes`.
"""
raise NotImplementedError()
[docs]
def clean_modules(self):
"""
Unfreezes all the teacher and student modules, clears I/O dicts, unregisters forward hook handles,
and clears the handle lists.
"""
unfreeze_module_params(self.org_teacher_model)
unfreeze_module_params(self.org_student_model)
self.teacher_io_dict.clear()
self.student_io_dict.clear()
for _, module_handle in self.target_teacher_pairs + self.target_student_pairs:
module_handle.remove()
self.target_teacher_pairs.clear()
self.target_student_pairs.clear()
[docs]
class MultiStagesDistillationBox(DistillationBox):
"""
A multi-stage knowledge distillation framework. This is a subclass of :class:`DistillationBox`.
:param teacher_model: teacher model.
:type teacher_model: nn.Module
:param student_model: student model.
:type student_model: nn.Module
:param dataset_dict: dict that contains datasets with IDs of your choice.
:type dataset_dict: dict
:param train_config: training configuration.
:type train_config: dict
:param device: target device.
:type device: torch.device
:param device_ids: target device IDs.
:type device_ids: list[int]
:param distributed: whether to be in distributed training mode.
:type distributed: bool
:param lr_factor: multiplier for learning rate.
:type lr_factor: float or int
:param accelerator: Hugging Face accelerator.
:type accelerator: accelerate.Accelerator or None
"""
def __init__(self, teacher_model, student_model, dataset_dict,
train_config, device, device_ids, distributed, lr_factor, accelerator=None):
stage1_config = train_config['stage1']
super().__init__(teacher_model, student_model, dataset_dict,
stage1_config, device, device_ids, distributed, lr_factor, accelerator)
self.train_config = train_config
self.stage_number = 1
self.stage_end_epoch = stage1_config['num_epochs']
self.num_epochs = sum(train_config[key]['num_epochs'] for key in train_config.keys() if key.startswith('stage'))
self.current_epoch = 0
logger.info('Started stage {}'.format(self.stage_number))
[docs]
def save_stage_ckpt(self, model, local_model_config):
"""
Saves the checkpoint of ``model`` for the current training stage.
:param model: model to be saved.
:type model: nn.Module
:param local_model_config: model configuration at the current training stage.
:type local_model_config: dict
"""
dst_ckpt_file_path = local_model_config.get('dst_ckpt', None)
if dst_ckpt_file_path is not None:
model_state_dict = model.module.state_dict() if check_if_wrapped(model) else model.state_dict()
make_parent_dirs(dst_ckpt_file_path)
save_on_master(model_state_dict, dst_ckpt_file_path)
[docs]
def advance_to_next_stage(self):
"""
Reads the next training stage's configuration in ``train_config`` and advances to the next training stage.
"""
self.save_stage_ckpt(self.teacher_model, self.train_config.get('teacher', dict()))
self.save_stage_ckpt(self.student_model, self.train_config.get('student', dict()))
self.clean_modules()
self.stage_grad_count = 0
self.stage_number += 1
next_stage_config = self.train_config['stage{}'.format(self.stage_number)]
self.setup(next_stage_config)
self.stage_end_epoch += next_stage_config['num_epochs']
logger.info('Advanced to stage {}'.format(self.stage_number))
[docs]
def post_epoch_process(self, *args, **kwargs):
"""
Performs a post-epoch process.
The superclass's post_epoch_process should be overridden by all subclasses or
defined through :meth:`DistillationBox.setup_pre_post_processes`.
"""
super().post_epoch_process(*args, **kwargs)
self.current_epoch += 1
if self.current_epoch == self.stage_end_epoch and self.current_epoch < self.num_epochs:
self.advance_to_next_stage()
[docs]
def get_distillation_box(teacher_model, student_model, dataset_dict,
train_config, device, device_ids, distributed, lr_factor, accelerator=None):
"""
Gets a distillation box.
:param teacher_model: teacher model.
:type teacher_model: nn.Module
:param student_model: student model.
:type student_model: nn.Module
:param dataset_dict: dict that contains datasets with IDs of your choice.
:type dataset_dict: dict
:param train_config: training configuration.
:type train_config: dict
:param device: target device.
:type device: torch.device
:param device_ids: target device IDs.
:type device_ids: list[int]
:param distributed: whether to be in distributed training mode.
:type distributed: bool
:param lr_factor: multiplier for learning rate.
:type lr_factor: float or int
:param accelerator: Hugging Face accelerator.
:type accelerator: accelerate.Accelerator or None
:return: distillation box.
:rtype: DistillationBox or MultiStagesDistillationBox
"""
if 'stage1' in train_config:
return MultiStagesDistillationBox(teacher_model, student_model, dataset_dict,
train_config, device, device_ids, distributed, lr_factor, accelerator)
return DistillationBox(teacher_model, student_model, dataset_dict, train_config,
device, device_ids, distributed, lr_factor, accelerator)