Pruning

enot.pruning package contains functional for automatic pruning of user models.

Supported models

Pruning engine works fine with many models including (but not limited to):

  • (from torchvision)
    • classification
      • efficientnet_b0

      • resnet18

      • resnet34

      • resnet50

      • wide_resnet50_2

      • densenet161

      • mobilenet_v2

      • mobilenet_v3_large

    • detection
      • ssd300_vgg16

      • ssdlite320_mobilenet_v3_large

    • segmentation
      • fcn_resnet50

      • fcn_resnet101

      • deeplabv3_resnet50

      • deeplabv3_resnet101

      • deeplabv3_mobilenet_v3_large

      • lraspp_mobilenet_v3_large

  • (from Pytorch Image Models (timm))
    • efficientnet_b0

    • resnet18

    • resnet34

    • resnet50

    • wide_resnet50_2

    • densenet161

    • mobilenetv2_100

    • mobilenetv3_large_100

Introduction

This package features pre-defined or custom pruning procedures for removing least important filters or neurons. enot.pruning package currently supports structured pruning procedure. User can define pruning ratio (percentage of channels removed) manually or use any of the pre-defined strategies.

The first (and the simplest one) is equal pruning strategy, which keeps roughly the same percentage of channels in each prunable layer.

The second (and quite compute intense) is optimal pruning strategy, which searches for the optimal pruning ratios for each prunable layer.

A couple definitions to simplify documentation reading experience:

Pruning ratio is a fraction of channels to remove (prune) in prunable group or prunable layer.

Gate (or channel gate) is a special torch.nn.Module which gathers and saves channel “local importance”. “Local importance” means that you can compare channels by their importance values within single layer, but not between distinct layers.

Calibration procedure estimates channels importance values according to user model and data.

Channel label is the global channel index in a network.

Pruning config specifies pruning amount for each prunable layer. See more here.

Pruning API

Pruning functional accessible to user is divided into three sections:

High-level interface

High-level interface provides two ways to prune model. They work in the same manner, but have different policies for the amount of pruning for each layer.

calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False)[source]

Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer).

Parameters
  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • pruning_ratio (float, optional) – Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4).

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) – Whether to log calibration procedure with tqdm progress bar. Default value is False.

Returns

Pruned model.

Return type

torch.nn.Module

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_and_prune_model_equal(...)
calibrate_and_prune_model_optimal(model, dataloader, loss_function, latency_calculation_function, target_latency, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, min_channels_to_leave_fn=<function default_min_channels_to_leave>, channel_search_step_fn=<function default_channel_search_step>, n_search_steps=200, **kwargs)[source]

Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model.

