Pruning

The enot.pruning module contains functional for automatic pruning of user-models.

Pruning engine works fine with many models and frameworks like Torchvision, Pytorch Image Models (TIMM), OpenMMLab, and others.

This package features pre-defined or custom pruning procedures for removing least important filters or neurons. ENOT Pruning module currently supports structured pruning procedure. User can define pruning ratio (percentage of channels removed) manually or use any of the pre-defined strategies.

The first (and the simplest one) is equal pruning strategy, which keeps roughly the same percentage of channels in each prunable layer.

The second (and quite compute intense) is optimal pruning strategy, which searches for the optimal pruning ratios for each prunable layer.

A couple definitions to simplify documentation reading experience:

  • Pruning ratio is a fraction of channels to remove (prune) in prunable group or prunable layer.

  • Gate (or channel gate) is a special torch.nn.Module which gathers and saves channel “local importance”. “Local importance” means that you can compare channels by their importance values within single layer, but not between distinct layers.

  • Calibration procedure estimates channels importance values according to user model and data.

  • Channel label is the global channel index in a network.

  • Pruning config specifies pruning amount for each prunable layer. See more here.

Pruning API

Pruning functional accessible to user is divided into three sections:

High-level interface

High-level interface provides two ways to prune model. They work in the same manner, but have different policies for the amount of pruning for each layer:

calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]

Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer).

Parameters:
  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • pruning_ratio (float, optional) – Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4).

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader samples for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) – Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Model function for execution. See notes below for the detailed argument description. Default value is “forward”.

Returns:

Pruned model.

Return type:

torch.nn.Module

Note

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_and_prune_model_equal(...)

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: forward, execute_model, forward_train. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: submodule.forward, head.predict_features, submodule1.submodule2.forward.

calibrate_and_prune_model_optimal(model, dataloader, loss_function, latency_calculation_function, target_latency, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, channels_selection_constraint=None, n_search_steps=200, entry_point='forward', **kwargs)[source]

Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model.

Parameters:
  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.

  • target_latency (float) – Target model latency. This argument should have the same units as latency_calculation_function’s output.

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader samples for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • channels_selection_constraint (Optional[ChannelsSelectorConstraint]) – Optional ChannelsSelectorConstraint instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, then DefaultChannelsSelectorConstraint will be used. None by default.

  • n_search_steps (int, optional) – Number of sampled configurations for pruning to select optimal architecture. Default value is 200.

  • entry_point (str, optional) – Model function for execution. See notes in calibrate_and_prune_model_equal() for the detailed argument description. Default value is “forward”.

  • **kwargs – Additional keyword arguments for label selector.

Returns:

Pruned model.

Return type:

torch.nn.Module

Low-level interface

Low-level interface consists of:

The main class for pruning is ModelPruningInfo. It contains necessary information for model pruning: model execution graph, list of all prunable and non-prunable channels, all channel importance values, etc.

Model pruning info is obtained through calibration process. Calibration for pruning is made by PruningCalibrator class. This class’s instance should be used as a context manager. Inside it’s context user should compute losses and perform backward passes through his model on calibration data.

User can also calibrate his model by calling calibrate_model_for_pruning() function and provide all necessary arguments for proper model execution and loss computation.

Actual model pruning is performed by prune_model() function which requires pruning info object and list of channel labels to prune.

To select which labels should be pruned, and which should not, we created a special interface for label selection. Base class for all label selectors is PruningLabelSelector. This class defines an abstract function PruningLabelSelector.select() which should select labels to prune and which subclasses should implement to fit our pipelines based on calibrate_and_prune_model().

Abstract class ScorePruningLabelSelector is a straightforward way to select least important channel labels. It utilises Pruning config to select least important channels in each prunable layer. This is done by sorting channels by their importance values and selecting top-K least important channels based on the corresponding pruning ratio from the pruning config.

These two label selectors are used internally in high-level interface for pruning described above:

