Pruning¶
enot.pruning
package contains functional for automatic pruning of user
models.
Supported models¶
Pruning engine works fine with many models including (but not limited to):
- (from torchvision)
- classification
efficientnet_b0
resnet18
resnet34
resnet50
wide_resnet50_2
densenet161
mobilenet_v2
mobilenet_v3_large
- detection
ssd300_vgg16
ssdlite320_mobilenet_v3_large
- segmentation
fcn_resnet50
fcn_resnet101
deeplabv3_resnet50
deeplabv3_resnet101
deeplabv3_mobilenet_v3_large
lraspp_mobilenet_v3_large
- (from Pytorch Image Models (timm))
efficientnet_b0
resnet18
resnet34
resnet50
wide_resnet50_2
densenet161
mobilenetv2_100
mobilenetv3_large_100
Introduction¶
This package features pre-defined or custom pruning procedures for removing
least important filters or neurons. enot.pruning
package currently supports
structured pruning procedure. User can define pruning ratio (percentage of
channels removed) manually or use any of the pre-defined strategies.
The first (and the simplest one) is equal pruning strategy, which keeps roughly the same percentage of channels in each prunable layer.
The second (and quite compute intense) is optimal pruning strategy, which searches for the optimal pruning ratios for each prunable layer.
A couple definitions to simplify documentation reading experience:
Pruning ratio is a fraction of channels to remove (prune) in prunable group or prunable layer.
Gate (or channel gate) is a special torch.nn.Module
which gathers and
saves channel “local importance”. “Local importance” means that you can compare
channels by their importance values within single layer, but not between
distinct layers.
Calibration procedure estimates channels importance values according to user model and data.
Channel label is the global channel index in a network.
Pruning config specifies pruning amount for each prunable layer. See more here.
Pruning API¶
Pruning functional accessible to user is divided into three sections:
High-level interface: the easiest way to use pruning functional;
Low-level interface: low-level API for pruning, takes moderate time to understand and use;
Utility functional: utility functions for low-level API.
High-level interface¶
High-level interface provides two ways to prune model. They work in the same manner, but have different policies for the amount of pruning for each layer.
- calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]¶
Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer).
- Parameters
model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
pruning_ratio (float, optional) – Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4).
finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) – Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Model function for execution. See
Notes
section for the detailed argument description. Default value is “forward”.
- Returns
Pruned model.
- Return type
Notes
Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call
eval
method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_equal(...)
entry_point
is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.
- calibrate_and_prune_model_optimal(model, dataloader, loss_function, latency_calculation_function, target_latency, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, min_channels_to_leave_fn=<function default_min_channels_to_leave>, channel_search_step_fn=<function default_channel_search_step>, n_search_steps=200, entry_point='forward', **kwargs)[source]¶
Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model.
- Parameters
model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
latency_calculation_function (Callable[[torch.nn.Module], float]) – Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module) and measure its execution “speed” (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and other “speed” criteria.
target_latency (float) – Target model latency. This argument should have the same units as
latency_calculation_function
’s output.finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
min_channels_to_leave_fn (Callable[[int], int], optional) – Function to construct minimal pruning configuration. This function should take the number of channels in specific layer and return the minimal number of channels to keep. Default value is
default_min_channels_to_leave()
.channel_search_step_fn (Callable[[int], int], optional) – Function to select search step for each group of channels. This search step defines the number of possible pruning ratios for a specific group. If search step is equal to 1 - then all number of channels after pruning are possible (with a lower bound set by
min_channels_to_leave_fn
). When search step is equal to 2 - the number of options is reduces by 2. This function should take the number of channels in specific layer and return search step. Default value isdefault_channel_search_step()
.n_search_steps (int, optional) – Number of sampled configurations for pruning (equal to the number of
latency_calculation_function
executions) to select optimal architecture. Default value is 200.entry_point (str, optional) – Model function for execution. See
Notes
section for the detailed argument description. Default value is “forward”.**kwargs – Additional keyword arguments for label selector.
- Returns
Pruned model.
- Return type
Notes
Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call
eval
method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_optimal(...)
entry_point
is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.
Low-level interface¶
The main class for pruning is ModelPruningInfo
. It contains necessary
information for model pruning: model execution graph, list of all prunable and
non-prunable channels, all channel importance values, e.t.c.
Model pruning info is obtained through calibration process. Calibration for
pruning is made by PruningCalibrator
class. This class’s instance
should be used as a context manager. Inside it’s context user should compute
losses and perform backward passes through his model on calibration data.
User can also calibrate his model by calling
calibrate_model_for_pruning()
function and provide all necessary
arguments for proper model execution and loss computation.
Actual model pruning is performed by prune_model()
function which
requires pruning info object and list of channel labels to prune.
To select which labels should be pruned, and which should not, we created a
special interface for label selection. Base class for all label selectors is
PruningLabelSelector
. This class defines an abstract function
PruningLabelSelector.select()
which should select labels to prune and
which subclasses should implement to fit our pipelines based on
calibrate_and_prune_model()
.
Abstract class TopKPruningLabelSelector
is a straightforward way to
select least important channel labels. It utilises Pruning config to select
least important channels in each prunable layer. This is done by sorting
channels by their importance values and selecting top-K least important
channels based on the corresponding pruning ratio from the pruning config.
Our label selectors based on this class are the following:
They are used internally in high-level interface for pruning described above.
Pruning info class:
- class ModelPruningInfo(graph, prunable_groups, unprunable_groups, channel_labels_by_nodes, node_users, gate_container, gate_names_by_locations, prunable_group_gate_names, gate_apply_info, child_labels_by_parents, int_pruners)¶
Container with model pruning-related information.
- property n_prunable_groups: int¶
Number of prunable channel groups.
- Returns
Number of prunable channel groups.
- Return type
Calibration:
- class PruningCalibrator(model, entry_point='forward', max_prunable_labels=4096)¶
Context manager for gathering pruning information.
This context manager walks through model and collects information about prunable and unprunable channels and estimates channel importance values. The main purpose of this calibrator is to create a
ModelPruningInfo
object which can be later used to prune the model.Model pruning info can be accessed via
PruningCalibrator.pruning_info
.Examples
>>> p_calibrator = PruningCalibrator( ... model=baseline_model, ... ) >>> with p_calibrator: >>> for inputs, labels in dataloader: >>> model_output = model(inputs) >>> loss = loss_function(model_output, labels) >>> loss.backward() >>> p_calibrator.pruning_info
See also
calibrate_and_prune_model
,calibrate_and_prune_model_equal
,calibrate_and_prune_model_optimal
- __init__(model, entry_point='forward', max_prunable_labels=4096)¶
- Parameters
model (torch.nn.Module) – Model which user wants to calibrate for pruning.
entry_point (str, optional) – Model function for execution. See
Notes
section for the detailed argument description. Default value is “forward”.max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.
Notes
entry_point
is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.
- property pruning_info: Optional[enot.pruning.pruning_info.ModelPruningInfo]¶
Information about model pruning.
- Returns
Pruning-related information about model. None if calibrator instance was not used as a context manager or model was calibrated incorrectly.
- Return type
- calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward', max_prunable_labels=4096)[source]¶
Estimates model channel importance values for later pruning.
This function searches for prunable channels in user-defined
model
. After extracting channel information from model graph, estimates channel importance values for later pruning.- Parameters
model (torch.nn.Module) – Model to calibrate pruning gates for.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
n_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in
epochs
argument.epochs (int, optional) – Number of total calibration epochs. Not used when
n_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Name of model method for tracing. In this method all prunable operations should be forwarded. Default value is ‘forward’.
max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.
- Returns
Pruning information for later usage in pruning methods.
- Return type
Notes
Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call
eval
method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...)
Pruning:
- prune_model(model, pruning_info, prune_labels, inplace=True)¶
Remove (prune) least important channels defined by
prune_labels
parameter.- Parameters
model (torch.nn.Module) – Model for pruning.
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).prune_labels (Sequence of Label) – Sequence of global channel index to prune.
inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).
- Returns
Pruned model.
- Return type
The following function combines the functional of
enot.pruning.calibrate_model_for_pruning()
and enot.pruning.prune_model()
functions:
- calibrate_and_prune_model(label_selector, model, dataloader, loss_function, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]¶
Estimates channel importance values and prunes model with user defined strategy.
This function searches for prunable channels in user-defined
model
. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategylabel_selector
.- Parameters
label_selector (PruningLabelSelector) – Channel selector object. This object should implement
PruningLabelSelector.select()
method which returns list with channel indices to prune.model (torch.nn.Module) – Model to prune.
dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.
loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.
finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.
calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in
calibration_epochs
argument.calibration_epochs (int, optional) – Number of total calibration epochs. Not used when
calibration_steps
argument is not None. Default value is 1.sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is
default_sample_to_n_samples()
. See more here.sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is
default_sample_to_model_inputs()
. See more here.show_tqdm (bool, optional) –
Whether to log calibration procedure with tqdm progress bar. Default value is False.
entry_point (str, optional) – Model function for execution. See
Notes
section for the detailed argument description. Default value is “forward”.
- Returns
Pruned model.
- Return type
Notes
Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call
eval
method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement:
with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...)
An example of user-defined strategy:
import numpy as np from typing import List from enot.pruning import ModelPruningInfo, PruningLabelSelector, iterate_over_gate_criteria class CustomPruningSelector(PruningLabelSelector): def __init__(self, pruning_ratio: float): self.pruning_ratio: float = pruning_ratio super().__init__() def select(self, pruning_info: ModelPruningInfo) -> List[int]: labels_to_prune: List[int] = [] for _, labels, criteria in iterate_over_gate_criteria(pruning_info): criteria: np.ndarray = np.array(criteria) prune_channels_num = int(len(criteria) * self.pruning_ratio) index_for_pruning = np.argsort(criteria)[:prune_channels_num] labels_to_prune += np.array(labels)[index_for_pruning].tolist() return labels_to_prune
entry_point
is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.
Label selection:
- class PruningLabelSelector[source]¶
Base class for all label selectors for pruning.
This class defines an abstract method
PruningLabelSelector.select()
which should return labels to prune.- __init__()¶
- property labels_for_pruning: List[enot.pruning.labels.label.Label]¶
Labels to prune.
- Returns
List of all channel labels to prune.
- Return type
list of int
- Raises
RuntimeError – If labels were not calculated (
PruningLabelSelector.select()
method was not called).
- abstract select(pruning_info)¶
Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.
Warning
Depending on label selector implementation, this function may have significant execution time.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Returns
List of channel labels which should be pruned.
- Return type
list of int
- class TopKPruningLabelSelector[source]¶
Bases:
enot.pruning.label_selector.PruningLabelSelector
Base class for label selectors based on pruning config.
For description of pruning configs see here.
This class implements
PruningLabelSelector.select()
method by callingTopKPruningLabelSelector.get_config_for_pruning()
abstract method to generate pruning config, and selects top-k least important channels from each group with k values based on pruning config.- __init__()¶
- abstract get_config_for_pruning(pruning_info)¶
Abstract method which should implement Pruning config selection strategy.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Returns
Config for pruning.
- Return type
- static get_labels_by_config(pruning_info, pruning_config)¶
Transforms Pruning config into channel labels to prune.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).pruning_config (PruningConfig) – Config for pruning.
- Returns
Channel indices to prune (remove from the model).
- Return type
list of int
- property pruning_cfg: Optional[Union[Sequence[int], Sequence[float]]]¶
Returns current pruning config.
- Returns
pruning_config – Config for pruning.
- Return type
- select(pruning_info)¶
Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.
Warning
Depending on label selector implementation, this function may have significant execution time.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Returns
List of channel labels which should be pruned.
- Return type
list of int
- class UniformPruningLabelSelector(pruning_ratio)[source]¶
Bases:
enot.pruning.label_selector.TopKPruningLabelSelector
Label selector for uniform (equal) pruning.
Removes an equal percentage of channels in each prunable layer.
- __init__(pruning_ratio)¶
- Parameters
pruning_ratio (float) – Percentage of prunable channels to remove in each group.
- get_config_for_pruning(pruning_info)¶
Abstract method which should implement Pruning config selection strategy.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Returns
Config for pruning.
- Return type
- class OptimalPruningLabelSelector(model, latency_calculation_function, target_latency, *, min_channels_to_leave_fn=<function default_min_channels_to_leave>, channel_search_step_fn=<function default_channel_search_step>, n_search_steps=200, **kwargs)¶
Label selector based on estimation of optimal Pruning config through bayesian optimization.
- __init__(model, latency_calculation_function, target_latency, *, min_channels_to_leave_fn=<function default_min_channels_to_leave>, channel_search_step_fn=<function default_channel_search_step>, n_search_steps=200, **kwargs)¶
- Parameters
model (torch.nn.Module) – Model to prune.
latency_calculation_function (Callable[[torch.nn.Module], float]) – Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module) and measure its execution “speed” (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and other “speed” criteria.
target_latency (float) – Target model latency. This argument should have the same units as
latency_calculation_function
’s output. Latency value should be larger than latency of minimal model constructed according tominimal_channels_to_leave
argument.min_channels_to_leave_fn (Callable[[int], int], optional) – Function to construct minimal pruning configuration. This function should take the number of channels in specific layer and return the minimal number of channels to keep. Default value is
default_min_channels_to_leave()
.channel_search_step_fn (Callable[[int], int], optional) – Function to select search step for each group of channels. This search step defines the number of possible pruning ratios for a specific group. If search step is equal to 1 - then all number of channels after pruning are possible (with a lower bound set by
min_channels_to_leave_fn
). When search step is equal to 2 - the number of options is reduces by 2. This function should take the number of channels in specific layer and return search step. Default value isdefault_channel_search_step()
.n_search_steps (int, optional) – Number of sampled configurations for pruning (equal to the number of
latency_calculation_function
executions) to select optimal architecture. Default value is 200.**kwargs – Additional keyword arguments for label selector.
- property baseline_latency: float¶
Baseline model latency.
- Returns
Baseline model latency. Baseline model is the model provided to the constructor.
- Return type
- get_config_for_pruning(pruning_info)¶
Abstract method which should implement Pruning config selection strategy.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Returns
Config for pruning.
- Return type
- static get_labels_by_config(pruning_info, pruning_config)¶
Transforms Pruning config into channel labels to prune.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).pruning_config (PruningConfig) – Config for pruning.
- Returns
Channel indices to prune (remove from the model).
- Return type
list of int
- property labels_for_pruning: List[enot.pruning.labels.label.Label]¶
Labels to prune.
- Returns
List of all channel labels to prune.
- Return type
list of int
- Raises
RuntimeError – If labels were not calculated (
PruningLabelSelector.select()
method was not called).
- property minimal_network_latency: float¶
Minimal model latency.
- Returns
Minimal model latency. Minimal model is the model constructed according to the
minimal_channels_to_leave
argument in this class’s constructor.- Return type
- property n_search_steps: int¶
Number of models to evaluate.
- Returns
Number of models to evaluate during search procedure.
- Return type
- property pruning_cfg: Optional[Union[Sequence[int], Sequence[float]]]¶
Returns current pruning config.
- Returns
pruning_config – Config for pruning.
- Return type
- select(pruning_info)¶
Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.
Warning
Depending on label selector implementation, this function may have significant execution time.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Returns
List of channel labels which should be pruned.
- Return type
list of int
Utility functional¶
ModelPruningInfo utilities:
- iterate_over_gate_criteria(pruning_info)¶
Iterates over pruning gates and yields channel labels and channel criteria.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Yields
group_id (int) – Global index number of a gate.
labels (list of Labels) – Channel labels of a gate.
criteria (list of float) – Channel criteria of a gate.
- Returns
Generator with items specified above.
- Return type
Generator
- get_criteria_label_dict(pruning_info)¶
Collects information about channel importance values in human-readable format.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).- Returns
Dictionary with channel labels and importance values. It’s keys are channel group names, and values store both global channel indices (labels) and channel importance values.
- Return type
Dict[str, Dict[str, numpy.ndarray]]
- get_labels_for_uniform_pruning(pruning_info, pruning_ratio=0.5)¶
Returns labels of least important channels of every prunable layer.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).pruning_ratio (float) – Specifies percentage of least important channels selection for pruning.
- Returns
List of least important channel labels.
- Return type
list of int
Pruning config utilities:
- check_pruning_config(pruning_config)¶
Validates pruning config.
- Parameters
pruning_config (PruningConfig) – Config for pruning.
- Raises
ValueError – If pruning config is not valid.
- Return type
- get_least_important_labels_by_config(pruning_info, pruning_config)¶
Extracts least important channel labels based on pruning config.
- Parameters
pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in
calibrate_model_for_pruning()
).pruning_config (PruningConfig) – Config for pruning.
- Returns
Least important labels from all layers.
- Return type
list of Label
OptimalPruningLabelSelector utilities:
- default_min_channels_to_leave(current_layer_channels, absolute_min_channels=16, relative_min_channels=0.25)[source]¶
Default function to specify the number of channels to keep by the total number of channels in a layer.
This function allows to set two boundaries for minimal number of channels in each layer: absolute and relative. Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels in a layer. Result is a maximal value among these two.
- Parameters
current_layer_channels (int) – Number of channels in a layer.
absolute_min_channels (int, optional) – Absolute lower bound. If it is larger than
current_layer_channels
- thencurrent_layer_channels
is returned. Default value is 16.relative_min_channels (float, optional) – Relative lower bound. Default value is 0.25.
- Returns
The minimal number of channels to keep among
current_layer_channels
.- Return type
- default_channel_search_step(current_layer_channels)[source]¶
Default function to specify search step for channel group by its size.
It returns step equal to 1 if channel group has size less or equal to 64, and returns step equal to 4 otherwise.
- Parameters
current_layer_channels (int) – Number of channels in a layer.
- Returns
Search step for Pruning config generation.
- Return type