Parameters
  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module) and measure its execution “speed” (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and other “speed” criteria.

  • target_latency (float) – Target model latency. This argument should have the same units as latency_calculation_function’s output.

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • min_channels_to_leave_fn (Callable[[int], int], optional) – Function to construct minimal pruning configuration. This function should take the number of channels in specific layer and return the minimal number of channels to keep. Default value is default_min_channels_to_leave().

  • channel_search_step_fn (Callable[[int], int], optional) – Function to select search step for each group of channels. This search step defines the number of possible pruning ratios for a specific group. If search step is equal to 1 - then all number of channels after pruning are possible (with a lower bound set by min_channels_to_leave_fn). When search step is equal to 2 - the number of options is reduces by 2. This function should take the number of channels in specific layer and return search step. Default value is default_channel_search_step().

  • n_search_steps (int, optional) – Number of sampled configurations for pruning (equal to the number of latency_calculation_function executions) to select optimal architecture. Default value is 200.

  • **kwargs – Additional keyword arguments for label selector.

Returns

Pruned model.

Return type

torch.nn.Module

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_and_prune_model_optimal(...)

Low-level interface

The main class for pruning is ModelPruningInfo. It contains necessary information for model pruning: model execution graph, list of all prunable and non-prunable channels, all channel importance values, e.t.c.

Model pruning info is obtained through calibration process. Calibration for pruning is made by EnotPruningCalibrator class. This class’s instance should be used as a context manager. Inside it’s context user should compute losses and perform backward passes through his model on calibration data.

User can also calibrate his model by calling calibrate_model_for_pruning() function and provide all necessary arguments for proper model execution and loss computation.

Actual model pruning is performed by prune_model() function which requires pruning info object and list of channel labels to prune.

To select which labels should be pruned, and which should not, we created a special interface for label selection. Base class for all label selectors is PruningLabelSelector. This class defines an abstract function PruningLabelSelector.select() which should select labels to prune and which subclasses should implement to fit our pipelines based on calibrate_and_prune_model().

Abstract class TopKPruningLabelSelector is a straightforward way to select least important channel labels. It utilises Pruning config to select least important channels in each prunable layer. This is done by sorting channels by their importance values and selecting top-K least important channels based on the corresponding pruning ratio from the pruning config.

Our label selectors based on this class are the following:

They are used internally in high-level interface for pruning described above.

Pruning info class:

class ModelPruningInfo(graph, prunable_groups, unprunable_groups, labels_by_nodes, node_users, prunable_module_names_by_nodes, gate_container, gate_names_by_nodes, prunable_group_gate_names, ops_and_gate_names)

Container with model pruning-related information.

property n_prunable_groups: int

Number of prunable channel groups.

Returns

Number of prunable channel groups.

Return type

int

summary(level=0)

Returns string representation of useful information from pruning parser.

Parameters

level (int, optional) – Degree of detail. Should be from 0 to 2 inclusive. Default value is 0.

Returns

String with information about pruning.

Return type

str

Calibration:

class EnotPruningCalibrator(model, entry_point='forward')

Context manager for gathering pruning information.

This context manager walks through model and collects information about prunable and unprunable channels and estimates channel importance values. The main purpose of this calibrator is to create a ModelPruningInfo object which can be later used to prune the model.

Model pruning info can be accessed via EnotPruningCalibrator.pruning_info.

Examples

>>> p_calibrator = EnotPruningCalibrator(
...     model=baseline_model,
... )
>>> with p_calibrator:
>>>     for inputs, labels in dataloader:
>>>         model_output = model(inputs)
>>>         loss = loss_function(model_output, labels)
>>>         loss.backward()
>>> p_calibrator.pruning_info
__init__(model, entry_point='forward')
Parameters
  • model (torch.nn.Module) – Model which user wants to calibrate for pruning.

  • entry_point (str, optional) – Model function for execution. See Notes section for the detailed argument description. Default value is “forward”.

Notes

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: “forward”, “execute_model”, “forward_train”. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: “submodule.forward”, “head.predict_features”, “submodule1.submodule2.forward”.

property pruning_info: Optional[enot.pruning.pruning_info.ModelPruningInfo]

Information about model pruning.

Returns

Pruning-related information about model. None if calibrator instance was not used as a context manager or model was calibrated incorrectly.

Return type

ModelPruningInfo or None

calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False)[source]

Estimates model channel importance values for later pruning.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importance values for later pruning.

Parameters
  • model (torch.nn.Module) – Model to calibrate pruning gates for.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • n_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in epochs argument.

  • epochs (int, optional) – Number of total calibration epochs. Not used when n_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

Returns

Pruning information for later usage in pruning methods.

Return type

ModelPruningInfo

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_model_for_pruning(...)

Pruning:

prune_model(model, pruning_info, prune_labels, inplace=True)

Remove (prune) least important channels defined by prune_labels parameter.

Parameters
  • model (torch.nn.Module) – Model for pruning.

  • pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

  • prune_labels (Sequence of int) – Sequence of global channel index to prune.

  • inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).

Returns

Pruned model.

Return type

torch.nn.Module

The following function combines the functional of enot.pruning.calibrate_model_for_pruning() and enot.pruning.prune_model() functions:

calibrate_and_prune_model(label_selector, model, dataloader, loss_function, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False)[source]

Estimates channel importance values and prunes model with user defined strategy.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategy label_selector.

Parameters
  • label_selector (PruningLabelSelector) – Channel selector object. This object should implement PruningLabelSelector.select() method which returns list with channel indices to prune.

  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

Returns

Pruned model.

Return type

torch.nn.Module