Pruning info

class ModelPruningInfo(graph, prunable_groups, unprunable_groups, channel_labels_by_nodes, node_users, gate_container, gate_names_by_locations, prunable_group_gate_names, gate_apply_info, child_labels_by_parents, int_pruners)

Container with model pruning-related information.

property n_prunable_groups: int

Number of prunable channel groups.

Returns:

Number of prunable channel groups.

Return type:

int

summary(level=0)

Returns string representation of useful information from pruning parser.

Parameters:

level (int, optional) – Degree of detail. Should be from 0 to 2 inclusive. Default value is 0.

Returns:

String with information about pruning.

Return type:

str

Calibration

class PruningCalibrator(model, entry_point='forward', max_prunable_labels=4096)

Context manager for gathering pruning information.

This context manager walks through model and collects information about prunable and unprunable channels and estimates channel importance values. The main purpose of this calibrator is to create a ModelPruningInfo object which can be later used to prune the model.

Model pruning info can be accessed via PruningCalibrator.pruning_info.

Examples

>>> p_calibrator = PruningCalibrator(model=baseline_model)
>>> with p_calibrator:
>>>     for inputs, labels in dataloader:
>>>         model_output = model(inputs)
>>>         loss = loss_function(model_output, labels)
>>>         loss.backward()
>>> p_calibrator.pruning_info
__init__(model, entry_point='forward', max_prunable_labels=4096)
Parameters:
  • model (torch.nn.Module) – Model which user wants to calibrate for pruning.

  • entry_point (str, optional) – Model function for execution. See note below for the detailed argument description. Default value is “forward”.

  • max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.

property pruning_info: ModelPruningInfo | None

Information about model pruning.

Returns:

Pruning-related information about model. None if calibrator instance was not used as a context manager or model was calibrated incorrectly.

Return type:

ModelPruningInfo or None

Note

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: forward, execute_model, forward_train. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: submodule.forward, head.predict_features, submodule1.submodule2.forward.

calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward', max_prunable_labels=4096)[source]

Estimates model channel importance values for later pruning.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importance values for later pruning.

Parameters:
  • model (torch.nn.Module) – Model to calibrate pruning gates for.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • n_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in epochs argument.

  • epochs (int, optional) – Number of total calibration epochs. Not used when n_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Name of model method for tracing. In this method all prunable operations should be forwarded. Default value is ‘forward’.

  • max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.

Returns:

Pruning information for later usage in pruning methods.

Return type:

ModelPruningInfo

Model pruning

prune_model(model, pruning_info, prune_labels, inplace=True)

Remove (prune) least important channels defined by prune_labels parameter.

Parameters:
  • model (torch.nn.Module) – Model for pruning.

  • pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

  • prune_labels (Sequence of Label) – Sequence of global channel index to prune.

  • inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).

Returns:

Pruned model.

Return type:

torch.nn.Module

The following function combines the functional of enot.pruning.calibrate_model_for_pruning() and enot.pruning.prune_model() functions:

calibrate_and_prune_model(label_selector, model, dataloader, loss_function, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]

Estimates channel importance values and prunes model with user defined strategy.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategy label_selector.

Parameters:
  • label_selector (PruningLabelSelector) – Channel selector object. This object should implement PruningLabelSelector.select() method which returns list with channel indices to prune.

  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Model function for execution. See Notes section for the detailed argument description. Default value is “forward”.

Returns:

Pruned model.

Return type:

torch.nn.Module

Note

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_and_prune_model_equal(...)

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: forward, execute_model, forward_train. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: submodule.forward, head.predict_features, submodule1.submodule2.forward.

Label selector

To select which labels should be pruned use one of the following selectors:

There are also two useful classes to help you implement your own label selector:

class UniformPruningLabelSelector(pruning_ratio)

Bases: ScorePruningLabelSelector

Label selector for uniform (equal) pruning.

