Pruning

enot.pruning package contains functional for automatic pruning of user models.

Supported models

Pruning engine works fine with many models including (but not limited to):

  • (from torchvision)
    • classification
      • efficientnet_b0

      • resnet18

      • resnet34

      • resnet50

      • wide_resnet50_2

      • densenet161

      • mobilenet_v2

      • mobilenet_v3_large

    • detection
      • ssd300_vgg16

      • ssdlite320_mobilenet_v3_large

    • segmentation
      • fcn_resnet50

      • fcn_resnet101

      • deeplabv3_resnet50

      • deeplabv3_resnet101

      • deeplabv3_mobilenet_v3_large

      • lraspp_mobilenet_v3_large

  • (from Pytorch Image Models (timm))
    • efficientnet_b0

    • resnet18

    • resnet34

    • resnet50

    • wide_resnet50_2

    • densenet161

    • mobilenetv2_100

    • mobilenetv3_large_100

Introduction

This package features pre-defined or custom pruning procedures for removing least important filters or neurons. enot.pruning package currently supports structured pruning procedure. User can define pruning ratio (percentage of channels removed) manually or use any of the pre-defined strategies.

The first (and the simplest one) is equal pruning strategy, which keeps roughly the same percentage of channels in each prunable layer.

The second (and quite compute intense) is optimal pruning strategy, which searches for the optimal pruning ratios for each prunable layer.

A couple definitions to simplify documentation reading experience:

Pruning ratio is a fraction of channels to remove (prune) in prunable group or prunable layer.

Gate (or channel gate) is a special torch.nn.Module which gathers and saves channel “local importance”. “Local importance” means that you can compare channels by their importance values within single layer, but not between distinct layers.

Calibration procedure estimates channels importance values according to user model and data.

Channel label is the global channel index in a network.

Pruning config specifies pruning amount for each prunable layer. See more here.

Pruning API

Pruning functional accessible to user is divided into three sections:

High-level interface

High-level interface provides two ways to prune model. They work in the same manner, but have different policies for the amount of pruning for each layer.

calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]

Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer).

Parameters:
  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • pruning_ratio (float, optional) – Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4).

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) – Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Model function for execution. See Notes section for the detailed argument description. Default value is “forward”.

Returns:

Pruned model.

Return type:

torch.nn.Module

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_and_prune_model_equal(...)

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.

calibrate_and_prune_model_optimal(model, dataloader, loss_function, latency_calculation_function, target_latency, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, architecture_optimization_strategy=<function default_architecture_optimization_strategy>, n_search_steps=200, entry_point='forward', **kwargs)[source]

Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model.