Notes

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_model_for_pruning(...)

An example of user-defined strategy:

import numpy as np
from typing import List
from enot.pruning import ModelPruningInfo, PruningLabelSelector, iterate_over_gate_criteria

class CustomPruningSelector(PruningLabelSelector):

    def __init__(self, pruning_ratio: float):
        self.pruning_ratio: float = pruning_ratio
        super().__init__()

    def select(self, pruning_info: ModelPruningInfo) -> List[int]:
        labels_to_prune: List[int] = []
        for _, labels, criteria in iterate_over_gate_criteria(pruning_info):
            criteria: np.ndarray = np.array(criteria)
            prune_channels_num = int(len(criteria) * self.pruning_ratio)
            index_for_pruning = np.argsort(criteria)[:prune_channels_num]
            labels_to_prune += np.array(labels)[index_for_pruning].tolist()
        return labels_to_prune

Label selection:

class PruningLabelSelector[source]

Base class for all label selectors for pruning.

This class defines an abstract method PruningLabelSelector.select() which should return labels to prune.

__init__()
property labels_for_pruning: List[int]

Labels to prune.

Returns

List of all channel labels to prune.

Return type

list of int

Raises

RuntimeError – If labels were not calculated (PruningLabelSelector.select() method was not called).

abstract select(pruning_info)

Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns

List of channel labels which should be pruned.

Return type

list of int

class TopKPruningLabelSelector[source]

Bases: enot.pruning.label_selector.PruningLabelSelector

Base class for label selectors based on pruning config.

For description of pruning configs see here.

This class implements PruningLabelSelector.select() method by calling TopKPruningLabelSelector.get_config_for_pruning() abstract method to generate pruning config, and selects top-k least important channels from each group with k values based on pruning config.

__init__()
abstract get_config_for_pruning(pruning_info)

Abstract method which should implement Pruning config selection strategy.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns

Config for pruning.

Return type

PruningConfig

static get_labels_by_config(pruning_info, pruning_config)

Transforms Pruning config into channel labels to prune.

Parameters
Returns

Channel indices to prune (remove from the model).

Return type

list of int

property pruning_cfg: Optional[Union[Sequence[int], Sequence[float]]]

Returns current pruning config.

Returns

pruning_config – Config for pruning.

Return type

PruningConfig

select(pruning_info)

Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns

List of channel labels which should be pruned.

Return type

list of int

class UniformPruningLabelSelector(pruning_ratio)[source]

Bases: enot.pruning.label_selector.TopKPruningLabelSelector

Label selector for uniform (equal) pruning.

Removes an equal percentage of channels in each prunable layer.

__init__(pruning_ratio)
Parameters

pruning_ratio (float) – Percentage of prunable channels to remove in each group.

get_config_for_pruning(pruning_info)

Abstract method which should implement Pruning config selection strategy.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns

Config for pruning.

Return type

PruningConfig

class OptimalPruningLabelSelector(model, latency_calculation_function, target_latency, *, min_channels_to_leave_fn=<function default_min_channels_to_leave>, channel_search_step_fn=<function default_channel_search_step>, n_search_steps=200, **kwargs)

Label selector based on estimation of optimal Pruning config through bayesian optimization.

