Pruning¶
enot.pruning
package contains functional for automatic pruning
of user models.
This package features developer-defined or custom pruning procedures for
removing least important filters or neurons. So in simple cases enot.pruning
package provides
structured pruning procedure. In more complex cases enot.pruning
package provides special
mechanism to estimate global channels importance which lead to optimal sub-architecture selection
(will be released soon).
A couple definitions to simplify documentation reading experience:
Gates / channel gates is the special
torch.nn.Module
which gather and store channel “local importance”. “Local importance” means that you can compare channels within layer not between layers.Calibration procedure is the procedure to estimate channels importance relative to your dataset.
Channel label is the global channel index in a network.
enot-lite-plus pruning¶
- calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, verbose=0, **kwargs)[source]¶
Estimates channel importances and prunes model with equal pruning strategy (same amount of channels will be pruned at each prunable layers).
- Parameters
model (torch.nn.Module) – Model to calibrate pruning gates for.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model’s channel importances.
loss_function (Callable) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
pruning_ratio (float) – Relative amount of channels to prune at each prunable layers.
finetune_bn (bool) – Finetune running mean and running variance for better model quality.
n_steps (int or None, optional) – Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs in
epochs
argument.epochs (int, optional) – Number of total threshold calibration epochs. Not used when
n_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more in Converting dataloader items to PyTorch model inputs.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more in Converting dataloader items to PyTorch model inputs.verbose (int, optional) – Procedure verbosity level. 0 disables all messages, 1 enables
tqdm
progress bar logging, 2 gives additional information about calibration. Default value is 0.
- Returns
Pruned model.
- Return type
Notes
Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call
eval
method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...)
- get_criteria_label_dict(pruning_info)[source]¶
Collects information about channel importances in human-readable format.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Returns
Dictionary with channel labels and importances. It’s keys are channel group names, and values store both global channel indices (labels) and channel importances.
- Return type
Dict[str, Dict[str, numpy.ndarray]]
- get_labels_for_equal_pruning(pruning_info, pruning_ratio=0.5)[source]¶
Returns labels of least important channels of every prunable layer.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).pruning_ratio (float) – Specifies percentage of least important channels selection for pruning.
- Returns
List of least important channel labels.
- Return type
list of int
- iterate_over_gate_criteria(pruning_info)¶
Iterates over pruning gates and yields channel labels and channel criteria.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Yields
group_id (int) – Global index number of a gate.
labels (list of int) – Channel labels of a gate.
criteria (list of float) – Channel criteria of a gate.
- Returns
Generator with items specified above.
- Return type
Generator
- class EnotPruningCalibrator(model, entry_point='forward')¶
Context manager for pruning information gathering.
This context manager walks through model and collects information about prunable/unprunable channel and channel importance. The main purpose of this calibrator is to create a
ModelPruningInfo
object which can be later used to prune the model.Model pruning info can be accessed via
EnotPruningCalibrator.pruning_info
.Examples
>>> p_calibrator = EnotPruningCalibrator( ... model=baseline_model, ... ) >>> with p_calibrator: >>> for inputs, labels in dataloader: >>> model_output = model(inputs) >>> loss = loss_function(model_output, labels) >>> loss.backward() >>> p_calibrator.pruning_info
- Parameters
model (
Module
) –entry_point (
str
) –
- __init__(model, entry_point='forward')¶
- Parameters
model (torch.nn.Module) – Model which user wants to calibrate for pruning.
entry_point (str) – Model function for execution. See
Notes
section for the detailed argument description. Default value is “forward”.
Notes
entry_point
is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.
- calibrate_and_prune_model(label_selection_fn, model, dataloader, loss_function, finetune_bn=False, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, verbose=0, **kwargs)[source]¶
Estimates channel importances and prunes model with user defined strategy.
This function searches for prunable channels in user-defined
model
. After extracting channel information from model graph, estimates channel importances for later pruning. After that prunes model by removing channels specified by user-defined strategylabel_selection_fn
.- Parameters
label_selection_fn (Callable[[torch.nn.Module, ModelPruningInfo], List[int]]) – Channel selection strategy for pruning. This function should typically return labels of the least important channels.
model (torch.nn.Module) – Model to calibrate pruning gates for.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model’s channel importances.
loss_function (Callable) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
finetune_bn (bool) – Finetune running mean and running variance for better model quality.
n_steps (int or None, optional) – Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs in
epochs
argument.epochs (int, optional) – Number of total threshold calibration epochs. Not used when
n_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more in Converting dataloader items to PyTorch model inputs.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more in Converting dataloader items to PyTorch model inputs.verbose (int, optional) – Procedure verbosity level. 0 disables all messages, 1 enables
tqdm
progress bar logging, 2 gives additional information about calibration. Default value is 0.
- Returns
Pruned model.
- Return type
Notes
Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call
eval
method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...)
label_selection_fn
is a function that defines model pruning strategy. This function should return a list with integer values - channel labels to remove from the model. Input arguments oflabel_selection_fn
are:- modeltorch.nn.Module
Model to calibrate pruning gates for.
- pruning_infoModelPruningInfo
Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- kwargsDict[str, Any]
Keyword arguments for
label_selection_fn
.
An example of user-defined strategy:
from enot.pruning import ModelPruningInfo, iterate_over_gate_criteria def label_selection_fn(model: torch.nn.Module, pruning_info: ModelPruningInfo, **kwargs): labels_to_prune: List[int] = [] pruning_ratio = kwargs.pop('pruning_ratio') for _, labels, criteria in iterate_over_gate_criteria(pruning_info): criteria: np.ndarray = np.array(criteria) prune_channels_num = int(len(criteria) * pruning_ratio) index_for_pruning = np.argsort(criteria)[:prune_channels_num] labels_to_prune += np.array(labels)[index_for_pruning].tolist() return labels_to_prune
- prune_model(model, pruning_info, prune_labels, inplace=True)¶
Remove (prune) least important channels defined by
prune_labels
parameter.- Parameters
model (torch.nn.Module) – Model for pruning.
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).prune_labels (Sequence of int) – Sequence of global channel index to prune.
inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).
- Returns
Pruned model.
- Return type
- calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, verbose=0)[source]¶
Estimates model channel importances for later pruning.
This function searches for prunable channels in user-defined
model
. After extracting channel information from model graph, estimates channel importances for later pruning.- Parameters
model (torch.nn.Module) – Model to calibrate pruning gates for.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model’s channel importances.
loss_function (Callable) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
n_steps (int or None, optional) – Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs in
epochs
argument.epochs (int, optional) – Number of total threshold calibration epochs. Not used when
n_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more in Converting dataloader items to PyTorch model inputs.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more in Converting dataloader items to PyTorch model inputs.verbose (int, optional) – Procedure verbosity level. 0 disables all messages, 1 enables
tqdm
progress bar logging, 2 gives additional information about calibration. Default value is 0.
- Returns
Pruning information for later usage in pruning methods.
- Return type
ModelPruningInfo
Notes
Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call
eval
method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...)