Parameters:
  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module) and measure its execution “speed” (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and other “speed” criteria.

  • target_latency (float) – Target model latency. This argument should have the same units as latency_calculation_function’s output.

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • architecture_optimization_strategy (Callable[[int], Tuple[int, int]]) – Function that constructs an architecture pruning configuration. This function should take the number of channels in specific layer and return the minimal number of channels to keep and channel search step. This search step defines the number of possible pruning ratios for a specific group. If search step is equal to 1, then all number of channels after pruning are possible (with a lower bound set by the minimal number of channels). When search step is equal to 2 - the number of options is reduces by 2. This function should take the number of channels in specific layer and return search step. Default value is default_architecture_optimization_strategy().

  • n_search_steps (int, optional) – Number of sampled configurations for pruning (equal to the number of latency_calculation_function executions) to select optimal architecture. Default value is 200.

  • entry_point (str, optional) – Model function for execution. See Notes section for the detailed argument description. Default value is “forward”.

  • **kwargs – Additional keyword arguments for label selector.

Returns:

Pruned model.

Return type:

torch.nn.Module

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_and_prune_model_optimal(...)

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.

Low-level interface

The main class for pruning is ModelPruningInfo. It contains necessary information for model pruning: model execution graph, list of all prunable and non-prunable channels, all channel importance values, e.t.c.

Model pruning info is obtained through calibration process. Calibration for pruning is made by PruningCalibrator class. This class’s instance should be used as a context manager. Inside it’s context user should compute losses and perform backward passes through his model on calibration data.

User can also calibrate his model by calling calibrate_model_for_pruning() function and provide all necessary arguments for proper model execution and loss computation.

Actual model pruning is performed by prune_model() function which requires pruning info object and list of channel labels to prune.

To select which labels should be pruned, and which should not, we created a special interface for label selection. Base class for all label selectors is PruningLabelSelector. This class defines an abstract function PruningLabelSelector.select() which should select labels to prune and which subclasses should implement to fit our pipelines based on calibrate_and_prune_model().

Abstract class TopKPruningLabelSelector is a straightforward way to select least important channel labels. It utilises Pruning config to select least important channels in each prunable layer. This is done by sorting channels by their importance values and selecting top-K least important channels based on the corresponding pruning ratio from the pruning config.

Our label selectors based on this class are the following:

They are used internally in high-level interface for pruning described above.

Pruning info class:

class ModelPruningInfo(graph, prunable_groups, unprunable_groups, channel_labels_by_nodes, node_users, gate_container, gate_names_by_locations, prunable_group_gate_names, gate_apply_info, child_labels_by_parents, int_pruners)

Container with model pruning-related information.

property n_prunable_groups: int

Number of prunable channel groups.

Returns:

Number of prunable channel groups.

Return type:

int

summary(level=0)

Returns string representation of useful information from pruning parser.

Parameters:

level (int, optional) – Degree of detail. Should be from 0 to 2 inclusive. Default value is 0.

Returns:

String with information about pruning.

Return type:

str

Calibration:

class PruningCalibrator(model, entry_point='forward', max_prunable_labels=4096)

Context manager for gathering pruning information.

This context manager walks through model and collects information about prunable and unprunable channels and estimates channel importance values. The main purpose of this calibrator is to create a ModelPruningInfo object which can be later used to prune the model.

Model pruning info can be accessed via PruningCalibrator.pruning_info.

Examples

>>> p_calibrator = PruningCalibrator(
...     model=baseline_model,
... )
>>> with p_calibrator:
>>>     for inputs, labels in dataloader:
>>>         model_output = model(inputs)
>>>         loss = loss_function(model_output, labels)
>>>         loss.backward()
>>> p_calibrator.pruning_info
__init__(model, entry_point='forward', max_prunable_labels=4096)
Parameters:
  • model (torch.nn.Module) – Model which user wants to calibrate for pruning.

  • entry_point (str, optional) – Model function for execution. See Notes section for the detailed argument description. Default value is “forward”.

  • max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.

Notes

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.

property pruning_info: ModelPruningInfo | None

Information about model pruning.

Returns:

Pruning-related information about model. None if calibrator instance was not used as a context manager or model was calibrated incorrectly.

Return type:

ModelPruningInfo or None

calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward', max_prunable_labels=4096)[source]

Estimates model channel importance values for later pruning.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importance values for later pruning.

Parameters:
  • model (torch.nn.Module) – Model to calibrate pruning gates for.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • n_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in epochs argument.

  • epochs (int, optional) – Number of total calibration epochs. Not used when n_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Name of model method for tracing. In this method all prunable operations should be forwarded. Default value is ‘forward’.

  • max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.

Returns:

Pruning information for later usage in pruning methods.

Return type:

ModelPruningInfo

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_model_for_pruning(...)

Pruning:

prune_model(model, pruning_info, prune_labels, inplace=True)

Remove (prune) least important channels defined by prune_labels parameter.

Parameters:
  • model (torch.nn.Module) – Model for pruning.

  • pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

  • prune_labels (Sequence of Label) – Sequence of global channel index to prune.

  • inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).

Returns:

Pruned model.

Return type:

torch.nn.Module

The following function combines the functional of enot.pruning.calibrate_model_for_pruning() and enot.pruning.prune_model() functions:

calibrate_and_prune_model(label_selector, model, dataloader, loss_function, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]

Estimates channel importance values and prunes model with user defined strategy.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategy label_selector.

Parameters:
  • label_selector (PruningLabelSelector) – Channel selector object. This object should implement PruningLabelSelector.select() method which returns list with channel indices to prune.

  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Model function for execution. See Notes section for the detailed argument description. Default value is “forward”.

