Pruning
The enot.pruning
module contains functional for automatic pruning of user-models.
Pruning engine works fine with many models and frameworks like Torchvision, Pytorch Image Models (TIMM), OpenMMLab, and others.
This package features pre-defined or custom pruning procedures for removing least important filters or neurons. ENOT Pruning module currently supports structured pruning procedure. User can define pruning ratio (percentage of channels removed) manually or use any of the pre-defined strategies.
The first (and the simplest one) is equal pruning strategy, which keeps roughly the same percentage of channels in each prunable layer.
The second (and quite compute intense) is optimal pruning strategy, which searches for the optimal pruning ratios for each prunable layer.
A couple definitions to simplify documentation reading experience:
Pruning ratio is a fraction of channels to remove (prune) in prunable group or prunable layer.
Gate (or channel gate) is a special
torch.nn.Module
which gathers and saves channel “local importance”. “Local importance” means that you can compare channels by their importance values within single layer, but not between distinct layers.Calibration procedure estimates channels importance values according to user model and data.
Channel label is the global channel index in a network.
Pruning config specifies pruning amount for each prunable layer. See more here.
Pruning API
Pruning functional accessible to user is divided into three sections:
High-level interface: the easiest way to use pruning functional.
Low-level interface: low-level API for pruning, takes moderate time to understand and use.
High-level interface
High-level interface provides two ways to prune model. They work in the same manner, but have different policies for the amount of pruning for each layer:
- calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]
Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer).
- Parameters:
model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
pruning_ratio (float, optional) – Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4).
finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader samples for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) – Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Model function for execution. See notes below for the detailed argument description. Default value is “forward”.
- Returns:
Pruned model.
- Return type:
Note
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call eval
method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_equal(...)
entry_point
is a string that specifies which function of the user’s model will be traced.
Simple examples: forward, execute_model, forward_train.
If such a function is located in the model’s submodule -
then you should first write the submodule’s name followed by the function name, all separated by dots:
submodule.forward, head.predict_features, submodule1.submodule2.forward.
- calibrate_and_prune_model_optimal(model, dataloader, loss_function, latency_calculation_function, target_latency, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, channels_selection_constraint=None, n_search_steps=200, entry_point='forward', **kwargs)[source]
Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model.
- Parameters:
model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.
target_latency (float) – Target model latency. This argument should have the same units as
latency_calculation_function
’s output.finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader samples for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
channels_selection_constraint (Optional[ChannelsSelectorConstraint]) – Optional
ChannelsSelectorConstraint
instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, thenDefaultChannelsSelectorConstraint
will be used. None by default.n_search_steps (int, optional) – Number of sampled configurations for pruning to select optimal architecture. Default value is 200.
entry_point (str, optional) – Model function for execution. See notes in
calibrate_and_prune_model_equal()
for the detailed argument description. Default value is “forward”.**kwargs – Additional keyword arguments for label selector.
- Returns:
Pruned model.
- Return type:
Low-level interface
Low-level interface consists of:
The main class for pruning is ModelPruningInfo
. It contains necessary
information for model pruning: model execution graph, list of all prunable and
non-prunable channels, all channel importance values, etc.
Model pruning info is obtained through calibration process. Calibration for
pruning is made by PruningCalibrator
class. This class’s instance
should be used as a context manager. Inside it’s context user should compute
losses and perform backward passes through his model on calibration data.
User can also calibrate his model by calling
calibrate_model_for_pruning()
function and provide all necessary
arguments for proper model execution and loss computation.
Actual model pruning is performed by prune_model()
function which
requires pruning info object and list of channel labels to prune.
To select which labels should be pruned, and which should not, we created a
special interface for label selection. Base class for all label selectors is
PruningLabelSelector
. This class defines an abstract function
PruningLabelSelector.select()
which should select labels to prune and
which subclasses should implement to fit our pipelines based on
calibrate_and_prune_model()
.
Abstract class ScorePruningLabelSelector
is a straightforward way to
select least important channel labels. It utilises Pruning config to select
least important channels in each prunable layer. This is done by sorting
channels by their importance values and selecting top-K least important
channels based on the corresponding pruning ratio from the pruning config.
These two label selectors are used internally in high-level interface for pruning described above:
Pruning info
- class ModelPruningInfo(graph, prunable_groups, unprunable_groups, channel_labels_by_nodes, node_users, gate_container, gate_names_by_locations, prunable_group_gate_names, gate_apply_info, child_labels_by_parents, int_pruners)
Container with model pruning-related information.
Calibration
- class PruningCalibrator(model, entry_point='forward', max_prunable_labels=4096)
Context manager for gathering pruning information.
This context manager walks through model and collects information about prunable and unprunable channels and estimates channel importance values. The main purpose of this calibrator is to create a
ModelPruningInfo
object which can be later used to prune the model.Model pruning info can be accessed via
PruningCalibrator.pruning_info
.Examples
>>> p_calibrator = PruningCalibrator(model=baseline_model) >>> with p_calibrator: >>> for inputs, labels in dataloader: >>> model_output = model(inputs) >>> loss = loss_function(model_output, labels) >>> loss.backward() >>> p_calibrator.pruning_info
See also
calibrate_and_prune_model
,calibrate_and_prune_model_equal
,calibrate_and_prune_model_optimal
- __init__(model, entry_point='forward', max_prunable_labels=4096)
- Parameters:
model (torch.nn.Module) – Model which user wants to calibrate for pruning.
entry_point (str, optional) – Model function for execution. See note below for the detailed argument description. Default value is “forward”.
max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.
- property pruning_info: ModelPruningInfo | None
Information about model pruning.
- Returns:
Pruning-related information about model. None if calibrator instance was not used as a context manager or model was calibrated incorrectly.
- Return type:
ModelPruningInfo or None
Note
entry_point
is a string that specifies which function of the user’s model will be traced.
Simple examples: forward, execute_model, forward_train.
If such a function is located in the model’s submodule -
then you should first write the submodule’s name followed by the function name, all separated by dots:
submodule.forward, head.predict_features, submodule1.submodule2.forward.
- calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward', max_prunable_labels=4096)[source]
Estimates model channel importance values for later pruning.
This function searches for prunable channels in user-defined
model
. After extracting channel information from model graph, estimates channel importance values for later pruning.- Parameters:
model (torch.nn.Module) – Model to calibrate pruning gates for.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
n_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in
epochs
argument.epochs (int, optional) – Number of total calibration epochs. Not used when
n_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Name of model method for tracing. In this method all prunable operations should be forwarded. Default value is ‘forward’.
max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.
- Returns:
Pruning information for later usage in pruning methods.
- Return type:
Model pruning
- prune_model(model, pruning_info, prune_labels, inplace=True)
Remove (prune) least important channels defined by
prune_labels
parameter.- Parameters:
model (torch.nn.Module) – Model for pruning.
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).prune_labels (Sequence of Label) – Sequence of global channel index to prune.
inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).
- Returns:
Pruned model.
- Return type:
The following function combines the functional of
enot.pruning.calibrate_model_for_pruning()
and enot.pruning.prune_model()
functions:
- calibrate_and_prune_model(label_selector, model, dataloader, loss_function, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]
Estimates channel importance values and prunes model with user defined strategy.
This function searches for prunable channels in user-defined
model
. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategylabel_selector
.- Parameters:
label_selector (PruningLabelSelector) – Channel selector object. This object should implement
PruningLabelSelector.select()
method which returns list with channel indices to prune.model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Model function for execution. See
Notes
section for the detailed argument description. Default value is “forward”.
- Returns:
Pruned model.
- Return type:
Note
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call eval
method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_equal(...)
entry_point
is a string that specifies which function of the user’s model will be traced.
Simple examples: forward, execute_model, forward_train.
If such a function is located in the model’s submodule -
then you should first write the submodule’s name followed by the function name, all separated by dots:
submodule.forward, head.predict_features, submodule1.submodule2.forward.
Label selector
To select which labels should be pruned use one of the following selectors:
UniformPruningLabelSelector
— removes an equal percentage of channels in each prunable layer.GlobalPruningLabelSelectorByChannels
— selects the least important channels within network.GlobalLatencyPruningLabelSelector
— finds the model (labels) with latency as close as possible to the target latency parameter.KnapsackPruningLabelSelector
— maximizes the total score of the model with latency constraint according to the classical knapsack algorithm.OptimalPruningLabelSelector
— finds the optimal pruning ratios for each prunable layer.
There are also two useful classes to help you implement your own label selector:
PruningLabelSelector
— base class for all label selectors.ScorePruningLabelSelector
— base class for selecting least important channel labels.
- class UniformPruningLabelSelector(pruning_ratio)
Bases:
ScorePruningLabelSelector
Label selector for uniform (equal) pruning.
Removes an equal percentage of channels in each prunable layer.
- class GlobalPruningLabelSelectorByChannels(n_channels_or_ratio)
Bases:
ScorePruningLabelSelector
Label selector for global pruning.
Selects the least important channels within network.
- __init__(n_channels_or_ratio)
- Parameters:
n_channels_or_ratio (Union[int, float]) – Number of channels or channels ratio to remove within all network. If the parameter is in the range (0, 1), it is interpreted as a fraction of all unique labels. If the parameter greater or equal to 1, it is interpreted as a number of labels to remove.
- class GlobalLatencyPruningLabelSelector(target_latency, latency_calculation_function, selector_cb=None)
Bases:
PruningLabelSelector
,TargetLatencyMixin
Latency-based label selector.
It finds the model with latency as close as possible to the target latency parameter, and always selects the labels with the lowest scores.
- __init__(target_latency, latency_calculation_function, selector_cb=None)
- Parameters:
target_latency (float) – Target model latency. This argument should have the same units as output of
latency_calculation_function
.latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.
selector_cb (Optional[Callable[[float, float], bool]]) – An optional callback that is called at each iteration of the search process. The callback should take the current latency as the first parameter and the target latency as the second parameter and return True if the search procedure should be stopped, False otherwise. Can be used for logging and early stop.
- class KnapsackPruningLabelSelector(target_latency, latency_calculation_function, strict_label_order=True, channels_constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
Bases:
PruningLabelSelector
,TargetLatencyMixin
Label selector based on knapsack algorithm.
We assume that the label score is value and latency is weight from the classical algorithm, so
KnapsackPruningLabelSelector
maximizes total score of the model with latency constraint. This algorithm cannot be used to find an exact solution because latency of labels have non-linear dependency, but the iterative approach gives good results. At each iteration, the label selector recalculates latency estimation for each label according to the selection at the previous iteration and solve knapsack problem with target latency constraint, until problem converged to this constraint.- __init__(target_latency, latency_calculation_function, strict_label_order=True, channels_constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
- Parameters:
target_latency (float) – Target model latency. This argument should have the same units as output of
latency_calculation_function
.latency_calculation_function (Callable[[nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.
strict_label_order (bool) – If True, then the labels within a gate are always selected in ascending order of score, otherwise any order of selection is acceptable. Default value is True.
channels_constraint (Optional[ChannelsSelectorConstraint]) – Optional
ChannelsSelectorConstraint
instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, thenDummyChannelsSelectorConstraint
will be used. None by default.max_iterations (int) – Maximum number of iterations. Default value is 15.
solver_execution_time_limit (Optional[float]) – Time limit in seconds for solving a linear problem at each iteration. Default value is 60.
verbose (bool) – Whether to print selection statistics and show progress bar or not. True by default.
- class OptimalPruningLabelSelector(target_latency, latency_calculation_function, n_search_steps, n_starting_points=None, batch_size=4, channels_constraint=None, additional_starting_points_generator=None, device=None, dtype=torch.float64, verbose=True, handle_ctrl_c=False, seed=None, scbo_kwargs=None)
- __init__(target_latency, latency_calculation_function, n_search_steps, n_starting_points=None, batch_size=4, channels_constraint=None, additional_starting_points_generator=None, device=None, dtype=torch.float64, verbose=True, handle_ctrl_c=False, seed=None, scbo_kwargs=None)
- Parameters:
target_latency (float) – Target model latency. This argument should have the same units as output of
latency_calculation_function
.latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria. If latency measurement cannot be performed for a particular model configuration due to technical reasons (e.g., problems with measurements on a server), then
latency_calculation_function
should raiseLatencyMeasurementError
fromenot.pruning.label_selector
.n_search_steps (Optional[int]) – Number of configurations (samples) for pruning to find optimal architecture. If None, then search loop will stop when SCBO converges.
n_starting_points (Optional[int]) – Number of sampled configurations used for startup step. If None, then
n_starting_points
calculates asmin(total_number_of_gates_in_model * 2, n_search_steps - 1)
. None by default.batch_size (int) – Batch size for SCBO algorithm. 4 by default.
channels_constraint (Optional[ChannelsSelectorConstraint]) – Optional
ChannelsSelectorConstraint
instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, thenDefaultChannelsSelectorConstraint
will be used. None by default.additional_starting_points_generator (Optional[StartingPointsGenerator]) – User defined generator of additional starting points for SCBO algorithm. None by default.
device (Optional[Union[str, torch.device]]) – The desired device for SCBO algorithm. None by default (cuda if possible).
dtype (torch.dtype) – The desired dtype of SCBO algorithm. torch.double by default.
verbose (bool) – Whether to show progress bar or not. True by default.
handle_ctrl_c (bool) – Whether to allow early stop by Ctrl+c or not. False by default.
seed (Optional[int]) – Optional seed for SCBO algorithm. Default value is None.
scbo_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments for SCBO (developers only).
- class PruningLabelSelector
Base class for all pruning label selectors.
This class defines an abstract method
PruningLabelSelector.select()
that should return labels to prune.- static extract_labels_info(pruning_info, sort_labels=True, drop_shared_gates=True, normalize_score=True, eps=1e-07)
Extract labels information from pruning_info.
- Parameters:
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).sort_labels (bool) – Whether to sort (ascending order) labels (sort_labels=True) according pruning criteria (score) within each group or do not sort (sort_labels=True). True by default.
drop_shared_gates (bool) – Whether to drop group of labels for repeated (shared) gates (drop_shared_gates=True) or add as is in output (drop_shared_gates=False). True by default.
normalize_score (bool) – Whether to normalize labels score by total score. True by default.
eps (float) – Coefficient for numerical stability for score normalization. Only used when
normalize_score==True
. Default is 1e-7.
- Returns:
First position of tuple contains labels grouped by gates. Second position of tuple contains mapping of labels and pruning criteria value (score).
- Return type:
Tuple[List[Tuple[Label, …]], Dict[Label, float]]
- static get_most_important_labels(sorted_labels_by_gates)
Return set of the most important labels. The most important labels are labels with the highest score from every gate.
- Parameters:
sorted_labels_by_gates (List[Tuple[Label, ...]]) – Labels grouped by gates and sorted by score (ascending order) within every gate group.
- Returns:
Set of the most important labels.
- Return type:
Set[Label]
- abstract select(model, pruning_info)
Method that chooses which labels should be pruned based on current label selector policy.
Warning
Depending on label selector implementation, this function may have significant execution time.
- Parameters:
model (torch.nn.Module) – Model for pruning.
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).
- Returns:
List of channel labels which should be pruned.
- Return type:
list of Label
- class ScorePruningLabelSelector
Bases:
PruningLabelSelector
Base class for pruning label selectors, which select labels based on
PruningLabelSelector.extract_labels_info()
output.This class defines an abstract method
_select()
that takesPruningLabelSelector.extract_labels_info()
output and should return list of labels to prune.- abstract _select(labels_by_gates, labels_score)
select()
implementation.- Parameters:
labels_by_gates (List[Tuple[Label, ...]]) – Labels grouped by gates.
labels_score (Dict[Label, float]) – Mapping of labels and pruning criteria value (score).
- Returns:
List of labels to prune.
- Return type:
List[Label]
- select(model, pruning_info)
Method that chooses which labels should be pruned based on current label selector policy.
Warning
Depending on label selector implementation, this function may have significant execution time.
- Parameters:
model (torch.nn.Module) – Model for pruning.
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).
- Returns:
List of channel labels which should be pruned.
- Return type:
list of Label
Here is an example of simple label selector that only selects even labels:
import operator
from functools import reduce
from typing import List
import torch
from enot.pruning import ModelPruningInfo
from enot.pruning import PruningLabelSelector
class EvenPruningLabelSelector(PruningLabelSelector):
def select(self, model: torch.nn.Module, pruning_info: ModelPruningInfo) -> List[int]:
labels_by_gates, labels_by_score = self.extract_labels_info(pruning_info)
# we don't need model and labels_by_score for our label selector
del model, labels_by_score
is_even = lambda label: label.label % 2 == 0
labels = reduce(operator.concat, labels_by_gates)
even_labels = [*filter(is_even, labels)]
return even_labels
Channels Selector Constraint
- class ChannelsSelectorConstraint
Base class for all channels selection constraint.
To implement new channel selection you should write implementation for
_get()
.- abstract _get(number_of_channels)
Implementation of
get()
. It should return low, high and step for the constraint by total number of channels in gate.low - minimal number of channels that should be pruned
high - maximal number of channels allowed to prune
step - step of discretization
- apply(prune_n, number_of_channels)
Apply channels selector constraint.
- get(number_of_channels)
Calculate low, high and step for the constraint by total number of channels in gate.
low - minimal number of channels that should be pruned
high - maximal number of channels allowed to prune
step - a step of discretization
- class DefaultChannelsSelectorConstraint(absolute_min_channels=16, relative_min_channels=0.25, fixed_step=None)
Bases:
ChannelsSelectorConstraint
Default channels selector constraint.
This class allows to set two boundaries for minimal number of channels in each layer: absolute and relative. Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels in a layer. Result is a maximal value among these two.
You can also specify a fixed value for the discretization step. If no fixed step is specified, we do the following:
get()
returns step equal to 1 if channel group has size less than or equal to 64get()
returns step equal to 4 if channel group has size greater than 64
- __init__(absolute_min_channels=16, relative_min_channels=0.25, fixed_step=None)
- Parameters:
absolute_min_channels (int) – Absolute lower bound. If it is larger than
number_of_channels
, thennumber_of_channels
is returned. Default value is 16.relative_min_channels (float) – Relative lower bound. Default value is 0.25.
fixed_step (Optional[int]) – Optional fixed value for the discretization step. Default value is None.
Here you can see implementation of _get()
method
for DefaultChannelsSelectorConstraint
:
def _get(self, number_of_channels: int) -> Tuple[int, int, int]:
high = max(self._absolute_min_channels, int(self._relative_min_channels * number_of_channels))
high = number_of_channels - high
high = max(0, high)
if self._fixed_step is None:
step = 4 if number_of_channels > 64 else 1
else:
step = self._fixed_step
return 0, high, step
Tools
Below you can find a set of tools that can help refine information of our pruning structure.
Label printing conventions:
[--]:960:(75436, 76395)
— group of 960 prunable labels with IDs from 75436 to 76395[N-]:448:(3, 450)
— group of 448 non-prunable labels with IDs from 3 to 450
The second -
symbol means whether these label group are sequential or not.
- tabulate_module_dependencies(module_name, pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)
Return a string containing all of the tracing/pruning information for the node corresponding to the module with the given name.
- Parameters:
module_name (str) – Name of the module in torch model, for example: “features.5.conv.1.0”.
pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.
model (Optional[nn.Module]) – PyTorch model used for calibration.
show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.
show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.
show_all_nodes (bool) – Include or not all types of nodes. Default value is False.
- Returns:
Tracing/pruning information for the node corresponding to the module wit the given name.
- Return type:
An example of usage:
print(tabulate_module_dependencies(module_name='features.17.conv.1.0', pruning_info=pruning_info, model=model))
+-------------------+-----------------------------------------------+
| module name | features.17.conv.1.0 |
+-------------------+-----------------------------------------------+
| node name | operation_node_conv2d_456 |
+-------------------+-----------------------------------------------+
| node operation | conv2d |
+-------------------+-----------------------------------------------+
| prunable groups | 1 |
+-------------------+-----------------------------------------------+
| unprunable groups | 0 |
+-------------------+-----------------------------------------------+
| node labels | [None, '[--]:960:(75436, 76395)', None, None] |
+-------------------+-----------------------------------------------+
Prunable group 0 with labels [--]:960:(75436, 76395)
+----------------------+---------------------------------------------------------------------------------------------+
| features.17.conv.1.0 | Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) |
+----------------------+---------------------------------------------------------------------------------------------+
| features.17.conv.0.0 | Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) |
+----------------------+---------------------------------------------------------------------------------------------+
Prunable group 0 users:
+--------------------+-----------------------------------------------------------------+
| features.17.conv.2 | Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False) |
+--------------------+-----------------------------------------------------------------+
- tabulate_unprunable_groups(pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)
Return a string containing all of the tracing/pruning information for all unprunable groups.
- Parameters:
pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.
model (Optional[nn.Module]) – PyTorch model used for calibration.
show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.
show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.
show_all_nodes (bool) – Include or not all types of nodes. Default value is False.
- Returns:
Tracing/pruning information for all unprunable groups.
- Return type:
An example of usage:
print(tabulate_unprunable_groups(pruning_info=pruning_info, model=model, show_all_nodes=True))
Unprunable groups:
Unprunable group 0 with labels [N-]:3:(0, 2)
+----------------------+---------------------+
| parameter_node_199 | features.0.0.weight |
+----------------------+---------------------+
| placeholder_node_314 | |
+----------------------+---------------------+
Unprunable group 0 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
Unprunable group 1 with labels [N-]:448:(3, 450)
+----------------------+--+
| placeholder_node_314 | |
+----------------------+--+
Unprunable group 1 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
Unprunable group 2 with labels [N-]:6:(486, 491)
+--------------------+---------------------+
| parameter_node_199 | features.0.0.weight |
+--------------------+---------------------+
Unprunable group 2 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
...
- tabulate_label_dependencies(label, pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)
Return a string containing all of the tracing/pruning information for the label.
- Parameters:
label (Union[Label, int]) – Label for which information is to be printed.
pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.
model (Optional[nn.Module]) – PyTorch model used for calibration.
show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.
show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.
show_all_nodes (bool) – Include or not all types of nodes. Default value is False.
- Returns:
Tracing/pruning information for the label.
- Return type:
An example of usage:
print(tabulate_label_dependencies(label=88964, pruning_info=pruning_info, model=model, show_all_nodes=True))
label: 88964
+---------------------+------------------------------------------------------------------+
| parameter_node_94 | features.18.0.weight |
+---------------------+------------------------------------------------------------------+
| features.18.0 | Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) |
+---------------------+------------------------------------------------------------------+
| parameter_node_197 | features.18.1.running_mean |
+---------------------+------------------------------------------------------------------+
| parameter_node_151 | features.18.1.running_var |
+---------------------+------------------------------------------------------------------+
| parameter_node_81 | features.18.1.weight |
+---------------------+------------------------------------------------------------------+
| parameter_node_113 | features.18.1.bias |
+---------------------+------------------------------------------------------------------+
| batch_norm | |
+---------------------+------------------------------------------------------------------+
| hardtanh | |
+---------------------+------------------------------------------------------------------+
| adaptive_avg_pool2d | |
+---------------------+------------------------------------------------------------------+
| flatten | |
+---------------------+------------------------------------------------------------------+
| dropout | |
+---------------------+------------------------------------------------------------------+
| parameter_node_42 | classifier.1.weight |
+---------------------+------------------------------------------------------------------+
- tabulate_label_filters(label, pruning_info, model)
Return a string containing all of the label filters information.
- Parameters:
label (Union[Label, int]) – The label for which filters are be found.
pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.
model (nn.Module) – PyTorch model used for calibration.
- Returns:
Label filters information.
- Return type:
An example of usage:
print(tabulate_label_filters(label=88964, pruning_info=pruning_info, model=model))
label: 88964
+----------------------------+----------------+---------------------------------+
| tensor name | shape | mapping |
+============================+================+=================================+
| features.18.0.weight | (1, 320, 1, 1) | [[(1202, 0)], None, None, None] |
+----------------------------+----------------+---------------------------------+
| features.18.1.running_mean | (1,) | [[(1202, 0)]] |
+----------------------------+----------------+---------------------------------+
| features.18.1.running_var | (1,) | [[(1202, 0)]] |
+----------------------------+----------------+---------------------------------+
| features.18.1.weight | (1,) | [[(1202, 0)]] |
+----------------------------+----------------+---------------------------------+
| features.18.1.bias | (1,) | [[(1202, 0)]] |
+----------------------------+----------------+---------------------------------+
| classifier.1.weight | (10, 1) | [None, [(1202, 0)]] |
+----------------------------+----------------+---------------------------------+
- get_label_filters(label, pruning_info, model, *, with_tensors=False)
Return filters that corresponding to passed label.
- Parameters:
label (Union[Label, int]) – The label for which filters are be found.
pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.
model (nn.Module) – PyTorch model used for calibration.
with_tensors (bool) – Include filter values (tensors) or not. Default value is False.
- Returns:
List of containers with filter tensor name, shape, mapping and filter value (optionally).
- Return type:
List[LabelFilterInfo]
An example of usage:
for filter_ in get_label_filters(label=label, pruning_info=pruning_info, model=model):
print(filter_)
LabelFilterInfo(tensor_name='features.18.0.weight', shape=(1, 320, 1, 1), mapping=[[(1202, 0)], None, None, None], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.running_mean', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.running_var', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.weight', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.bias', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='classifier.1.weight', shape=(10, 1), mapping=[None, [(1202, 0)]], tensor=None)
here shape
— is the shape of the filter, and mapping
can be used to extract the desired filter from tensor.
- get_tracing_node_by_module_name(name, graph)
Return the tracing node corresponding to the module with the given name.
Note: current implementation works ONLY for modules with “weight” parameter.
- Parameters:
name (str) – Name of module in torch model, for example: “features.5.conv.1.0”.
graph (List[Node]) – Tracing graph.
- Returns:
Node (OpNode) corresponding to the module with the given name.
- Return type:
OpNode
- Raises:
ValueError – If node corresponding to the module name is not found.
- get_module_name_by_tracing_node(node)
Return the module name corresponding to the node.
Note: current implementation works ONLY for modules with “weight” parameter. Also, you can easy get a module by its name with the help of
getattr_complex
.- Parameters:
node (OpNode) – The node for which the module name should be found.
- Returns:
Module name corresponding to the node.
- Return type:
- Raises:
ValueError – If module name corresponding to the node is not found.
Converting dataloader items to PyTorch model inputs
Short summary: When ENOT needs to pass items from your dataloader to your model, it requires a special mapping function. This function will be used as following:
# Your code:
your_model = ...
your_dataloader = ...
your_mapping_function = ...
# In the ENOT framework:
sample = next(your_dataloader)
model_args, model_kwargs = your_mapping_function(sample)
model_result = your_model(*model_args, **model_kwargs)
When working with ENOT framework, sometimes it is necessary to write functions, which can transform dataloader items to model input format. We can later use these functions to extract dataloader items with the __next__ method and convert them into user’s custom model’s input format.
Such function should take a single input of type Any - a single item from user dataloader. It should return a tuple with two elements. First element is a tuple with model positional arguments, which will be passed to model’s __call__ method by unpacking operator *args. Second element is a dictionary with string keys and any-type values, which defines model keyword arguments, and will be passed to it’s __call__ method by unpacking operator **kwargs.
Let’s see the most basic example:
DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor)
def my_conversion_function(x):
return (x[0], ), {} # Single positional argument for model, no keyword arguments.
The same functionality is provided by our function
default_sample_to_model_inputs()
.
Example function which changes model’s default forward keyword argument value:
DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor, should_profile: bool = True)
def my_conversion_function(x):
return (x[0], ), {'should_profile': False} # We do not want to profile!
Example function which also performs data pre-processing and moves it to the GPU:
DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor)
def my_conversion_function(x):
# Normalizing images, casting them to float32 format, and move to cuda.
return ((x[0].float()) / 255).cuda(), ), {}
Advanced case with complex dataloader item structure and model positional and keyword arguments:
DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'masks': List[torch.Tensor]}
Model expects: (sequence: torch.Tensor, mask: Optional[torch.Tensor] = None, unroll_attention: bool = False)
def my_conversion_function(x):
sequence = x['sequence']
mask = x['masks'][0]
return (sequence, ), {'mask': mask, 'unroll_attention': True}
- default_sample_to_model_inputs(sample)
Default function for sample to model input conversion.
This function covers the most simple case when dataloader returns pairs of images and labels in the form (images, labels). If the model receives a single positional argument, then you can use this function to convert dataloader output to model input.
- Parameters:
sample (tuple) – Tuple with one or more elements from which only the first item will be passed to the model. So, user-defined model must have a single positional argument in it’s forward function definition.
- Returns:
Model input args and kwargs, see more here.
- Return type:
tuple with two items - tuple and dict with str keys
Extracting the number of samples from dataloader item
Sometimes ENOT needs to know the number of samples of interest in dataloader item. Sample of interest could be an image, one second of audio, number of bounding boxes, e.t.c. What single sample means fully depends on you. To provide this information to ENOT, you should define a special function. This function will be used as following:
# Your code:
your_model = ...
your_dataloader = ...
your_function = ...
# In the ENOT framework:
sample = next(your_dataloader)
n_base_samples: int = your_function(sample)
# process n_base_samples
Such function should take a single input of type Any - a single item from user dataloader. It should return single integer value - number of samples of interest in dataloader sample.
Let’s see the most basic example:
DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
def my_function(x):
# Extract the number of images from the batch dimension.
return x[0].shape[0]
The same functionality is provided by our function
default_sample_to_n_samples()
.
Advanced case with complex dataloader item structure:
DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'mask': torch.Tensor}
def my_conversion_function(x):
# sequence is a tensor of shape (N, T1) with int64 dtype.
# translation is a tensor of shape (N, T2) with int64 dtype.
# mask is a tensor of shape (N, T1) with int64 dtype.
# mask is equal to 1 where sequence values are correct (not padded) and 0 elsewhere.
total_time_steps = x['mask'].sum()
total_seconds = total_time_steps // 100 # Let's suppose that one time step is 10ms.
return total_seconds # Total audio length in seconds in all samples.
- default_sample_to_n_samples(sample)
Default function to extract the number of samples of interest from dataloader item.
This function covers the most simple case when dataloader returns pairs of images and labels in the form (images, labels).