__init__(model, latency_calculation_function, target_latency, *, min_channels_to_leave_fn=<function default_min_channels_to_leave>, channel_search_step_fn=<function default_channel_search_step>, n_search_steps=200, **kwargs)
Parameters
  • model (torch.nn.Module) – Model to prune.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module) and measure its execution “speed” (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and other “speed” criteria.

  • target_latency (float) – Target model latency. This argument should have the same units as latency_calculation_function’s output. Latency value should be larger than latency of minimal model constructed according to minimal_channels_to_leave argument.

  • min_channels_to_leave_fn (Callable[[int], int], optional) – Function to construct minimal pruning configuration. This function should take the number of channels in specific layer and return the minimal number of channels to keep. Default value is default_min_channels_to_leave().

  • channel_search_step_fn (Callable[[int], int], optional) – Function to select search step for each group of channels. This search step defines the number of possible pruning ratios for a specific group. If search step is equal to 1 - then all number of channels after pruning are possible (with a lower bound set by min_channels_to_leave_fn). When search step is equal to 2 - the number of options is reduces by 2. This function should take the number of channels in specific layer and return search step. Default value is default_channel_search_step().

  • n_search_steps (int, optional) – Number of sampled configurations for pruning (equal to the number of latency_calculation_function executions) to select optimal architecture. Default value is 200.

  • **kwargs – Additional keyword arguments for label selector.

property baseline_latency: float

Baseline model latency.

Returns

Baseline model latency. Baseline model is the model provided to the constructor.

Return type

float

get_config_for_pruning(pruning_info)

Abstract method which should implement Pruning config selection strategy.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns

Config for pruning.

Return type

PruningConfig

static get_labels_by_config(pruning_info, pruning_config)

Transforms Pruning config into channel labels to prune.

Parameters
Returns

Channel indices to prune (remove from the model).

Return type

list of int

property labels_for_pruning: List[int]

Labels to prune.

Returns

List of all channel labels to prune.

Return type

list of int

Raises

RuntimeError – If labels were not calculated (PruningLabelSelector.select() method was not called).

property minimal_network_latency: float

Minimal model latency.

Returns

Minimal model latency. Minimal model is the model constructed according to the minimal_channels_to_leave argument in this class’s constructor.

Return type

float

property n_search_steps: int

Number of models to evaluate.

Returns

Number of models to evaluate during search procedure.

Return type

int

property pruning_cfg: Optional[Union[Sequence[int], Sequence[float]]]

Returns current pruning config.

Returns

pruning_config – Config for pruning.

Return type

PruningConfig

select(pruning_info)

Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns

List of channel labels which should be pruned.

Return type

list of int

Utility functional

ModelPruningInfo utilities:

iterate_over_gate_criteria(pruning_info)

Iterates over pruning gates and yields channel labels and channel criteria.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Yields
  • group_id (int) – Global index number of a gate.

  • labels (list of int) – Channel labels of a gate.

  • criteria (list of float) – Channel criteria of a gate.

Returns

Generator with items specified above.

Return type

Generator

get_criteria_label_dict(pruning_info)[source]

Collects information about channel importance values in human-readable format.

Parameters

pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

Returns

Dictionary with channel labels and importance values. It’s keys are channel group names, and values store both global channel indices (labels) and channel importance values.

Return type

Dict[str, Dict[str, numpy.ndarray]]

get_labels_for_uniform_pruning(pruning_info, pruning_ratio=0.5)[source]

Returns labels of least important channels of every prunable layer.

Parameters
Returns

List of least important channel labels.

Return type

list of int

Pruning config utilities:

check_pruning_config(pruning_config)[source]

Validates pruning config.

Parameters

pruning_config (PruningConfig) – Config for pruning.

Raises

ValueError – If pruning config is not valid.

Return type

None

get_least_important_labels_by_config(pruning_info, pruning_config)[source]

Extracts least important channel labels based on pruning config.

Parameters
Returns

Least important labels from all layers.

Return type

list of int

OptimalPruningLabelSelector utilities:

default_min_channels_to_leave(current_layer_channels, absolute_min_channels=16, relative_min_channels=0.25)[source]

Default function to specify the number of channels to keep by the total number of channels in a layer.

This function allows to set two boundaries for minimal number of channels in each layer: absolute and relative. Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels in a layer. Result is a maximal value among these two.

Parameters
  • current_layer_channels (int) – Number of channels in a layer.

  • absolute_min_channels (int, optional) – Absolute lower bound. If it is larger than current_layer_channels - then current_layer_channels is returned. Default value is 16.

  • relative_min_channels (float, optional) – Relative lower bound. Default value is 0.25.

Returns

The minimal number of channels to keep among current_layer_channels.

Return type

int

default_channel_search_step(current_layer_channels)[source]

Default function to specify search step for channel group by its size.

It returns step equal to 1 if channel group has size less or equal to 64, and returns step equal to 4 otherwise.

Parameters

current_layer_channels (int) – Number of channels in a layer.

Returns

Search step for Pruning config generation.

Return type

int