Source code for sc2bench.analysis

import numpy as np
import torch
from torch import nn
from torchdistill.common.constant import def_logger
from torchdistill.common.file_util import get_binary_object_size

logger = def_logger.getChild(__name__)
ANALYZER_CLASS_DICT = dict()


[docs] def register_analysis_class(cls): """ Registers an analyzer class. :param cls: analyzer class to be registered :type cls: class :return: registered analyzer class :rtype: class """ ANALYZER_CLASS_DICT[cls.__name__] = cls return cls
[docs] class AnalyzableModule(nn.Module): """ A base module to analyze and summarize the wrapped modules and intermediate representations. :param analyzer_configs: list of analysis configurations :type analyzer_configs: list[dict] or None """ def __init__(self, analyzer_configs=None): if analyzer_configs is None: analyzer_configs = list() super().__init__() self.analyzers = [get_analyzer(analyzer_config['key'], **analyzer_config['kwargs']) for analyzer_config in analyzer_configs] self.activated_analysis = False
[docs] def forward(self, *args, **kwargs): raise NotImplementedError()
[docs] def activate_analysis(self): """ Makes internal analyzers ready to run. """ self.activated_analysis = True
[docs] def deactivate_analysis(self): """ Turns internal analyzers off. """ self.activated_analysis = False
[docs] def analyze(self, compressed_obj): """ Analyzes a compressed object using internal analyzers. :param compressed_obj: compressed object :type compressed_obj: Any """ if not self.activated_analysis: return for analyzer in self.analyzers: analyzer.analyze(compressed_obj)
[docs] def summarize(self): """ Shows each of internal analyzers' summary of results. """ for analyzer in self.analyzers: analyzer.summarize()
[docs] def clear_analysis(self): """ Clears each of internal analyzers' results. """ for analyzer in self.analyzers: analyzer.clear()
[docs] class BaseAnalyzer(object): """ A base analyzer to analyze and summarize the wrapped modules and intermediate representations. """
[docs] def analyze(self, *args, **kwargs): """ Analyzes a compressed object. """ raise NotImplementedError()
[docs] def summarize(self): """ Shows the summary of results. This should be overridden by all subclasses. """ raise NotImplementedError()
[docs] def clear(self): """ Clears the results. This should be overridden by all subclasses. """ raise NotImplementedError()
[docs] @register_analysis_class class FileSizeAnalyzer(BaseAnalyzer): """ An analyzer to measure file size of compressed object in the designated unit. :param unit: unit of data size in bytes ('B', 'KB', 'MB') :type unit: str """ UNIT_DICT = {'B': 1, 'KB': 1024, 'MB': 1024 * 1024} def __init__(self, unit='KB', **kwargs): self.unit = unit self.unit_size = self.UNIT_DICT[unit] self.kwargs = kwargs self.file_size_list = list()
[docs] def analyze(self, compressed_obj): """ Computes and appends binary object size of the compressed object. :param compressed_obj: compressed object :type compressed_obj: Any """ file_size = get_binary_object_size(compressed_obj, unit_size=self.unit_size) self.file_size_list.append(file_size)
[docs] def summarize(self): """ Computes and shows mean and std of the stored file sizes and the number of samples . """ file_sizes = np.array(self.file_size_list) logger.info('Bottleneck size [{}]: mean {} std {} for {} samples'.format(self.unit, file_sizes.mean(), file_sizes.std(), len(file_sizes)))
[docs] def clear(self): """ Clears the file size list. """ self.file_size_list.clear()
[docs] @register_analysis_class class FileSizeAccumulator(FileSizeAnalyzer): """ An accumulator to store pre-computed file size in the designated unit. :param unit: unit of data size in bytes ('B', 'KB', 'MB') :type unit: str """ UNIT_DICT = {'B': 1, 'KB': 1024, 'MB': 1024 * 1024} def __init__(self, unit='KB', **kwargs): super().__init__(unit=unit, **kwargs)
[docs] def analyze(self, file_size): """ Appends a file size. :param file_size: pre-computed file size :type file_size: int or float """ self.file_size_list.append(file_size / self.unit_size)
[docs] def get_analyzer(cls_name, **kwargs): """ Gets an analyzer module. :param cls_name: analyzer class name :type cls_name: str :param kwargs: kwargs for the analyzer class :type kwargs: dict :return: analyzer module :rtype: BaseAnalyzer or None """ if cls_name not in ANALYZER_CLASS_DICT: return None return ANALYZER_CLASS_DICT[cls_name](**kwargs)
[docs] def check_if_analyzable(module): """ Checks if a module is an instance of `AnalyzableModule`. :param module: module :type module: Any :return: True if the module is an instance of `AnalyzableModule`. False otherwise :rtype: bool """ return isinstance(module, AnalyzableModule)
[docs] def analyze_model_size(model, encoder_paths=None, additional_rest_paths=None, ignores_dtype_error=True): """ Approximates numbers of bits used for parameters of the whole model, encoder, and the rest of the model. :param model: model :type model: nn.Module :param encoder_paths: list of module paths for the model to be considered as part of encoder's parameters :type encoder_paths: list[str] or None :param additional_rest_paths: list of additional rest module paths whose parameters should be shared with encoder e.g., module path of entropy bottleneck in the model if applied :type additional_rest_paths: list[str] or None :param ignores_dtype_error: if False, raise an error when any unexpected dtypes are found :type ignores_dtype_error: bool :return: model size (sum of param x num_bits) with three keys: model (whole model), encoder, and the rest :rtype: dict """ model_size = 0 encoder_size = 0 rest_size = 0 if encoder_paths is None: encoder_paths = list() if additional_rest_paths is None: additional_rest_paths = list() encoder_path_set = set(encoder_paths) additional_rest_path_set = set(additional_rest_paths) for k, v in model.state_dict().items(): dim = v.dim() param_count = 1 if dim == 0 else np.prod(v.size()) v_dtype = v.dtype if v_dtype in (torch.int64, torch.float64): num_bits = 64 elif v_dtype in (torch.int32, torch.float32): num_bits = 32 elif v_dtype in (torch.int16, torch.float16, torch.bfloat16): num_bits = 16 elif v_dtype in (torch.int8, torch.uint8, torch.qint8, torch.quint8): num_bits = 8 elif v_dtype == torch.bool: num_bits = 2 else: error_message = f'For {k}, dtype `{v_dtype}` is not expected' if ignores_dtype_error: print(error_message) continue else: raise TypeError(error_message) param_size = num_bits * param_count model_size += param_size match_flag = False for encoder_path in encoder_path_set: if k.startswith(encoder_path): encoder_size += param_size if k in additional_rest_path_set: rest_size += param_size match_flag = True break if not match_flag: rest_size += param_size return {'model': model_size, 'encoder': encoder_size, 'rest': rest_size}