Removes an equal percentage of channels in each prunable layer.

__init__(pruning_ratio)
Parameters:

pruning_ratio (float) – Ratio of prunable channels to remove in each group.

class GlobalPruningLabelSelectorByChannels(n_channels_or_ratio)

Bases: ScorePruningLabelSelector

Label selector for global pruning.

Selects the least important channels within network.

__init__(n_channels_or_ratio)
Parameters:

n_channels_or_ratio (Union[int, float]) – Number of channels or channels ratio to remove within all network. If the parameter is in the range (0, 1), it is interpreted as a fraction of all unique labels. If the parameter greater or equal to 1, it is interpreted as a number of labels to remove.

class GlobalLatencyPruningLabelSelector(target_latency, latency_calculation_function, selector_cb=None)

Bases: PruningLabelSelector, TargetLatencyMixin

Latency-based label selector.

It finds the model with latency as close as possible to the target latency parameter, and always selects the labels with the lowest scores.

__init__(target_latency, latency_calculation_function, selector_cb=None)
Parameters:
  • target_latency (float) – Target model latency. This argument should have the same units as output of latency_calculation_function.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.

  • selector_cb (Optional[Callable[[float, float], bool]]) – An optional callback that is called at each iteration of the search process. The callback should take the current latency as the first parameter and the target latency as the second parameter and return True if the search procedure should be stopped, False otherwise. Can be used for logging and early stop.

class KnapsackPruningLabelSelector(target_latency, latency_calculation_function, strict_label_order=True, channels_constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)

Bases: PruningLabelSelector, TargetLatencyMixin

Label selector based on knapsack algorithm.

We assume that the label score is value and latency is weight from the classical algorithm, so KnapsackPruningLabelSelector maximizes total score of the model with latency constraint. This algorithm cannot be used to find an exact solution because latency of labels have non-linear dependency, but the iterative approach gives good results. At each iteration, the label selector recalculates latency estimation for each label according to the selection at the previous iteration and solve knapsack problem with target latency constraint, until problem converged to this constraint.

