torchdistill.commonο
torchdistill.common.file_utilο
- torchdistill.common.file_util.check_if_exists(file_path)[source]ο
Checks if a file/dir exists.
- Parameters:
file_path (str) β file/dir path
- Returns:
True if the given file exists
- Return type:
bool
- torchdistill.common.file_util.get_file_path_list(dir_path, is_recursive=False, is_sorted=False)[source]ο
Gets file paths for a given dir path.
- Parameters:
dir_path (str) β dir path
is_recursive (bool) β if True, get file paths recursively
is_sorted (bool) β if True, sort file paths in ascending order
- Returns:
list of file paths
- Return type:
list[str]
- torchdistill.common.file_util.get_dir_path_list(dir_path, is_recursive=False, is_sorted=False)[source]ο
Gets dir paths for a given dir path.
- Parameters:
dir_path (str) β dir path
is_recursive (bool) β if True, get dir paths recursively
is_sorted (bool) β if True, sort dir paths in ascending order
- Returns:
list of dir paths
- Return type:
list[str]
- torchdistill.common.file_util.make_dirs(dir_path)[source]ο
Makes a directory and its parent directories.
- Parameters:
dir_path (str) β dir path
- torchdistill.common.file_util.make_parent_dirs(file_path)[source]ο
Makes parent directories.
- Parameters:
file_path (str) β file path
- torchdistill.common.file_util.save_pickle(obj, file_path)[source]ο
Saves a serialized object as a file.
- Parameters:
obj (Any) β object to be serialized
file_path (str) β output file path
torchdistill.common.main_utilο
- torchdistill.common.main_util.import_dependencies(dependencies=None)[source]ο
Imports specified packages.
- Parameters:
dependencies (list[dict or list[str] or (str, str) or str] or str or None) β package names.
- torchdistill.common.main_util.import_get(key, package=None, **kwargs)[source]ο
Imports module and get its attribute.
- Parameters:
key (str) β attribute name or package path separated by period(.).
package (str or None) β package path if
key
is just an attribute name.
- Returns:
attribute of the imported module.
- Return type:
Any
- torchdistill.common.main_util.import_call(key, package=None, init=None, **kwargs)[source]ο
Imports module and call the module/function e.g., instantiation.
- Parameters:
key (str) β module name or package path separated by period(.).
package (str or None) β package path if
key
is just an attribute name.init (dict) β dict of arguments and/or keyword arguments to instantiate the imported module.
- Returns:
object imported and called.
- Return type:
Any
- torchdistill.common.main_util.import_call_method(package, class_name=None, method_name=None, init=None, **kwargs)[source]ο
Imports module and call its method.
- Parameters:
package (str) β package path.
class_name (str) β class name under
package
.method_name (str) β method name of
class_name
class underpackage
.init (dict) β dict of arguments and/or keyword arguments to instantiate the imported module.
- Returns:
object imported and called.
- Return type:
Any
- torchdistill.common.main_util.setup_for_distributed(is_master)[source]ο
Disables logging when not in master process.
- Parameters:
is_master (bool) β True if it is the master process.
- torchdistill.common.main_util.set_seed(seed)[source]ο
Sets a random seed for random, numpy, and torch (torch.manual_seed, torch.cuda.manual_seed_all).
- Parameters:
seed (int) β random seed.
- torchdistill.common.main_util.is_dist_avail_and_initialized()[source]ο
Checks if distributed model is available and initialized.
- Returns:
True if distributed mode is available and initialized.
- Return type:
bool
- torchdistill.common.main_util.get_world_size()[source]ο
Gets world size.
- Returns:
world size.
- Return type:
int
- torchdistill.common.main_util.get_rank()[source]ο
Gets the rank of the current process in the provided
group
or the default group if none was provided.- Returns:
rank of the current process in the provided
group
or the default group if none was provided.- Return type:
int
- torchdistill.common.main_util.is_main_process()[source]ο
Checks if this is the main process.
- Returns:
True if this is the main process.
- Return type:
bool
- torchdistill.common.main_util.save_on_master(*args, **kwargs)[source]ο
Use torch.save for args if this is the main process.
- Returns:
True if this is the main process.
- Return type:
bool
- torchdistill.common.main_util.init_distributed_mode(world_size=1, dist_url='env://')[source]ο
Initialize the distributed mode.
- Parameters:
world_size (int) β world size.
dist_url (str) β URL specifying how to initialize the process group.
- Returns:
tuple of 1) whether distributed mode is initialized, 2) world size, and 3) list of device IDs.
- Return type:
(bool, int, list[int] or None)
- torchdistill.common.main_util.load_ckpt(ckpt_file_path, model=None, optimizer=None, lr_scheduler=None, strict=True)[source]ο
Load a checkpoint file with model, optimizer, and/or lr_scheduler.
- Parameters:
ckpt_file_path (str) β checkpoint file path.
model (nn.Module) β model.
optimizer (nn.Module) β optimizer.
lr_scheduler (nn.Module) β learning rate scheduler.
strict (bool) β
strict
as a keyword argument ofload_state_dict
.
- Returns:
tuple of best value (e.g., best validation result) and parsed args.
- Return type:
(float or None, argparse.Namespace or None)
- torchdistill.common.main_util.save_ckpt(model, optimizer, lr_scheduler, best_value, args, output_file_path)[source]ο
Save a checkpoint file including model, optimizer, best value, parsed args, and learning rate scheduler.
- Parameters:
model (nn.Module) β model.
optimizer (nn.Module) β optimizer.
lr_scheduler (nn.Module) β learning rate scheduler.
best_value (float) β best value e.g., best validation result.
args (argparse.Namespace) β parsed args.
output_file_path (str) β output file path.
torchdistill.common.misc_utilο
- torchdistill.common.misc_util.check_if_plottable()[source]ο
Checks if DISPLAY environmental variable is valid.
- Returns:
True if DISPLAY variable is valid.
- Return type:
bool
- torchdistill.common.misc_util.get_classes(package_name, require_names=False)[source]ο
Gets classes in a given package.
- Parameters:
package_name (str) β package name.
require_names (bool) β whether to preserve member names.
- Returns:
collection of classes defined in the given package.
- Return type:
list[(str, class)] or list[class]
- torchdistill.common.misc_util.get_classes_as_dict(package_name, is_lower=False)[source]ο
Gets classes in a given package as dict.
- Parameters:
package_name (str) β package name.
is_lower (bool) β if True, use lowercase module names.
- Returns:
dict of classes defined in the given package.
- Return type:
dict
- torchdistill.common.misc_util.get_functions(package_name, require_names=False)[source]ο
Gets functions in a given package.
- Parameters:
package_name (str) β package name.
require_names (bool) β whether to preserve function names.
- Returns:
collection of functions defined in the given package.
- Return type:
list[(str, Callable)] or list[Callable]
- torchdistill.common.misc_util.get_functions_as_dict(package_name, is_lower=False)[source]ο
Gets function in a given package as dict.
- Parameters:
package_name (str) β package name.
is_lower (bool) β if True, use lowercase module names.
- Returns:
dict of classes defined in the given package.
- Return type:
dict
torchdistill.common.module_utilο
- torchdistill.common.module_util.check_if_wrapped(model)[source]ο
Checks if a given model is wrapped by DataParallel or DistributedDataParallel.
- Parameters:
model (nn.Module) β model.
- Returns:
True if model is wrapped by either DataParallel or DistributedDataParallel.
- Return type:
bool
- torchdistill.common.module_util.count_params(module)[source]ο
Returns the number of module parameters.
- Parameters:
module (nn.Module) β module.
- Returns:
number of model parameters.
- Return type:
int
- torchdistill.common.module_util.freeze_module_params(module)[source]ο
Freezes parameters by setting requires_grad=False for all the parameters.
- Parameters:
module (nn.Module) β module.
- torchdistill.common.module_util.unfreeze_module_params(module)[source]ο
Unfreezes parameters by setting requires_grad=True for all the parameters.
- Parameters:
module (nn.Module) β module.
- torchdistill.common.module_util.get_updatable_param_names(module)[source]ο
Gets collection of updatable parameter names.
- Parameters:
module (nn.Module) β module.
- Returns:
names of updatable parameters.
- Return type:
list[str]
- torchdistill.common.module_util.get_frozen_param_names(module)[source]ο
Gets collection of frozen parameter names.
- Parameters:
module (nn.Module) β module.
- Returns:
names of frozen parameters.
- Return type:
list[str]
- torchdistill.common.module_util.get_module(root_module, module_path)[source]ο
Gets a module specified by
module_path
.- Parameters:
root_module (nn.Module) β module.
module_path (str) β module path for extracting the module from
root_module
.
- Returns:
module extracted from
root_module
if exists.- Return type:
nn.Module or None
- torchdistill.common.module_util.get_hierarchized_dict(module_paths)[source]ο
Gets a hierarchical structure from module paths.
- Parameters:
module_paths (list[str]) β module paths.
- Returns:
module extracted from
root_module
if exists.- Return type:
dict
- torchdistill.common.module_util.decompose(ordered_dict)[source]ο
Converts an ordered dict into a list of key-value pairs.
- Parameters:
ordered_dict (collections.OrderedDict) β ordered dict.
- Returns:
list of key-value pairs.
- Return type:
list[(str, Any)]
- torchdistill.common.module_util.get_components(module_paths)[source]ο
Converts module paths into a list of pairs of parent module and child module names.
- Parameters:
module_paths (list[str]) β module paths.
- Returns:
list of pairs of parent module and child module names.
- Return type:
list[(str, str)]
- torchdistill.common.module_util.extract_target_modules(parent_module, target_class, module_list)[source]ο
Extracts modules that are instance of
target_class
and updatemodule_list
with the extracted modules.- Parameters:
parent_module (nn.Module) β parent module.
target_class (class) β target class.
module_list (list[nn.Module]) β (empty) list to be filled with modules that are instances of
target_class
.
- torchdistill.common.module_util.extract_all_child_modules(parent_module, module_list)[source]ο
Extracts all the child modules and update
module_list
with the extracted modules.- Parameters:
parent_module (nn.Module) β parent module.
module_list (list[nn.Module]) β (empty) list to be filled with child modules.
torchdistill.common.tensor_utilο
- class torchdistill.common.tensor_util.QuantizedTensor(tensor, scale, zero_point)ο
- scaleο
Alias for field number 1
- tensorο
Alias for field number 0
- zero_pointο
Alias for field number 2
- torchdistill.common.tensor_util.quantize_tensor(x, num_bits=8)[source]ο
Quantizes a tensor using num_bits int and float.
Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: βQuantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inferenceβ @ CVPR 2018 (2018)
- Parameters:
x (torch.Tensor) β tensor to be quantized.
num_bits (int) β the number of bits for quantization.
- Returns:
quantized tensor.
- Return type:
- torchdistill.common.tensor_util.dequantize_tensor(q_x)[source]ο
Dequantizes a quantized tensor.
Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu, Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko: βQuantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inferenceβ @ CVPR 2018 (2018)
- Parameters:
q_x (QuantizedTensor) β quantized tensor to be dequantized.
- Returns:
dequantized tensor.
- Return type:
torch.Tensor
torchdistill.common.yaml_utilο
- torchdistill.common.yaml_util.yaml_join(loader, node)[source]ο
Joins a sequence of strings.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
joined string.
- Return type:
str
- torchdistill.common.yaml_util.yaml_pathjoin(loader, node)[source]ο
Joins a sequence of strings as a (file) path.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
joined (file) path.
- Return type:
str
- torchdistill.common.yaml_util.yaml_expanduser(loader, node)[source]ο
Applies os.path.expanduser to a (file) path.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
(file) path.
- Return type:
str
- torchdistill.common.yaml_util.yaml_abspath(loader, node)[source]ο
Applies os.path.abspath to a (file) path.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
(file) path.
- Return type:
str
- torchdistill.common.yaml_util.yaml_import_get(loader, node)[source]ο
Imports module and get its attribute.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
module attribute.
- Return type:
Any
- torchdistill.common.yaml_util.yaml_import_call(loader, node)[source]ο
Imports module and call the module/function e.g., instantiation.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
result of callable module.
- Return type:
Any
- torchdistill.common.yaml_util.yaml_import_call_method(loader, node)[source]ο
Imports module and call its method.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
result of callable module.
- Return type:
Any
- torchdistill.common.yaml_util.yaml_getattr(loader, node)[source]ο
Gets an attribute of the first argument.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
module attribute.
- Return type:
Any
- torchdistill.common.yaml_util.yaml_setattr(loader, node)[source]ο
Sets an attribute to the first argument.
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
module attribute.
- Return type:
Any
- torchdistill.common.yaml_util.yaml_access_by_index_or_key(loader, node)[source]ο
Obtains a value from a specified data
- Parameters:
loader (yaml.loader.FullLoader) β yaml loader.
node (yaml.nodes.Node) β node.
- Returns:
accessed object.
- Return type:
Any
- torchdistill.common.yaml_util.load_yaml_file(yaml_file_path, custom_mode=True)[source]ο
Loads a yaml file optionally with convenient constructors.
- Parameters:
yaml_file_path (str) β yaml file path.
custom_mode (bool) β if True, uses convenient constructors.
- Returns:
loaded PyYAML object.
- Return type:
Any