Returns:

Pruned model.

Return type:

torch.nn.Module

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_model_for_pruning(...)

An example of user-defined strategy:

import numpy as np
from typing import List
from enot.pruning import ModelPruningInfo, PruningLabelSelector, iterate_over_gate_criteria

class CustomPruningSelector(PruningLabelSelector):

    def __init__(self, pruning_ratio: float):
        self.pruning_ratio: float = pruning_ratio
        super().__init__()

    def select(self, pruning_info: ModelPruningInfo) -> List[int]:
        labels_to_prune: List[int] = []
        for _, labels, criteria in iterate_over_gate_criteria(pruning_info):
            criteria: np.ndarray = np.array(criteria)
            prune_channels_num = int(len(criteria) * self.pruning_ratio)
            index_for_pruning = np.argsort(criteria)[:prune_channels_num]
            labels_to_prune += np.array(labels)[index_for_pruning].tolist()
        return labels_to_prune

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.

Label selection:

class PruningLabelSelector[source]

Base class for all label selectors for pruning.

This class defines an abstract method PruningLabelSelector.select() which should return labels to prune.

__init__()
property labels_for_pruning: List[Label]

Labels to prune.

Returns:

List of all channel labels to prune.

Return type:

list of int

Raises:

RuntimeError – If labels were not calculated (PruningLabelSelector.select() method was not called).

abstract select(pruning_info)

Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns:

List of channel labels which should be pruned.

Return type:

list of int

class TopKPruningLabelSelector[source]

Bases: PruningLabelSelector

Base class for label selectors based on pruning config.

For description of pruning configs see here.

This class implements PruningLabelSelector.select() method by calling TopKPruningLabelSelector.get_config_for_pruning() abstract method to generate pruning config, and selects top-k least important channels from each group with k values based on pruning config.

__init__()
abstract get_config_for_pruning(pruning_info)

Abstract method which should implement Pruning config selection strategy.

Parameters:

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns:

Config for pruning.

Return type:

PruningConfig

static get_labels_by_config(pruning_info, pruning_config)

Transforms Pruning config into channel labels to prune.

Parameters:
Returns:

Channel indices to prune (remove from the model).

Return type:

list of int

property pruning_cfg: Sequence[int] | Sequence[float] | None

Returns current pruning config.

Returns:

pruning_config – Config for pruning.

Return type:

PruningConfig

select(pruning_info)

Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns:

List of channel labels which should be pruned.

Return type:

list of int

class UniformPruningLabelSelector(pruning_ratio)[source]

Bases: TopKPruningLabelSelector

Label selector for uniform (equal) pruning.

Removes an equal percentage of channels in each prunable layer.

__init__(pruning_ratio)
Parameters:

pruning_ratio (float) – Percentage of prunable channels to remove in each group.

get_config_for_pruning(pruning_info)

Abstract method which should implement Pruning config selection strategy.

Parameters:

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns:

Config for pruning.

Return type:

PruningConfig

class OptimalPruningLabelSelector(model, latency_calculation_function, target_latency, *, architecture_optimization_strategy=<function default_architecture_optimization_strategy>, n_search_steps=200, **kwargs)

Base class for label selection by estimation of optimal Pruning config through bayesian optimization.