__init__(target_latency, latency_calculation_function, strict_label_order=True, channels_constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
Parameters:
  • target_latency (float) – Target model latency. This argument should have the same units as output of latency_calculation_function.

  • latency_calculation_function (Callable[[nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.

  • strict_label_order (bool) – If True, then the labels within a gate are always selected in ascending order of score, otherwise any order of selection is acceptable. Default value is True.

  • channels_constraint (Optional[ChannelsSelectorConstraint]) – Optional ChannelsSelectorConstraint instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, then DummyChannelsSelectorConstraint will be used. None by default.

  • max_iterations (int) – Maximum number of iterations. Default value is 15.

  • solver_execution_time_limit (Optional[float]) – Time limit in seconds for solving a linear problem at each iteration. Default value is 60.

  • verbose (bool) – Whether to print selection statistics and show progress bar or not. True by default.

class OptimalPruningLabelSelector(target_latency, latency_calculation_function, n_search_steps, n_starting_points=None, batch_size=4, channels_constraint=None, additional_starting_points_generator=None, device=None, dtype=torch.float64, verbose=True, handle_ctrl_c=False, seed=None, scbo_kwargs=None)
__init__(target_latency, latency_calculation_function, n_search_steps, n_starting_points=None, batch_size=4, channels_constraint=None, additional_starting_points_generator=None, device=None, dtype=torch.float64, verbose=True, handle_ctrl_c=False, seed=None, scbo_kwargs=None)
Parameters:
  • target_latency (float) – Target model latency. This argument should have the same units as output of latency_calculation_function.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria. If latency measurement cannot be performed for a particular model configuration due to technical reasons (e.g., problems with measurements on a server), then latency_calculation_function should raise LatencyMeasurementError from enot.pruning.label_selector.

  • n_search_steps (Optional[int]) – Number of configurations (samples) for pruning to find optimal architecture. If None, then search loop will stop when SCBO converges.

  • n_starting_points (Optional[int]) – Number of sampled configurations used for startup step. If None, then n_starting_points calculates as min(total_number_of_gates_in_model * 2, n_search_steps - 1). None by default.

  • batch_size (int) – Batch size for SCBO algorithm. 4 by default.

  • channels_constraint (Optional[ChannelsSelectorConstraint]) – Optional ChannelsSelectorConstraint instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, then DefaultChannelsSelectorConstraint will be used. None by default.

  • additional_starting_points_generator (Optional[StartingPointsGenerator]) – User defined generator of additional starting points for SCBO algorithm. None by default.

  • device (Optional[Union[str, torch.device]]) – The desired device for SCBO algorithm. None by default (cuda if possible).

  • dtype (torch.dtype) – The desired dtype of SCBO algorithm. torch.double by default.

  • verbose (bool) – Whether to show progress bar or not. True by default.

  • handle_ctrl_c (bool) – Whether to allow early stop by Ctrl+c or not. False by default.

  • seed (Optional[int]) – Optional seed for SCBO algorithm. Default value is None.

  • scbo_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments for SCBO (developers only).

class PruningLabelSelector

Base class for all pruning label selectors.

This class defines an abstract method PruningLabelSelector.select() that should return labels to prune.

static extract_labels_info(pruning_info, sort_labels=True, drop_shared_gates=True, normalize_score=True, eps=1e-07)

Extract labels information from pruning_info.

Parameters:
  • pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

  • sort_labels (bool) – Whether to sort (ascending order) labels (sort_labels=True) according pruning criteria (score) within each group or do not sort (sort_labels=True). True by default.

  • drop_shared_gates (bool) – Whether to drop group of labels for repeated (shared) gates (drop_shared_gates=True) or add as is in output (drop_shared_gates=False). True by default.

  • normalize_score (bool) – Whether to normalize labels score by total score. True by default.

  • eps (float) – Coefficient for numerical stability for score normalization. Only used when normalize_score==True. Default is 1e-7.

Returns:

First position of tuple contains labels grouped by gates. Second position of tuple contains mapping of labels and pruning criteria value (score).

Return type:

Tuple[List[Tuple[Label, …]], Dict[Label, float]]

static get_most_important_labels(sorted_labels_by_gates)

Return set of the most important labels. The most important labels are labels with the highest score from every gate.

Parameters:

sorted_labels_by_gates (List[Tuple[Label, ...]]) – Labels grouped by gates and sorted by score (ascending order) within every gate group.

Returns:

Set of the most important labels.

Return type:

Set[Label]

abstract select(model, pruning_info)

Method that chooses which labels should be pruned based on current label selector policy.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:
Returns:

List of channel labels which should be pruned.

Return type:

list of Label

class ScorePruningLabelSelector

Bases: PruningLabelSelector

Base class for pruning label selectors, which select labels based on PruningLabelSelector.extract_labels_info() output.

This class defines an abstract method _select() that takes PruningLabelSelector.extract_labels_info() output and should return list of labels to prune.

abstract _select(labels_by_gates, labels_score)

select() implementation.

Parameters:
  • labels_by_gates (List[Tuple[Label, ...]]) – Labels grouped by gates.

  • labels_score (Dict[Label, float]) – Mapping of labels and pruning criteria value (score).

Returns:

List of labels to prune.

Return type:

List[Label]

select(model, pruning_info)

Method that chooses which labels should be pruned based on current label selector policy.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:
Returns:

List of channel labels which should be pruned.

Return type:

list of Label

Here is an example of simple label selector that only selects even labels:

import operator
from functools import reduce
from typing import List

import torch

from enot.pruning import ModelPruningInfo
from enot.pruning import PruningLabelSelector


class EvenPruningLabelSelector(PruningLabelSelector):
    def select(self, model: torch.nn.Module, pruning_info: ModelPruningInfo) -> List[int]:
        labels_by_gates, labels_by_score = self.extract_labels_info(pruning_info)
        # we don't need model and labels_by_score for our label selector
        del model, labels_by_score

        is_even = lambda label: label.label % 2 == 0
        labels = reduce(operator.concat, labels_by_gates)
        even_labels = [*filter(is_even, labels)]

        return even_labels

Channels Selector Constraint

class ChannelsSelectorConstraint

Base class for all channels selection constraint.

To implement new channel selection you should write implementation for _get().

abstract _get(number_of_channels)

Implementation of get(). It should return low, high and step for the constraint by total number of channels in gate.

  • low - minimal number of channels that should be pruned

  • high - maximal number of channels allowed to prune

  • step - step of discretization

Parameters:

number_of_channels (int) – Total number of channels in gate.

Returns:

Tuple of low, high and step for the constraint.

Return type:

Tuple[int, int, int]

apply(prune_n, number_of_channels)

Apply channels selector constraint.

Parameters:
  • prune_n (int) – Desired number of channels to prune.

  • number_of_channels (int) – Total number of channels in gate.

Returns:

Number of channels to prune according to constraint.

Return type:

int

get(number_of_channels)

Calculate low, high and step for the constraint by total number of channels in gate.

  • low - minimal number of channels that should be pruned

  • high - maximal number of channels allowed to prune

  • step - a step of discretization

Parameters:

number_of_channels (int) – Total number of channels in gate.

Returns:

Tuple of low, high and step for the constraint.

Return type:

Tuple[int, int, int]

class DefaultChannelsSelectorConstraint(absolute_min_channels=16, relative_min_channels=0.25, fixed_step=None)

Bases: ChannelsSelectorConstraint

Default channels selector constraint.

This class allows to set two boundaries for minimal number of channels in each layer: absolute and relative. Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels in a layer. Result is a maximal value among these two.

You can also specify a fixed value for the discretization step. If no fixed step is specified, we do the following:

  • get() returns step equal to 1 if channel group has size less than or equal to 64

  • get() returns step equal to 4 if channel group has size greater than 64

__init__(absolute_min_channels=16, relative_min_channels=0.25, fixed_step=None)
Parameters:
  • absolute_min_channels (int) – Absolute lower bound. If it is larger than number_of_channels, then number_of_channels is returned. Default value is 16.

  • relative_min_channels (float) – Relative lower bound. Default value is 0.25.

  • fixed_step (Optional[int]) – Optional fixed value for the discretization step. Default value is None.

Here you can see implementation of _get() method for DefaultChannelsSelectorConstraint:

    def _get(self, number_of_channels: int) -> Tuple[int, int, int]:
        high = max(self._absolute_min_channels, int(self._relative_min_channels * number_of_channels))
        high = number_of_channels - high
        high = max(0, high)

        if self._fixed_step is None:
            step = 4 if number_of_channels > 64 else 1
        else:
            step = self._fixed_step

        return 0, high, step

Tools

Below you can find a set of tools that can help refine information of our pruning structure.

Label printing conventions:

  • [--]:960:(75436, 76395) — group of 960 prunable labels with IDs from 75436 to 76395

  • [N-]:448:(3, 450) — group of 448 non-prunable labels with IDs from 3 to 450

The second - symbol means whether these label group are sequential or not.

tabulate_module_dependencies(module_name, pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)

Return a string containing all of the tracing/pruning information for the node corresponding to the module with the given name.

Parameters:
  • module_name (str) – Name of the module in torch model, for example: “features.5.conv.1.0”.

  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (Optional[nn.Module]) – PyTorch model used for calibration.

  • show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.

  • show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.

  • show_all_nodes (bool) – Include or not all types of nodes. Default value is False.

Returns:

Tracing/pruning information for the node corresponding to the module wit the given name.

Return type:

str

An example of usage:

print(tabulate_module_dependencies(module_name='features.17.conv.1.0', pruning_info=pruning_info, model=model))
+-------------------+-----------------------------------------------+
| module name       | features.17.conv.1.0                          |
+-------------------+-----------------------------------------------+
| node name         | operation_node_conv2d_456                     |
+-------------------+-----------------------------------------------+
| node operation    | conv2d                                        |
+-------------------+-----------------------------------------------+
| prunable groups   | 1                                             |
+-------------------+-----------------------------------------------+
| unprunable groups | 0                                             |
+-------------------+-----------------------------------------------+
| node labels       | [None, '[--]:960:(75436, 76395)', None, None] |
+-------------------+-----------------------------------------------+
Prunable group 0 with labels [--]:960:(75436, 76395)
+----------------------+---------------------------------------------------------------------------------------------+
| features.17.conv.1.0 | Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) |
+----------------------+---------------------------------------------------------------------------------------------+
| features.17.conv.0.0 | Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)                             |
+----------------------+---------------------------------------------------------------------------------------------+
Prunable group 0 users:
+--------------------+-----------------------------------------------------------------+
| features.17.conv.2 | Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False) |
+--------------------+-----------------------------------------------------------------+
tabulate_unprunable_groups(pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)

Return a string containing all of the tracing/pruning information for all unprunable groups.

Parameters:
  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (Optional[nn.Module]) – PyTorch model used for calibration.

  • show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.

  • show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.

  • show_all_nodes (bool) – Include or not all types of nodes. Default value is False.

Returns:

Tracing/pruning information for all unprunable groups.

Return type:

str

An example of usage:

print(tabulate_unprunable_groups(pruning_info=pruning_info, model=model, show_all_nodes=True))
Unprunable groups:
Unprunable group 0 with labels [N-]:3:(0, 2)
+----------------------+---------------------+
| parameter_node_199   | features.0.0.weight |
+----------------------+---------------------+
| placeholder_node_314 |                     |
+----------------------+---------------------+
Unprunable group 0 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
Unprunable group 1 with labels [N-]:448:(3, 450)
+----------------------+--+
| placeholder_node_314 |  |
+----------------------+--+
Unprunable group 1 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
Unprunable group 2 with labels [N-]:6:(486, 491)
+--------------------+---------------------+
| parameter_node_199 | features.0.0.weight |
+--------------------+---------------------+
Unprunable group 2 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
...
tabulate_label_dependencies(label, pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)

Return a string containing all of the tracing/pruning information for the label.

Parameters:
  • label (Union[Label, int]) – Label for which information is to be printed.

  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (Optional[nn.Module]) – PyTorch model used for calibration.

  • show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.

  • show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.

  • show_all_nodes (bool) – Include or not all types of nodes. Default value is False.

Returns:

Tracing/pruning information for the label.

Return type:

str

An example of usage:

print(tabulate_label_dependencies(label=88964, pruning_info=pruning_info, model=model, show_all_nodes=True))
label: 88964
+---------------------+------------------------------------------------------------------+
| parameter_node_94   | features.18.0.weight                                             |
+---------------------+------------------------------------------------------------------+
| features.18.0       | Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) |
+---------------------+------------------------------------------------------------------+
| parameter_node_197  | features.18.1.running_mean                                       |
+---------------------+------------------------------------------------------------------+
| parameter_node_151  | features.18.1.running_var                                        |
+---------------------+------------------------------------------------------------------+
| parameter_node_81   | features.18.1.weight                                             |
+---------------------+------------------------------------------------------------------+
| parameter_node_113  | features.18.1.bias                                               |
+---------------------+------------------------------------------------------------------+
| batch_norm          |                                                                  |
+---------------------+------------------------------------------------------------------+
| hardtanh            |                                                                  |
+---------------------+------------------------------------------------------------------+
| adaptive_avg_pool2d |                                                                  |
+---------------------+------------------------------------------------------------------+
| flatten             |                                                                  |
+---------------------+------------------------------------------------------------------+
| dropout             |                                                                  |
+---------------------+------------------------------------------------------------------+
| parameter_node_42   | classifier.1.weight                                              |
+---------------------+------------------------------------------------------------------+
tabulate_label_filters(label, pruning_info, model)

