Pruning
enot.pruning
module contains functional for automatic pruning of user-models.
Pruning engine works fine with many models and frameworks like Torchvision, Pytorch Image Models (TIMM), OpenMMLab, and others.
This package features pre-defined or custom pruning procedures for removing least important filters or neurons. ENOT Pruning module currently supports structured pruning procedure. User can define pruning ratio (percentage of channels removed) manually or use any of the pre-defined strategies.
The first (and the simplest one) is equal pruning strategy, which keeps roughly the same percentage of channels in each prunable layer.
The second (and quite compute intense) is optimal pruning strategy, which searches for the optimal pruning ratios for each prunable layer.
A couple definitions to simplify documentation reading experience:
Pruning ratio is a fraction of channels to remove (prune) in prunable group or prunable layer.
Gate (or channel gate) is a special
torch.nn.Module
which gathers and saves channel “local importance”. “Local importance” means that you can compare channels by their importance values within single layer, but not between distinct layers.Calibration procedure estimates channels importance values according to user model and data.
Channel label is the global channel index in a network.
Pruning config specifies pruning amount for each prunable layer. See more here.
Pruning API
Pruning functional accessible to user is divided into three sections:
High-level interface: the easiest way to use pruning functional.
Low-level interface: low-level API for pruning, takes moderate time to understand and use.
High-level interface
High-level interface provides two ways to prune model. They work in the same manner, but have different policies for the amount of pruning for each layer:
- calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]
Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer).
- Parameters:
model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
pruning_ratio (float, optional) – Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4).
finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader samples for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) – Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Model function for execution. See notes below for the detailed argument description. Default value is “forward”.
- Returns:
Pruned model.
- Return type:
Note
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call eval
method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_equal(...)
entry_point
is a string that specifies which function of the user’s model will be traced.
Simple examples: forward, execute_model, forward_train.
If such a function is located in the model’s submodule -
then you should first write the submodule’s name followed by the function name, all separated by dots:
submodule.forward, head.predict_features, submodule1.submodule2.forward.
- calibrate_and_prune_model_optimal(model, dataloader, loss_function, latency_calculation_function, target_latency, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, channels_selection_constraint=None, n_search_steps=200, entry_point='forward', **kwargs)[source]
Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model.
- Parameters:
model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.
target_latency (float) – Target model latency. This argument should have the same units as
latency_calculation_function
’s output.finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader samples for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
channels_selection_constraint (Optional[ChannelsSelectorConstraint]) – Optional
ChannelsSelectorConstraint
instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, thenDefaultChannelsSelectorConstraint
will be used. None by default.n_search_steps (int, optional) – Number of sampled configurations for pruning to select optimal architecture. Default value is 200.
entry_point (str, optional) – Model function for execution. See notes in
calibrate_and_prune_model_equal()
for the detailed argument description. Default value is “forward”.**kwargs – Additional keyword arguments for label selector.
- Returns:
Pruned model.
- Return type:
Low-level interface
Low-level interface consists of:
The main class for pruning is ModelPruningInfo
. It contains necessary
information for model pruning: model execution graph, list of all prunable and
non-prunable channels, all channel importance values, etc.
Model pruning info is obtained through calibration process. Calibration for
pruning is made by PruningCalibrator
class. This class’s instance
should be used as a context manager. Inside it’s context user should compute
losses and perform backward passes through his model on calibration data.
User can also calibrate his model by calling
calibrate_model_for_pruning()
function and provide all necessary
arguments for proper model execution and loss computation.
Actual model pruning is performed by prune_model()
function which
requires pruning info object and list of channel labels to prune.
To select which labels should be pruned, and which should not, we created a
special interface for label selection. Base class for all label selectors is
PruningLabelSelector
. This class defines an abstract function
PruningLabelSelector.select()
which should select labels to prune and
which subclasses should implement to fit our pipelines based on
calibrate_and_prune_model()
.
Abstract class ScorePruningLabelSelector
is a straightforward way to
select least important channel labels. It utilises Pruning config to select
least important channels in each prunable layer. This is done by sorting
channels by their importance values and selecting top-K least important
channels based on the corresponding pruning ratio from the pruning config.
These two label selectors are used internally in high-level interface for pruning described above:
Pruning info
- class ModelPruningInfo(graph, prunable_groups, unprunable_groups, channel_labels_by_nodes, node_users, gate_container, gate_names_by_locations, prunable_group_gate_names, gate_apply_info, child_labels_by_parents, int_pruners)
Container with model pruning-related information.
Calibration
- class PruningCalibrator(model, entry_point='forward', max_prunable_labels=4096)
Context manager for gathering pruning information.
This context manager walks through model and collects information about prunable and unprunable channels and estimates channel importance values. The main purpose of this calibrator is to create a
ModelPruningInfo
object which can be later used to prune the model.Model pruning info can be accessed via
PruningCalibrator.pruning_info
.Examples
>>> p_calibrator = PruningCalibrator(model=baseline_model) >>> with p_calibrator: >>> for inputs, labels in dataloader: >>> model_output = model(inputs) >>> loss = loss_function(model_output, labels) >>> loss.backward() >>> p_calibrator.pruning_info
See also
calibrate_and_prune_model
,calibrate_and_prune_model_equal
,calibrate_and_prune_model_optimal
- __init__(model, entry_point='forward', max_prunable_labels=4096)
- Parameters:
model (torch.nn.Module) – Model which user wants to calibrate for pruning.
entry_point (str, optional) – Model function for execution. See note below for the detailed argument description. Default value is “forward”.
max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.
- property pruning_info: ModelPruningInfo | None
Information about model pruning.
- Returns:
Pruning-related information about model. None if calibrator instance was not used as a context manager or model was calibrated incorrectly.
- Return type:
ModelPruningInfo or None
Note
entry_point
is a string that specifies which function of the user’s model will be traced.
Simple examples: forward, execute_model, forward_train.
If such a function is located in the model’s submodule -
then you should first write the submodule’s name followed by the function name, all separated by dots:
submodule.forward, head.predict_features, submodule1.submodule2.forward.
- calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward', max_prunable_labels=4096)[source]
Estimates model channel importance values for later pruning.
This function searches for prunable channels in user-defined
model
. After extracting channel information from model graph, estimates channel importance values for later pruning.- Parameters:
model (torch.nn.Module) – Model to calibrate pruning gates for.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
n_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in
epochs
argument.epochs (int, optional) – Number of total calibration epochs. Not used when
n_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Name of model method for tracing. In this method all prunable operations should be forwarded. Default value is ‘forward’.
max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.
- Returns:
Pruning information for later usage in pruning methods.
- Return type:
Model pruning
- prune_model(model, pruning_info, prune_labels, inplace=True)
Remove (prune) least important channels defined by
prune_labels
parameter.- Parameters:
model (torch.nn.Module) – Model for pruning.
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).prune_labels (Sequence of Label) – Sequence of global channel index to prune.
inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).
- Returns:
Pruned model.
- Return type:
The following function combines the functional of
enot.pruning.calibrate_model_for_pruning()
and enot.pruning.prune_model()
functions:
- calibrate_and_prune_model(label_selector, model, dataloader, loss_function, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]
Estimates channel importance values and prunes model with user defined strategy.
This function searches for prunable channels in user-defined
model
. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategylabel_selector
.- Parameters:
label_selector (PruningLabelSelector) – Channel selector object. This object should implement
PruningLabelSelector.select()
method which returns list with channel indices to prune.model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Model function for execution. See
Notes
section for the detailed argument description. Default value is “forward”.
- Returns:
Pruned model.
- Return type:
Note
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call eval
method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_equal(...)
entry_point
is a string that specifies which function of the user’s model will be traced.
Simple examples: forward, execute_model, forward_train.
If such a function is located in the model’s submodule -
then you should first write the submodule’s name followed by the function name, all separated by dots:
submodule.forward, head.predict_features, submodule1.submodule2.forward.
Label selector
To select which labels should be pruned use one of the following selectors:
UniformPruningLabelSelector
— removes an equal percentage of channels in each prunable layer.GlobalPruningLabelSelectorByChannels
— selects the least important channels within network.LatencyPruningLabelSelector
— finds the model (labels) with latency as close as possible to the target latency parameter.OptimalPruningLabelSelector
— finds the optimal pruning ratios for each prunable layer.
There are also two useful classes to help you implement your own label selector:
PruningLabelSelector
— base class for all label selectors.ScorePruningLabelSelector
— base class for selecting least important channel labels.
- class UniformPruningLabelSelector(pruning_ratio)
Bases:
ScorePruningLabelSelector
Label selector for uniform (equal) pruning.
Removes an equal percentage of channels in each prunable layer.
- class GlobalPruningLabelSelectorByChannels(n_channels_or_ratio)
Bases:
ScorePruningLabelSelector
Label selector for global pruning.
Selects the least important channels within network.
- __init__(n_channels_or_ratio)
- Parameters:
n_channels_or_ratio (Union[int, float]) – Number of channels or channels ratio to remove within all network. If the parameter is in the range (0, 1), it is interpreted as a fraction of all unique labels. If the parameter greater or equal to 1, it is interpreted as a number of labels to remove.
- class LatencyPruningLabelSelector(target_latency, latency_calculation_function, selector_cb=None)
Bases:
PruningLabelSelector
Latency-based label selector.
It finds the model with latency as close as possible to the target latency parameter, and always selects the labels with the lowest scores.
- __init__(target_latency, latency_calculation_function, selector_cb=None)
- Parameters:
target_latency (float) – Target model latency. This argument should have the same units as output of
latency_calculation_function
.latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.
selector_cb (Optional[Callable[[float, float], bool]]) – An optional callback that is called at each iteration of the search process. The callback should take the current latency as the first parameter and the target latency as the second parameter and return True if the search procedure should be stopped, False otherwise. Can be used for logging and early stop.
- class OptimalPruningLabelSelector(target_latency, latency_calculation_function, n_search_steps, latency_penalty=300, channels_constraint=None, n_jobs=1, show_progress_bar=False, verbose=True, seed=None)
- __init__(target_latency, latency_calculation_function, n_search_steps, latency_penalty=300, channels_constraint=None, n_jobs=1, show_progress_bar=False, verbose=True, seed=None)
- Parameters:
target_latency (float) – Target model latency. This argument should have the same units as output of
latency_calculation_function
.latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.
n_search_steps (int) – Number of configurations (samples) for pruning to find optimal architecture.
latency_penalty (float) – Weight of the latency in objective function. Default value 300.
channels_constraint (Optional[ChannelsSelectorConstraint]) – Optional
ChannelsSelectorConstraint
instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, thenDefaultChannelsSelectorConstraint
will be used. None by default.n_jobs (int) – Number of parallel jobs. Default value is 1.
show_progress_bar (bool) – Show progress bar or not. False by default.
verbose (bool) – If True then default logger callback will duplicate messages in stdout. Default value is True.
seed (Optional[int]) – Optional seed for sampler. Default value is None.
- class PruningLabelSelector
Base class for all pruning label selectors.
This class defines an abstract method
PruningLabelSelector.select()
that should return labels to prune.- static extract_labels_info(pruning_info, sort_labels=True, drop_shared_gates=True)
Extract labels information from pruning_info.
- Parameters:
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).sort_labels (bool) – Weather to sort (ascending order) labels (sort_labels=True) according pruning criteria (score) within each group or do not sort (sort_labels=True). True by default.
drop_shared_gates (bool) – Weather to drop group of labels for repeated (shared) gates (drop_shared_gates=True) or add as is in output (drop_shared_gates=False). True by default.
- Returns:
First position of tuple contains labels grouped by gates. Second position of tuple contains mapping of labels and pruning criteria value (score).
- Return type:
Tuple[List[Tuple[Label, …]], Dict[Label, float]]
- static get_most_important_labels(sorted_labels_by_gates)
Return set of the most important labels. The most important labels are labels with the highest score from every gate.
- Parameters:
sorted_labels_by_gates (List[Tuple[Label, ...]]) – Labels grouped by gates and sorted by score (ascending order) within every gate group.
- Returns:
Set of the most important labels.
- Return type:
Set[Label]
- abstract select(model, pruning_info)
Method that chooses which labels should be pruned based on current label selector policy.
Warning
Depending on label selector implementation, this function may have significant execution time.
- Parameters:
model (torch.nn.Module) – Model for pruning.
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).
- Returns:
List of channel labels which should be pruned.
- Return type:
list of Label
- class ScorePruningLabelSelector
Bases:
PruningLabelSelector
Base class for pruning label selectors, which select labels based on
PruningLabelSelector.extract_labels_info()
output.This class defines an abstract method
_select()
that takesPruningLabelSelector.extract_labels_info()
output and should return list of labels to prune.- abstract _select(labels_by_gates, labels_score)
select()
implementation.- Parameters:
labels_by_gates (List[Tuple[Label, ...]]) – Labels grouped by gates.
labels_score (Dict[Label, float]) – Mapping of labels and pruning criteria value (score).
- Returns:
List of labels to prune.
- Return type:
List[Label]
- select(model, pruning_info)
Method that chooses which labels should be pruned based on current label selector policy.
Warning
Depending on label selector implementation, this function may have significant execution time.
- Parameters:
model (torch.nn.Module) – Model for pruning.
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).
- Returns:
List of channel labels which should be pruned.
- Return type:
list of Label
Here is an example of simple label selector that only selects even labels:
import operator
from functools import reduce
from typing import List
import torch
from enot.pruning import ModelPruningInfo
from enot.pruning import PruningLabelSelector
class EvenPruningLabelSelector(PruningLabelSelector):
def select(self, model: torch.nn.Module, pruning_info: ModelPruningInfo) -> List[int]:
labels_by_gates, labels_by_score = self.extract_labels_info(pruning_info)
# we don't need model and labels_by_score for our label selector
del model, labels_by_score
is_even = lambda label: label.label % 2 == 0
labels = reduce(operator.concat, labels_by_gates)
even_labels = [*filter(is_even, labels)]
return even_labels
Channels Selector Constraint
- class ChannelsSelectorConstraint
Base class for all channels selection constraint.
To implement new channel selection you should write implementation for
_get()
.- abstract _get(number_of_channels)
Implementation of
get()
. It should return low, high and step for the constraint by total number of channels in gate.low - minimal number of channels that should be pruned
high - maximal number of channels allowed to prune
step - step of discretization
- apply(prune_n, number_of_channels)
Apply channels selector constraint.
- get(number_of_channels)
Calculate low, high and step for the constraint by total number of channels in gate.
low - minimal number of channels that should be pruned
high - maximal number of channels allowed to prune
step - a step of discretization
- class DefaultChannelsSelectorConstraint(absolute_min_channels=16, relative_min_channels=0.25, fixed_step=None)
Bases:
ChannelsSelectorConstraint
Default channels selector constraint.
This class allows to set two boundaries for minimal number of channels in each layer: absolute and relative. Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels in a layer. Result is a maximal value among these two.
You can also specify a fixed value for the discretization step. If no fixed step is specified, we do the following:
get()
returns step equal to 1 if channel group has size less than or equal to 64get()
returns step equal to 4 if channel group has size greater than 64
- __init__(absolute_min_channels=16, relative_min_channels=0.25, fixed_step=None)
- Parameters:
absolute_min_channels (int) – Absolute lower bound. If it is larger than
number_of_channels
, thennumber_of_channels
is returned. Default value is 16.relative_min_channels (float) – Relative lower bound. Default value is 0.25.
fixed_step (Optional[int]) – Optional fixed value for the discretization step. Default value is None.
Here you can see implementation of _get()
method
for DefaultChannelsSelectorConstraint
:
def _get(self, number_of_channels: int) -> Tuple[int, int, int]:
high = max(self._absolute_min_channels, int(self._relative_min_channels * number_of_channels))
high = number_of_channels - high
high = max(0, high)
if self._fixed_step is None:
step = 4 if number_of_channels > 64 else 1
else:
step = self._fixed_step
return 0, high, step