__init__(model, latency_calculation_function, target_latency, *, architecture_optimization_strategy=<function default_architecture_optimization_strategy>, n_search_steps=200, **kwargs)
Parameters:
  • model (torch.nn.Module) – Model to prune.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module) and measure its execution “speed” (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and other “speed” criteria.

  • target_latency (float) – Target model latency. This argument should have the same units as latency_calculation_function’s output. Latency value should be larger than latency of minimal model constructed according to minimal_channels_to_leave argument.

  • architecture_optimization_strategy (Callable[[int], Tuple[int, int]]) – Function that constructs an architecture pruning configuration. This function should take the number of channels in specific layer and return the minimal number of channels to keep and channel search step. This search step defines the number of possible pruning ratios for a specific group. If search step is equal to 1, then all number of channels after pruning are possible (with a lower bound set by the minimal number of channels). When search step is equal to 2 - the number of options is reduces by 2. This function should take the number of channels in specific layer and return search step. Default value is default_architecture_optimization_strategy().

  • n_search_steps (int, optional) – Number of sampled configurations for pruning (equal to the number of latency_calculation_function executions) to select optimal architecture. Default value is 200.

  • **kwargs – Additional keyword arguments for label selector.

property baseline_latency: float

Baseline model latency.

Returns:

Baseline model latency. Baseline model is the model provided to the constructor.

Return type:

float

get_config_for_pruning(pruning_info)

Abstract method which should implement Pruning config selection strategy.

Parameters:

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns:

Config for pruning.

Return type:

PruningConfig

static get_labels_by_config(pruning_info, pruning_config)

Transforms Pruning config into channel labels to prune.

Parameters:
Returns:

Channel indices to prune (remove from the model).

Return type:

list of int

property labels_for_pruning: List[Label]

Labels to prune.

Returns:

List of all channel labels to prune.

Return type:

list of int

Raises:

RuntimeError – If labels were not calculated (PruningLabelSelector.select() method was not called).

property minimal_network_latency: float

Minimal model latency.

Returns:

Minimal model latency. Minimal model is the model constructed according to the minimal_channels_to_leave argument in this class’s constructor.

Return type:

float

property n_search_steps: int

Number of models to evaluate.

Returns:

Number of models to evaluate during search procedure.

Return type:

int

property pruning_cfg: Sequence[int] | Sequence[float] | None

Returns current pruning config.

Returns:

pruning_config – Config for pruning.

Return type:

PruningConfig

select(pruning_info)

Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns:

List of channel labels which should be pruned.

Return type:

list of int

Utility functional

ModelPruningInfo utilities:

iterate_over_gate_criteria(pruning_info)

Iterates over pruning gates and yields channel labels and channel criteria.

Parameters:

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Yields:
  • group_id (int) – Global index number of a gate.

  • labels (list of Labels) – Channel labels of a gate.

  • criteria (list of float) – Channel criteria of a gate.

Returns:

Generator with items specified above.

Return type:

Generator

get_criteria_label_dict(pruning_info)

Collects information about channel importance values in human-readable format.

Parameters:

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns:

Dictionary with channel labels and importance values. It’s keys are channel group names, and values store both global channel indices (labels) and channel importance values.

Return type:

Dict[str, Dict[str, numpy.ndarray]]

get_labels_for_uniform_pruning(pruning_info, pruning_ratio=0.5)

Returns labels of least important channels of every prunable layer.

Parameters:
Returns:

List of least important channel labels.

Return type:

list of int

Pruning config utilities:

check_pruning_config(pruning_config)

Validates pruning config.

Parameters:

pruning_config (PruningConfig) – Config for pruning.

Raises:

ValueError – If pruning config is not valid.

Return type:

None

get_least_important_labels_by_config(pruning_info, pruning_config)

Extracts least important channel labels based on pruning config.

Parameters:
Returns:

Least important labels from all layers.

Return type:

list of Label

OptimalPruningLabelSelector utilities:

default_architecture_optimization_strategy(current_layer_channels, absolute_min_channels=16, relative_min_channels=0.25)[source]

Default function to specify the number of channels to keep by the total number of channels in a layer and search step for channel group (or channel stride).

This function allows to set two boundaries for minimal number of channels in each layer: absolute and relative. Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels in a layer. Result is a maximal value among these two.

For channel stride it returns step equal to 1 if channel group has size less or equal to 64, and returns step equals to 4 otherwise.

Parameters:
  • current_layer_channels (int) – Number of channels in a layer.

  • absolute_min_channels (int, optional) – Absolute lower bound. If it is larger than current_layer_channels - then current_layer_channels is returned. Default value is 16.

  • relative_min_channels (float, optional) – Relative lower bound. Default value is 0.25.

Return type:

Tuple[int, int]

Returns:

  • int – The minimal number of channels to keep among current_layer_channels.

  • int – Search step for Pruning config generation.