Return a string containing all of the label filters information.

Parameters:
  • label (Union[Label, int]) – The label for which filters are be found.

  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (nn.Module) – PyTorch model used for calibration.

Returns:

Label filters information.

Return type:

str

An example of usage:

print(tabulate_label_filters(label=88964, pruning_info=pruning_info, model=model))
label: 88964
+----------------------------+----------------+---------------------------------+
| tensor name                | shape          | mapping                         |
+============================+================+=================================+
| features.18.0.weight       | (1, 320, 1, 1) | [[(1202, 0)], None, None, None] |
+----------------------------+----------------+---------------------------------+
| features.18.1.running_mean | (1,)           | [[(1202, 0)]]                   |
+----------------------------+----------------+---------------------------------+
| features.18.1.running_var  | (1,)           | [[(1202, 0)]]                   |
+----------------------------+----------------+---------------------------------+
| features.18.1.weight       | (1,)           | [[(1202, 0)]]                   |
+----------------------------+----------------+---------------------------------+
| features.18.1.bias         | (1,)           | [[(1202, 0)]]                   |
+----------------------------+----------------+---------------------------------+
| classifier.1.weight        | (10, 1)        | [None, [(1202, 0)]]             |
+----------------------------+----------------+---------------------------------+
get_label_filters(label, pruning_info, model, *, with_tensors=False)

