Pruning

enot.pruning package contains functional for automatic pruning of user models.

This package features developer-defined or custom pruning procedures for removing least important filters or neurons. So in simple cases enot.pruning package provides structured pruning procedure. In more complex cases enot.pruning package provides special mechanism to estimate global channels importance which lead to optimal sub-architecture selection (will be released soon).

A couple definitions to simplify documentation reading experience:

  • Gates / channel gates is the special torch.nn.Module which gather and store channel “local importance”. “Local importance” means that you can compare channels within layer not between layers.

  • Calibration procedure is the procedure to estimate channels importance relative to your dataset.

  • Channel label is the global channel index in a network.

enot-lite-plus pruning

calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, verbose=0, **kwargs)[source]

Estimates channel importances and prunes model with equal pruning strategy (same amount of channels will be pruned at each prunable layers).

Parameters
  • model (torch.nn.Module) – Model to calibrate pruning gates for.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model’s channel importances.

  • loss_function (Callable) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • pruning_ratio (float) – Relative amount of channels to prune at each prunable layers.

  • finetune_bn (bool) – Finetune running mean and running variance for better model quality.

  • n_steps (int or None, optional) – Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs in epochs argument.

  • epochs (int, optional) – Number of total threshold calibration epochs. Not used when n_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more in Converting dataloader items to PyTorch model inputs.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more in Converting dataloader items to PyTorch model inputs.

  • verbose (int, optional) – Procedure verbosity level. 0 disables all messages, 1 enables tqdm progress bar logging, 2 gives additional information about calibration. Default value is 0.

Returns

Pruned model.

Return type

torch.nn.Module

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_model_for_pruning(...)
get_criteria_label_dict(pruning_info)[source]

Collects information about channel importances in human-readable format.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns

Dictionary with channel labels and importances. It’s keys are channel group names, and values store both global channel indices (labels) and channel importances.

Return type

Dict[str, Dict[str, numpy.ndarray]]

get_labels_for_equal_pruning(pruning_info, pruning_ratio=0.5)[source]

Returns labels of least important channels of every prunable layer.

Parameters
  • pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

  • pruning_ratio (float) – Specifies percentage of least important channels selection for pruning.

Returns

List of least important channel labels.

Return type

list of int

iterate_over_gate_criteria(pruning_info)

Iterates over pruning gates and yields channel labels and channel criteria.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Yields
  • group_id (int) – Global index number of a gate.

  • labels (list of int) – Channel labels of a gate.

  • criteria (list of float) – Channel criteria of a gate.

Returns

Generator with items specified above.

Return type

Generator

class EnotPruningCalibrator(model, entry_point='forward')

Context manager for pruning information gathering.

This context manager walks through model and collects information about prunable/unprunable channel and channel importance. The main purpose of this calibrator is to create a ModelPruningInfo object which can be later used to prune the model.

Model pruning info can be accessed via EnotPruningCalibrator.pruning_info.

Examples

>>> p_calibrator = EnotPruningCalibrator(
...     model=baseline_model,
... )
>>> with p_calibrator:
>>>     for inputs, labels in dataloader:
>>>         model_output = model(inputs)
>>>         loss = loss_function(model_output, labels)
>>>         loss.backward()
>>> p_calibrator.pruning_info
Parameters
  • model (Module) –

  • entry_point (str) –

__init__(model, entry_point='forward')
Parameters
  • model (torch.nn.Module) – Model which user wants to calibrate for pruning.

  • entry_point (str) – Model function for execution. See Notes section for the detailed argument description. Default value is “forward”.

Notes

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.

property pruning_info: Optional[enot.pruning.pruning_info.ModelPruningInfo]

Returns pruning necessary information.

Returns

Return type

ModelPruningInfo or None

calibrate_and_prune_model(label_selection_fn, model, dataloader, loss_function, finetune_bn=False, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, verbose=0, **kwargs)[source]

Estimates channel importances and prunes model with user defined strategy.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importances for later pruning. After that prunes model by removing channels specified by user-defined strategy label_selection_fn.

Parameters
  • label_selection_fn (Callable[[torch.nn.Module, ModelPruningInfo], List[int]]) – Channel selection strategy for pruning. This function should typically return labels of the least important channels.

  • model (torch.nn.Module) – Model to calibrate pruning gates for.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model’s channel importances.

  • loss_function (Callable) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • finetune_bn (bool) – Finetune running mean and running variance for better model quality.

  • n_steps (int or None, optional) – Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs in epochs argument.

  • epochs (int, optional) – Number of total threshold calibration epochs. Not used when n_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more in Converting dataloader items to PyTorch model inputs.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more in Converting dataloader items to PyTorch model inputs.

  • verbose (int, optional) – Procedure verbosity level. 0 disables all messages, 1 enables tqdm progress bar logging, 2 gives additional information about calibration. Default value is 0.

Returns

Pruned model.

Return type

torch.nn.Module

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_model_for_pruning(...)

label_selection_fn is a function that defines model pruning strategy. This function should return a list with integer values - channel labels to remove from the model. Input arguments of label_selection_fn are:

modeltorch.nn.Module

Model to calibrate pruning gates for.

pruning_infoModelPruningInfo

Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

kwargsDict[str, Any]

Keyword arguments for label_selection_fn.

An example of user-defined strategy:

from enot.pruning import ModelPruningInfo, iterate_over_gate_criteria

def label_selection_fn(model: torch.nn.Module, pruning_info: ModelPruningInfo, **kwargs):
    labels_to_prune: List[int] = []
    pruning_ratio = kwargs.pop('pruning_ratio')
    for _, labels, criteria in iterate_over_gate_criteria(pruning_info):
        criteria: np.ndarray = np.array(criteria)
        prune_channels_num = int(len(criteria) * pruning_ratio)
        index_for_pruning = np.argsort(criteria)[:prune_channels_num]
        labels_to_prune += np.array(labels)[index_for_pruning].tolist()
    return labels_to_prune
prune_model(model, pruning_info, prune_labels, inplace=True)

Remove (prune) least important channels defined by prune_labels parameter.

Parameters
  • model (torch.nn.Module) – Model for pruning.

  • pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

  • prune_labels (Sequence of int) – Sequence of global channel index to prune.

  • inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).

Returns

Pruned model.

Return type

torch.nn.Module

calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, verbose=0)[source]

Estimates model channel importances for later pruning.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importances for later pruning.

Parameters
  • model (torch.nn.Module) – Model to calibrate pruning gates for.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model’s channel importances.

  • loss_function (Callable) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • n_steps (int or None, optional) – Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs in epochs argument.

  • epochs (int, optional) – Number of total threshold calibration epochs. Not used when n_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more in Converting dataloader items to PyTorch model inputs.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more in Converting dataloader items to PyTorch model inputs.

  • verbose (int, optional) – Procedure verbosity level. 0 disables all messages, 1 enables tqdm progress bar logging, 2 gives additional information about calibration. Default value is 0.

Returns

Pruning information for later usage in pruning methods.

Return type

ModelPruningInfo

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_model_for_pruning(...)