torchdistill.common


torchdistill.common.file_util

torchdistill.common.file_util.check_if_exists(file_path)[source]

Checks if a file/dir exists.

Parameters:

file_path (str) – file/dir path

Returns:

True if the given file exists

Return type:

bool

torchdistill.common.file_util.get_file_path_list(dir_path, is_recursive=False, is_sorted=False)[source]

Gets file paths for a given dir path.

Parameters:
  • dir_path (str) – dir path

  • is_recursive (bool) – if True, get file paths recursively

  • is_sorted (bool) – if True, sort file paths in ascending order

Returns:

list of file paths

Return type:

list[str]

torchdistill.common.file_util.get_dir_path_list(dir_path, is_recursive=False, is_sorted=False)[source]

Gets dir paths for a given dir path.

Parameters:
  • dir_path (str) – dir path

  • is_recursive (bool) – if True, get dir paths recursively

  • is_sorted (bool) – if True, sort dir paths in ascending order

Returns:

list of dir paths

Return type:

list[str]

torchdistill.common.file_util.make_dirs(dir_path)[source]

Makes a directory and its parent directories.

Parameters:

dir_path (str) – dir path

torchdistill.common.file_util.make_parent_dirs(file_path)[source]

Makes parent directories.

Parameters:

file_path (str) – file path

torchdistill.common.file_util.save_pickle(obj, file_path)[source]

Saves a serialized object as a file.

Parameters:
  • obj (Any) – object to be serialized

  • file_path (str) – output file path

torchdistill.common.file_util.load_pickle(file_path)[source]

Loads a deserialized object from a file.

Parameters:

file_path (str) – serialized file path

Returns:

deserialized object

Return type:

Any

torchdistill.common.file_util.get_binary_object_size(obj, unit_size=1024)[source]

Computes the size of object in bytes after serialization.

Parameters:
  • obj (Any) – object

  • unit_size (int or float) – unit file size

Returns:

size of object in bytes, divided by the unit_size

Return type:

float


torchdistill.common.main_util

torchdistill.common.main_util.import_dependencies(dependencies=None)[source]

Imports specified packages.

Parameters:

dependencies (list[dict or list[str] or (str, str) or str] or str or None) – package names.

torchdistill.common.main_util.import_get(key, package=None, **kwargs)[source]

Imports module and get its attribute.

Parameters:
  • key (str) – attribute name or package path separated by period(.).

  • package (str or None) – package path if key is just an attribute name.

Returns:

attribute of the imported module.

Return type:

Any

torchdistill.common.main_util.import_call(key, package=None, init=None, **kwargs)[source]

Imports module and call the module/function e.g., instantiation.

Parameters:
  • key (str) – module name or package path separated by period(.).

  • package (str or None) – package path if key is just an attribute name.

  • init (dict) – dict of arguments and/or keyword arguments to instantiate the imported module.

Returns:

object imported and called.

Return type:

Any

torchdistill.common.main_util.import_call_method(package, class_name=None, method_name=None, init=None, **kwargs)[source]

Imports module and call its method.

Parameters:
  • package (str) – package path.

  • class_name (str) – class name under package.

  • method_name (str) – method name of class_name class under package.

  • init (dict) – dict of arguments and/or keyword arguments to instantiate the imported module.

Returns:

object imported and called.

Return type:

Any

torchdistill.common.main_util.setup_for_distributed(is_master)[source]

Disables logging when not in master process.

Parameters:

is_master (bool) – True if it is the master process.

torchdistill.common.main_util.set_seed(seed)[source]

Sets a random seed for random, numpy, and torch (torch.manual_seed, torch.cuda.manual_seed_all).

Parameters:

seed (int) – random seed.

torchdistill.common.main_util.is_dist_avail_and_initialized()[source]

Checks if distributed model is available and initialized.

Returns:

True if distributed mode is available and initialized.

Return type:

bool

torchdistill.common.main_util.get_world_size()[source]

Gets world size.

Returns:

world size.

Return type:

int

torchdistill.common.main_util.get_rank()[source]

Gets the rank of the current process in the provided group or the default group if none was provided.

Returns:

rank of the current process in the provided group or the default group if none was provided.

Return type:

int

torchdistill.common.main_util.is_main_process()[source]

Checks if this is the main process.

Returns:

True if this is the main process.

Return type:

bool

torchdistill.common.main_util.save_on_master(*args, **kwargs)[source]

Use torch.save for args if this is the main process.

Returns:

True if this is the main process.

Return type:

bool