Return filters that corresponding to passed label.

Parameters:
  • label (Union[Label, int]) – The label for which filters are be found.

  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (nn.Module) – PyTorch model used for calibration.

  • with_tensors (bool) – Include filter values (tensors) or not. Default value is False.

Returns:

List of containers with filter tensor name, shape, mapping and filter value (optionally).

Return type:

List[LabelFilterInfo]

An example of usage:

for filter_ in get_label_filters(label=label, pruning_info=pruning_info, model=model):
    print(filter_)
LabelFilterInfo(tensor_name='features.18.0.weight', shape=(1, 320, 1, 1), mapping=[[(1202, 0)], None, None, None], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.running_mean', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.running_var', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.weight', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.bias', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='classifier.1.weight', shape=(10, 1), mapping=[None, [(1202, 0)]], tensor=None)

here shape — is the shape of the filter, and mapping can be used to extract the desired filter from tensor.

get_tracing_node_by_module_name(name, graph)

Return the tracing node corresponding to the module with the given name.

Note: current implementation works ONLY for modules with “weight” parameter.

Parameters:
  • name (str) – Name of module in torch model, for example: “features.5.conv.1.0”.

  • graph (List[Node]) – Tracing graph.

Returns:

Node (OpNode) corresponding to the module with the given name.

Return type:

OpNode

Raises:

ValueError – If node corresponding to the module name is not found.

get_module_name_by_tracing_node(node)

Return the module name corresponding to the node.

Note: current implementation works ONLY for modules with “weight” parameter. Also, you can easy get a module by its name with the help of getattr_complex.

Parameters:

node (OpNode) – The node for which the module name should be found.

Returns:

Module name corresponding to the node.

Return type:

str

Raises:

ValueError – If module name corresponding to the node is not found.

Converting dataloader items to PyTorch model inputs

Short summary: When ENOT needs to pass items from your dataloader to your model, it requires a special mapping function. This function will be used as following:

# Your code:
your_model = ...
your_dataloader = ...
your_mapping_function = ...

# In the ENOT framework:
sample = next(your_dataloader)
model_args, model_kwargs = your_mapping_function(sample)
model_result = your_model(*model_args, **model_kwargs)

When working with ENOT framework, sometimes it is necessary to write functions, which can transform dataloader items to model input format. We can later use these functions to extract dataloader items with the __next__ method and convert them into user’s custom model’s input format.

Such function should take a single input of type Any - a single item from user dataloader. It should return a tuple with two elements. First element is a tuple with model positional arguments, which will be passed to model’s __call__ method by unpacking operator *args. Second element is a dictionary with string keys and any-type values, which defines model keyword arguments, and will be passed to it’s __call__ method by unpacking operator **kwargs.