torchdistill.common.main_util.init_distributed_mode(world_size=1, dist_url='env://')[source]

Initialize the distributed mode.

Parameters:
  • world_size (int) – world size.

  • dist_url (str) – URL specifying how to initialize the process group.

Returns:

tuple of 1) whether distributed mode is initialized, 2) world size, and 3) list of device IDs.

Return type:

(bool, int, list[int] or None)

torchdistill.common.main_util.load_ckpt(ckpt_file_path, model=None, optimizer=None, lr_scheduler=None, strict=True)[source]

Load a checkpoint file with model, optimizer, and/or lr_scheduler.

Parameters:
  • ckpt_file_path (str) – checkpoint file path.

  • model (nn.Module) – model.

  • optimizer (nn.Module) – optimizer.

  • lr_scheduler (nn.Module) – learning rate scheduler.

  • strict (bool) – strict as a keyword argument of load_state_dict.

Returns:

tuple of best value (e.g., best validation result) and parsed args.

Return type:

(float or None, argparse.Namespace or None)

torchdistill.common.main_util.save_ckpt(model, optimizer, lr_scheduler, best_value, args, output_file_path)[source]

Save a checkpoint file including model, optimizer, best value, parsed args, and learning rate scheduler.

Parameters:
  • model (nn.Module) – model.

  • optimizer (nn.Module) – optimizer.

  • lr_scheduler (nn.Module) – learning rate scheduler.

  • best_value (float) – best value e.g., best validation result.

  • args (argparse.Namespace) – parsed args.

  • output_file_path (str) – output file path.


torchdistill.common.misc_util

torchdistill.common.misc_util.check_if_plottable()[source]

Checks if DISPLAY environmental variable is valid.

Returns:

True if DISPLAY variable is valid.

Return type:

bool

torchdistill.common.misc_util.get_classes(package_name, require_names=False)[source]

Gets classes in a given package.

Parameters:
  • package_name (str) – package name.

  • require_names (bool) – whether to preserve member names.

Returns:

collection of classes defined in the given package.

Return type:

list[(str, class)] or list[class]

torchdistill.common.misc_util.get_classes_as_dict(package_name, is_lower=False)[source]

Gets classes in a given package as dict.

Parameters:
  • package_name (str) – package name.

  • is_lower (bool) – if True, use lowercase module names.

Returns:

dict of classes defined in the given package.

Return type:

dict

torchdistill.common.misc_util.get_functions(package_name, require_names=False)[source]

Gets functions in a given package.

Parameters:
  • package_name (str) – package name.

  • require_names (bool) – whether to preserve function names.

Returns:

collection of functions defined in the given package.

Return type:

list[(str, Callable)] or list[Callable]

torchdistill.common.misc_util.get_functions_as_dict(package_name, is_lower=False)[source]

Gets function in a given package as dict.

Parameters:
  • package_name (str) – package name.

  • is_lower (bool) – if True, use lowercase module names.

Returns:

dict of classes defined in the given package.

Return type:

dict


torchdistill.common.module_util

torchdistill.common.module_util.check_if_wrapped(model)[source]

Checks if a given model is wrapped by DataParallel or DistributedDataParallel.

Parameters:

model (nn.Module) – model.

Returns:

True if model is wrapped by either DataParallel or DistributedDataParallel.

Return type:

bool

torchdistill.common.module_util.count_params(module)[source]

Returns the number of module parameters.

Parameters:

module (nn.Module) – module.

Returns:

number of model parameters.

Return type:

int

torchdistill.common.module_util.freeze_module_params(module)[source]

Freezes parameters by setting requires_grad=False for all the parameters.

Parameters:

module (nn.Module) – module.

torchdistill.common.module_util.unfreeze_module_params(module)[source]

Unfreezes parameters by setting requires_grad=True for all the parameters.

Parameters:

module (nn.Module) – module.

torchdistill.common.module_util.get_updatable_param_names(module)[source]

Gets collection of updatable parameter names.

Parameters:

module (nn.Module) – module.

Returns:

names of updatable parameters.

Return type:

list[str]

torchdistill.common.module_util.get_frozen_param_names(module)[source]

Gets collection of frozen parameter names.

Parameters:

module (nn.Module) – module.

Returns:

names of frozen parameters.

Return type:

list[str]

torchdistill.common.module_util.get_module(root_module, module_path)[source]

Gets a module specified by module_path.

Parameters:
  • root_module (nn.Module) – module.

  • module_path (str) – module path for extracting the module from root_module.

Returns:

module extracted from root_module if exists.

Return type:

nn.Module or None

torchdistill.common.module_util.get_hierarchized_dict(module_paths)[source]

Gets a hierarchical structure from module paths.

Parameters:

module_paths (list[str]) – module paths.

Returns:

module extracted from root_module if exists.

Return type:

dict

torchdistill.common.module_util.decompose(ordered_dict)[source]

Converts an ordered dict into a list of key-value pairs.

Parameters:

ordered_dict (collections.OrderedDict) – ordered dict.

Returns:

list of key-value pairs.

Return type:

list[(str, Any)]

torchdistill.common.module_util.get_components(module_paths)[source]

Converts module paths into a list of pairs of parent module and child module names.

Parameters:

module_paths (list[str]) – module paths.

Returns:

list of pairs of parent module and child module names.

Return type:

list[(str, str)]

torchdistill.common.module_util.extract_target_modules(parent_module, target_class, module_list)[source]

Extracts modules that are instance of target_class and update module_list with the extracted modules.

Parameters:
  • parent_module (nn.Module) – parent module.

  • target_class (class) – target class.

  • module_list (list[nn.Module]) – (empty) list to be filled with modules that are instances of target_class.

torchdistill.common.module_util.extract_all_child_modules(parent_module, module_list)[source]

Extracts all the child modules and update module_list with the extracted modules.

Parameters:
  • parent_module (nn.Module) – parent module.

  • module_list (list[nn.Module]) – (empty) list to be filled with child modules.


torchdistill.common.tensor_util

class torchdistill.common.tensor_util.QuantizedTensor(tensor, scale, zero_point)
scale

Alias for field number 1

tensor

Alias for field number 0

zero_point

Alias for field number 2

torchdistill.common.tensor_util.quantize_tensor(x, num_bits=8)[source]

Quantizes a tensor using num_bits int and float.

Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: β€œQuantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference” @ CVPR 2018 (2018)

Parameters:
  • x (torch.Tensor) – tensor to be quantized.

  • num_bits (int) – the number of bits for quantization.

Returns:

quantized tensor.

Return type:

QuantizedTensor

torchdistill.common.tensor_util.dequantize_tensor(q_x)[source]

Dequantizes a quantized tensor.

Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: β€œQuantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference” @ CVPR 2018 (2018)

Parameters:

q_x (QuantizedTensor) – quantized tensor to be dequantized.

Returns:

dequantized tensor.

Return type:

torch.Tensor


torchdistill.common.yaml_util

torchdistill.common.yaml_util.yaml_join(loader, node)[source]

Joins a sequence of strings.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

joined string.

Return type:

str

torchdistill.common.yaml_util.yaml_pathjoin(loader, node)[source]

Joins a sequence of strings as a (file) path.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

joined (file) path.

Return type:

str

torchdistill.common.yaml_util.yaml_expanduser(loader, node)[source]

Applies os.path.expanduser to a (file) path.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

(file) path.

Return type:

str

torchdistill.common.yaml_util.yaml_abspath(loader, node)[source]

Applies os.path.abspath to a (file) path.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

(file) path.

Return type:

str

torchdistill.common.yaml_util.yaml_import_get(loader, node)[source]

Imports module and get its attribute.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

module attribute.

Return type:

Any

torchdistill.common.yaml_util.yaml_import_call(loader, node)[source]

Imports module and call the module/function e.g., instantiation.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

result of callable module.

Return type:

Any

torchdistill.common.yaml_util.yaml_import_call_method(loader, node)[source]

Imports module and call its method.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

result of callable module.

Return type:

Any

torchdistill.common.yaml_util.yaml_getattr(loader, node)[source]

Gets an attribute of the first argument.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

module attribute.

Return type:

Any

torchdistill.common.yaml_util.yaml_setattr(loader, node)[source]

Sets an attribute to the first argument.

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

module attribute.

Return type:

Any

torchdistill.common.yaml_util.yaml_access_by_index_or_key(loader, node)[source]

Obtains a value from a specified data

Parameters:
  • loader (yaml.loader.FullLoader) – yaml loader.

  • node (yaml.nodes.Node) – node.

Returns:

accessed object.

Return type:

Any

torchdistill.common.yaml_util.load_yaml_file(yaml_file_path, custom_mode=True)[source]

Loads a yaml file optionally with convenient constructors.

Parameters:
  • yaml_file_path (str) – yaml file path.

  • custom_mode (bool) – if True, uses convenient constructors.

Returns:

loaded PyYAML object.

Return type:

Any