Let’s see the most basic example:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor)

def my_conversion_function(x):
    return (x[0], ), {}  # Single positional argument for model, no keyword arguments.

The same functionality is provided by our function default_sample_to_model_inputs().

Example function which changes model’s default forward keyword argument value:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor, should_profile: bool = True)

def my_conversion_function(x):
    return (x[0], ), {'should_profile': False}  # We do not want to profile!

Example function which also performs data pre-processing and moves it to the GPU:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)
Model expects: (images: torch.Tensor)

def my_conversion_function(x):
    # Normalizing images, casting them to float32 format, and move to cuda.
    return ((x[0].float()) / 255).cuda(), ), {}

Advanced case with complex dataloader item structure and model positional and keyword arguments:

DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'masks': List[torch.Tensor]}
Model expects: (sequence: torch.Tensor, mask: Optional[torch.Tensor] = None, unroll_attention: bool = False)

def my_conversion_function(x):
    sequence = x['sequence']
    mask = x['masks'][0]
    return (sequence, ), {'mask': mask, 'unroll_attention': True}
default_sample_to_model_inputs(sample)

Default function for sample to model input conversion.

This function covers the most simple case when dataloader returns pairs of images and labels in the form (images, labels). If the model receives a single positional argument, then you can use this function to convert dataloader output to model input.

Parameters:

sample (tuple) – Tuple with one or more elements from which only the first item will be passed to the model. So, user-defined model must have a single positional argument in it’s forward function definition.

Returns:

Model input args and kwargs, see more here.

Return type:

tuple with two items - tuple and dict with str keys

Extracting the number of samples from dataloader item

Sometimes ENOT needs to know the number of samples of interest in dataloader item. Sample of interest could be an image, one second of audio, number of bounding boxes, e.t.c. What single sample means fully depends on you. To provide this information to ENOT, you should define a special function. This function will be used as following:

# Your code:
your_model = ...
your_dataloader = ...
your_function = ...

# In the ENOT framework:
sample = next(your_dataloader)
n_base_samples: int = your_function(sample)
# process n_base_samples

Such function should take a single input of type Any - a single item from user dataloader. It should return single integer value - number of samples of interest in dataloader sample.

Let’s see the most basic example:

DataLoader returns: (images: torch.Tensor, labels: torch.Tensor)

def my_function(x):
    # Extract the number of images from the batch dimension.
    return x[0].shape[0]

The same functionality is provided by our function default_sample_to_n_samples().

Advanced case with complex dataloader item structure:

DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'mask': torch.Tensor}

def my_conversion_function(x):
    # sequence is a tensor of shape (N, T1) with int64 dtype.
    # translation is a tensor of shape (N, T2) with int64 dtype.
    # mask is a tensor of shape (N, T1) with int64 dtype.
    # mask is equal to 1 where sequence values are correct (not padded) and 0 elsewhere.
    total_time_steps = x['mask'].sum()
    total_seconds = total_time_steps // 100  # Let's suppose that one time step is 10ms.
    return total_seconds  # Total audio length in seconds in all samples.
default_sample_to_n_samples(sample)

Default function to extract the number of samples of interest from dataloader item.

This function covers the most simple case when dataloader returns pairs of images and labels in the form (images, labels).

Parameters:

sample (tuple) – Tuple with one or more elements where the first element is a tensor with the batch dimension axis equal to 0.

Returns:

Number of samples of interest in dataloader item, see more here.

Return type:

int