Pruning

enot.pruning module contains functional for automatic pruning of user-models.

Pruning engine works fine with many models and frameworks like Torchvision, Pytorch Image Models (TIMM), OpenMMLab, and others.

This package features pre-defined or custom pruning procedures for removing least important filters or neurons. ENOT Pruning module currently supports structured pruning procedure. User can define pruning ratio (percentage of channels removed) manually or use any of the pre-defined strategies.

The first (and the simplest one) is equal pruning strategy, which keeps roughly the same percentage of channels in each prunable layer.

The second (and quite compute intense) is optimal pruning strategy, which searches for the optimal pruning ratios for each prunable layer.

A couple definitions to simplify documentation reading experience:

  • Pruning ratio is a fraction of channels to remove (prune) in prunable group or prunable layer.

  • Gate (or channel gate) is a special torch.nn.Module which gathers and saves channel “local importance”. “Local importance” means that you can compare channels by their importance values within single layer, but not between distinct layers.

  • Calibration procedure estimates channels importance values according to user model and data.

  • Channel label is the global channel index in a network.

  • Pruning config specifies pruning amount for each prunable layer. See more here.

Pruning API

Pruning functional accessible to user is divided into three sections:

High-level interface

High-level interface provides two ways to prune model. They work in the same manner, but have different policies for the amount of pruning for each layer:

calibrate_and_prune_model_equal(model, dataloader, loss_function, pruning_ratio=0.5, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]

Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer).

Parameters:
  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • pruning_ratio (float, optional) – Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4).

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader samples for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) – Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Model function for execution. See notes below for the detailed argument description. Default value is “forward”.

Returns:

Pruned model.

Return type:

torch.nn.Module

Note

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_and_prune_model_equal(...)

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: forward, execute_model, forward_train. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: submodule.forward, head.predict_features, submodule1.submodule2.forward.

calibrate_and_prune_model_optimal(model, dataloader, loss_function, latency_calculation_function, target_latency, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, channels_selection_constraint=None, n_search_steps=200, entry_point='forward', **kwargs)[source]

Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model.

Parameters:
  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.

  • target_latency (float) – Target model latency. This argument should have the same units as latency_calculation_function’s output.

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader samples for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • channels_selection_constraint (Optional[ChannelsSelectorConstraint]) – Optional ChannelsSelectorConstraint instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, then DefaultChannelsSelectorConstraint will be used. None by default.

  • n_search_steps (int, optional) – Number of sampled configurations for pruning to select optimal architecture. Default value is 200.

  • entry_point (str, optional) – Model function for execution. See notes in calibrate_and_prune_model_equal() for the detailed argument description. Default value is “forward”.

  • **kwargs – Additional keyword arguments for label selector.

Returns:

Pruned model.

Return type:

torch.nn.Module

Low-level interface

Low-level interface consists of:

The main class for pruning is ModelPruningInfo. It contains necessary information for model pruning: model execution graph, list of all prunable and non-prunable channels, all channel importance values, etc.

Model pruning info is obtained through calibration process. Calibration for pruning is made by PruningCalibrator class. This class’s instance should be used as a context manager. Inside it’s context user should compute losses and perform backward passes through his model on calibration data.

User can also calibrate his model by calling calibrate_model_for_pruning() function and provide all necessary arguments for proper model execution and loss computation.

Actual model pruning is performed by prune_model() function which requires pruning info object and list of channel labels to prune.

To select which labels should be pruned, and which should not, we created a special interface for label selection. Base class for all label selectors is PruningLabelSelector. This class defines an abstract function PruningLabelSelector.select() which should select labels to prune and which subclasses should implement to fit our pipelines based on calibrate_and_prune_model().

Abstract class ScorePruningLabelSelector is a straightforward way to select least important channel labels. It utilises Pruning config to select least important channels in each prunable layer. This is done by sorting channels by their importance values and selecting top-K least important channels based on the corresponding pruning ratio from the pruning config.

These two label selectors are used internally in high-level interface for pruning described above:

Pruning info

class ModelPruningInfo(graph, prunable_groups, unprunable_groups, channel_labels_by_nodes, node_users, gate_container, gate_names_by_locations, prunable_group_gate_names, gate_apply_info, child_labels_by_parents, int_pruners)

Container with model pruning-related information.

property n_prunable_groups: int

Number of prunable channel groups.

Returns:

Number of prunable channel groups.

Return type:

int

summary(level=0)

Returns string representation of useful information from pruning parser.

Parameters:

level (int, optional) – Degree of detail. Should be from 0 to 2 inclusive. Default value is 0.

Returns:

String with information about pruning.

Return type:

str

Calibration

class PruningCalibrator(model, entry_point='forward', max_prunable_labels=4096)

Context manager for gathering pruning information.

This context manager walks through model and collects information about prunable and unprunable channels and estimates channel importance values. The main purpose of this calibrator is to create a ModelPruningInfo object which can be later used to prune the model.

Model pruning info can be accessed via PruningCalibrator.pruning_info.

Examples

>>> p_calibrator = PruningCalibrator(model=baseline_model)
>>> with p_calibrator:
>>>     for inputs, labels in dataloader:
>>>         model_output = model(inputs)
>>>         loss = loss_function(model_output, labels)
>>>         loss.backward()
>>> p_calibrator.pruning_info
__init__(model, entry_point='forward', max_prunable_labels=4096)
Parameters:
  • model (torch.nn.Module) – Model which user wants to calibrate for pruning.

  • entry_point (str, optional) – Model function for execution. See note below for the detailed argument description. Default value is “forward”.

  • max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.

property pruning_info: ModelPruningInfo | None

Information about model pruning.

Returns:

Pruning-related information about model. None if calibrator instance was not used as a context manager or model was calibrated incorrectly.

Return type:

ModelPruningInfo or None

Note

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: forward, execute_model, forward_train. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: submodule.forward, head.predict_features, submodule1.submodule2.forward.

calibrate_model_for_pruning(model, dataloader, loss_function, n_steps=None, epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward', max_prunable_labels=4096)[source]

Estimates model channel importance values for later pruning.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importance values for later pruning.

Parameters:
  • model (torch.nn.Module) – Model to calibrate pruning gates for.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • n_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in epochs argument.

  • epochs (int, optional) – Number of total calibration epochs. Not used when n_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Name of model method for tracing. In this method all prunable operations should be forwarded. Default value is ‘forward’.

  • max_prunable_labels (int, optional) – Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096.

Returns:

Pruning information for later usage in pruning methods.

Return type:

ModelPruningInfo

Model pruning

prune_model(model, pruning_info, prune_labels, inplace=True)

Remove (prune) least important channels defined by prune_labels parameter.

Parameters:
  • model (torch.nn.Module) – Model for pruning.

  • pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

  • prune_labels (Sequence of Label) – Sequence of global channel index to prune.

  • inplace (bool) – Whether to perform pruning in-place (modify the model itself without creating a copy).

Returns:

Pruned model.

Return type:

torch.nn.Module

The following function combines the functional of enot.pruning.calibrate_model_for_pruning() and enot.pruning.prune_model() functions:

calibrate_and_prune_model(label_selector, model, dataloader, loss_function, finetune_bn=False, calibration_steps=None, calibration_epochs=1, sample_to_n_samples=<function default_sample_to_n_samples>, sample_to_model_inputs=<function default_sample_to_model_inputs>, show_tqdm=False, entry_point='forward')[source]

Estimates channel importance values and prunes model with user defined strategy.

This function searches for prunable channels in user-defined model. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategy label_selector.

Parameters:
  • label_selector (PruningLabelSelector) – Channel selector object. This object should implement PruningLabelSelector.select() method which returns list with channel indices to prune.

  • model (torch.nn.Module) – Model to prune.

  • dataloader (torch.utils.data.DataLoader) – Dataloader for estimation of model channel importance values.

  • loss_function (Callable[[Any, Any], torch.Tensor]) – Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor.

  • finetune_bn (bool, optional) – Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False.

  • calibration_steps (int or None, optional) – Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in calibration_epochs argument.

  • calibration_epochs (int, optional) – Number of total calibration epochs. Not used when calibration_steps argument is not None. Default value is 1.

  • sample_to_n_samples (Callable, optional) – Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is default_sample_to_n_samples(). See more here.

  • sample_to_model_inputs (Callable, optional) – Function to map dataloader samples to model input format. Default value is default_sample_to_model_inputs(). See more here.

  • show_tqdm (bool, optional) –

    Whether to log calibration procedure with tqdm progress bar. Default value is False.

  • entry_point (str, optional) – Model function for execution. See Notes section for the detailed argument description. Default value is “forward”.

Returns:

Pruned model.

Return type:

torch.nn.Module

Note

Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call eval method of your model if your inference requires calling this method (e.g. when the model contains dropout layers).

Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing).

If you encounter errors during backward call, you can wrap this function call with the following statement:

with torch.autograd.set_detect_anomaly(True):
    calibrate_and_prune_model_equal(...)

entry_point is a string that specifies which function of the user’s model will be traced. Simple examples: forward, execute_model, forward_train. If such a function is located in the model’s submodule - then you should first write the submodule’s name followed by the function name, all separated by dots: submodule.forward, head.predict_features, submodule1.submodule2.forward.

Label selector

To select which labels should be pruned use one of the following selectors:

There are also two useful classes to help you implement your own label selector:

class UniformPruningLabelSelector(pruning_ratio)

Bases: ScorePruningLabelSelector

Label selector for uniform (equal) pruning.

Removes an equal percentage of channels in each prunable layer.

__init__(pruning_ratio)
Parameters:

pruning_ratio (float) – Ratio of prunable channels to remove in each group.

class GlobalPruningLabelSelectorByChannels(n_channels_or_ratio)

Bases: ScorePruningLabelSelector

Label selector for global pruning.

Selects the least important channels within network.

__init__(n_channels_or_ratio)
Parameters:

n_channels_or_ratio (Union[int, float]) – Number of channels or channels ratio to remove within all network. If the parameter is in the range (0, 1), it is interpreted as a fraction of all unique labels. If the parameter greater or equal to 1, it is interpreted as a number of labels to remove.

class LatencyPruningLabelSelector(target_latency, latency_calculation_function, selector_cb=None)

Bases: PruningLabelSelector

Latency-based label selector.

It finds the model with latency as close as possible to the target latency parameter, and always selects the labels with the lowest scores.

__init__(target_latency, latency_calculation_function, selector_cb=None)
Parameters:
  • target_latency (float) – Target model latency. This argument should have the same units as output of latency_calculation_function.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.

  • selector_cb (Optional[Callable[[float, float], bool]]) – An optional callback that is called at each iteration of the search process. The callback should take the current latency as the first parameter and the target latency as the second parameter and return True if the search procedure should be stopped, False otherwise. Can be used for logging and early stop.

class OptimalPruningLabelSelector(target_latency, latency_calculation_function, n_search_steps, latency_penalty=300, channels_constraint=None, n_jobs=1, show_progress_bar=False, verbose=True, seed=None)
__init__(target_latency, latency_calculation_function, n_search_steps, latency_penalty=300, channels_constraint=None, n_jobs=1, show_progress_bar=False, verbose=True, seed=None)
Parameters:
  • target_latency (float) – Target model latency. This argument should have the same units as output of latency_calculation_function.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.

  • n_search_steps (int) – Number of configurations (samples) for pruning to find optimal architecture.

  • latency_penalty (float) – Weight of the latency in objective function. Default value 300.

  • channels_constraint (Optional[ChannelsSelectorConstraint]) – Optional ChannelsSelectorConstraint instance, that calculates low, high and step for the channels constraint for total number of channels in gate. If None, then DefaultChannelsSelectorConstraint will be used. None by default.

  • n_jobs (int) – Number of parallel jobs. Default value is 1.

  • show_progress_bar (bool) – Show progress bar or not. False by default.

  • verbose (bool) – If True then default logger callback will duplicate messages in stdout. Default value is True.

  • seed (Optional[int]) – Optional seed for sampler. Default value is None.

class PruningLabelSelector

Base class for all pruning label selectors.

This class defines an abstract method PruningLabelSelector.select() that should return labels to prune.

static extract_labels_info(pruning_info, sort_labels=True, drop_shared_gates=True)

Extract labels information from pruning_info.

Parameters:
  • pruning_info (ModelPruningInfo) – Pruning-related information collected during calibration stage (see more in calibrate_model_for_pruning()).

  • sort_labels (bool) – Weather to sort (ascending order) labels (sort_labels=True) according pruning criteria (score) within each group or do not sort (sort_labels=True). True by default.

  • drop_shared_gates (bool) – Weather to drop group of labels for repeated (shared) gates (drop_shared_gates=True) or add as is in output (drop_shared_gates=False). True by default.

Returns:

First position of tuple contains labels grouped by gates. Second position of tuple contains mapping of labels and pruning criteria value (score).

Return type:

Tuple[List[Tuple[Label, …]], Dict[Label, float]]

static get_most_important_labels(sorted_labels_by_gates)

Return set of the most important labels. The most important labels are labels with the highest score from every gate.

Parameters:

sorted_labels_by_gates (List[Tuple[Label, ...]]) – Labels grouped by gates and sorted by score (ascending order) within every gate group.

Returns:

Set of the most important labels.

Return type:

Set[Label]

abstract select(model, pruning_info)

Method that chooses which labels should be pruned based on current label selector policy.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:
Returns:

List of channel labels which should be pruned.

Return type:

list of Label

class ScorePruningLabelSelector

Bases: PruningLabelSelector

Base class for pruning label selectors, which select labels based on PruningLabelSelector.extract_labels_info() output.

This class defines an abstract method _select() that takes PruningLabelSelector.extract_labels_info() output and should return list of labels to prune.

abstract _select(labels_by_gates, labels_score)

select() implementation.

Parameters:
  • labels_by_gates (List[Tuple[Label, ...]]) – Labels grouped by gates.

  • labels_score (Dict[Label, float]) – Mapping of labels and pruning criteria value (score).

Returns:

List of labels to prune.

Return type:

List[Label]

select(model, pruning_info)

Method that chooses which labels should be pruned based on current label selector policy.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:
Returns:

List of channel labels which should be pruned.

Return type:

list of Label

Here is an example of simple label selector that only selects even labels:

import operator
from functools import reduce
from typing import List

import torch

from enot.pruning import ModelPruningInfo
from enot.pruning import PruningLabelSelector


class EvenPruningLabelSelector(PruningLabelSelector):
    def select(self, model: torch.nn.Module, pruning_info: ModelPruningInfo) -> List[int]:
        labels_by_gates, labels_by_score = self.extract_labels_info(pruning_info)
        # we don't need model and labels_by_score for our label selector
        del model, labels_by_score

        is_even = lambda label: label.label % 2 == 0
        labels = reduce(operator.concat, labels_by_gates)
        even_labels = [*filter(is_even, labels)]

        return even_labels

Channels Selector Constraint

class ChannelsSelectorConstraint

Base class for all channels selection constraint.

To implement new channel selection you should write implementation for _get().

abstract _get(number_of_channels)

Implementation of get(). It should return low, high and step for the constraint by total number of channels in gate.

  • low - minimal number of channels that should be pruned

  • high - maximal number of channels allowed to prune

  • step - step of discretization

Parameters:

number_of_channels (int) – Total number of channels in gate.

Returns:

Tuple of low, high and step for the constraint.

Return type:

Tuple[int, int, int]

apply(prune_n, number_of_channels)

Apply channels selector constraint.

Parameters:
  • prune_n (int) – Desired number of channels to prune.

  • number_of_channels (int) – Total number of channels in gate.

Returns:

Number of channels to prune according to constraint.

Return type:

int

get(number_of_channels)

Calculate low, high and step for the constraint by total number of channels in gate.

  • low - minimal number of channels that should be pruned

  • high - maximal number of channels allowed to prune

  • step - a step of discretization

Parameters:

number_of_channels (int) – Total number of channels in gate.

Returns:

Tuple of low, high and step for the constraint.

Return type:

Tuple[int, int, int]

class DefaultChannelsSelectorConstraint(absolute_min_channels=16, relative_min_channels=0.25, fixed_step=None)

Bases: ChannelsSelectorConstraint

Default channels selector constraint.

This class allows to set two boundaries for minimal number of channels in each layer: absolute and relative. Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels in a layer. Result is a maximal value among these two.

You can also specify a fixed value for the discretization step. If no fixed step is specified, we do the following:

  • get() returns step equal to 1 if channel group has size less than or equal to 64

  • get() returns step equal to 4 if channel group has size greater than 64

__init__(absolute_min_channels=16, relative_min_channels=0.25, fixed_step=None)
Parameters:
  • absolute_min_channels (int) – Absolute lower bound. If it is larger than number_of_channels, then number_of_channels is returned. Default value is 16.

  • relative_min_channels (float) – Relative lower bound. Default value is 0.25.

  • fixed_step (Optional[int]) – Optional fixed value for the discretization step. Default value is None.

Here you can see implementation of _get() method for DefaultChannelsSelectorConstraint:

    def _get(self, number_of_channels: int) -> Tuple[int, int, int]:
        high = max(self._absolute_min_channels, int(self._relative_min_channels * number_of_channels))
        high = number_of_channels - high
        high = max(0, high)

        if self._fixed_step is None:
            step = 4 if number_of_channels > 64 else 1
        else:
            step = self._fixed_step

        return 0, high, step

Tools

Below you can find a set of tools that can help refine information of our pruning structure.

Label printing conventions:

  • [--]:960:(75436, 76395) — group of 960 prunable labels with IDs from 75436 to 76395

  • [N-]:448:(3, 450) — group of 448 non-prunable labels with IDs from 3 to 450

The second - symbol means whether these label group are sequential or not.

tabulate_module_dependencies(module_name, pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)

Return a string containing all of the tracing/pruning information for the node corresponding to the module with the given name.

Parameters:
  • module_name (str) – Name of the module in torch model, for example: “features.5.conv.1.0”.

  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (Optional[nn.Module]) – PyTorch model used for calibration.

  • show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.

  • show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.

  • show_all_nodes (bool) – Include or not all types of nodes. Default value is False.

Returns:

Tracing/pruning information for the node corresponding to the module wit the given name.

Return type:

str

An example of usage:

print(tabulate_module_dependencies(module_name='features.17.conv.1.0', pruning_info=pruning_info, model=model))
+-------------------+-----------------------------------------------+
| module name       | features.17.conv.1.0                          |
+-------------------+-----------------------------------------------+
| node name         | operation_node_conv2d_456                     |
+-------------------+-----------------------------------------------+
| node operation    | conv2d                                        |
+-------------------+-----------------------------------------------+
| prunable groups   | 1                                             |
+-------------------+-----------------------------------------------+
| unprunable groups | 0                                             |
+-------------------+-----------------------------------------------+
| node labels       | [None, '[--]:960:(75436, 76395)', None, None] |
+-------------------+-----------------------------------------------+
Prunable group 0 with labels [--]:960:(75436, 76395)
+----------------------+---------------------------------------------------------------------------------------------+
| features.17.conv.1.0 | Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) |
+----------------------+---------------------------------------------------------------------------------------------+
| features.17.conv.0.0 | Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)                             |
+----------------------+---------------------------------------------------------------------------------------------+
Prunable group 0 users:
+--------------------+-----------------------------------------------------------------+
| features.17.conv.2 | Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False) |
+--------------------+-----------------------------------------------------------------+
tabulate_unprunable_groups(pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)

Return a string containing all of the tracing/pruning information for all unprunable groups.

Parameters:
  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (Optional[nn.Module]) – PyTorch model used for calibration.

  • show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.

  • show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.

  • show_all_nodes (bool) – Include or not all types of nodes. Default value is False.

Returns:

Tracing/pruning information for all unprunable groups.

Return type:

str

An example of usage:

print(tabulate_unprunable_groups(pruning_info=pruning_info, model=model, show_all_nodes=True))
Unprunable groups:
Unprunable group 0 with labels [N-]:3:(0, 2)
+----------------------+---------------------+
| parameter_node_199   | features.0.0.weight |
+----------------------+---------------------+
| placeholder_node_314 |                     |
+----------------------+---------------------+
Unprunable group 0 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
Unprunable group 1 with labels [N-]:448:(3, 450)
+----------------------+--+
| placeholder_node_314 |  |
+----------------------+--+
Unprunable group 1 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
Unprunable group 2 with labels [N-]:6:(486, 491)
+--------------------+---------------------+
| parameter_node_199 | features.0.0.weight |
+--------------------+---------------------+
Unprunable group 2 users:
+--------------+------------------------------------------------------------------------------+
| features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) |
+--------------+------------------------------------------------------------------------------+
...
tabulate_label_dependencies(label, pruning_info, model=None, show_op_with_weights_only=True, show_parameter_nodes=False, show_all_nodes=False)

Return a string containing all of the tracing/pruning information for the label.

Parameters:
  • label (Union[Label, int]) – Label for which information is to be printed.

  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (Optional[nn.Module]) – PyTorch model used for calibration.

  • show_op_with_weights_only (bool) – Include nodes with weights (modules) only. Default value is True.

  • show_parameter_nodes (bool) – Include or not Parameter nodes. Default value is False.

  • show_all_nodes (bool) – Include or not all types of nodes. Default value is False.

Returns:

Tracing/pruning information for the label.

Return type:

str

An example of usage:

print(tabulate_label_dependencies(label=88964, pruning_info=pruning_info, model=model, show_all_nodes=True))
label: 88964
+---------------------+------------------------------------------------------------------+
| parameter_node_94   | features.18.0.weight                                             |
+---------------------+------------------------------------------------------------------+
| features.18.0       | Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) |
+---------------------+------------------------------------------------------------------+
| parameter_node_197  | features.18.1.running_mean                                       |
+---------------------+------------------------------------------------------------------+
| parameter_node_151  | features.18.1.running_var                                        |
+---------------------+------------------------------------------------------------------+
| parameter_node_81   | features.18.1.weight                                             |
+---------------------+------------------------------------------------------------------+
| parameter_node_113  | features.18.1.bias                                               |
+---------------------+------------------------------------------------------------------+
| batch_norm          |                                                                  |
+---------------------+------------------------------------------------------------------+
| hardtanh            |                                                                  |
+---------------------+------------------------------------------------------------------+
| adaptive_avg_pool2d |                                                                  |
+---------------------+------------------------------------------------------------------+
| flatten             |                                                                  |
+---------------------+------------------------------------------------------------------+
| dropout             |                                                                  |
+---------------------+------------------------------------------------------------------+
| parameter_node_42   | classifier.1.weight                                              |
+---------------------+------------------------------------------------------------------+
tabulate_label_filters(label, pruning_info, model)

Return a string containing all of the label filters information.

Parameters:
  • label (Union[Label, int]) – The label for which filters are be found.

  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (nn.Module) – PyTorch model used for calibration.

Returns:

Label filters information.

Return type:

str

An example of usage:

print(tabulate_label_filters(label=88964, pruning_info=pruning_info, model=model))
label: 88964
+----------------------------+----------------+---------------------------------+
| tensor name                | shape          | mapping                         |
+============================+================+=================================+
| features.18.0.weight       | (1, 320, 1, 1) | [[(1202, 0)], None, None, None] |
+----------------------------+----------------+---------------------------------+
| features.18.1.running_mean | (1,)           | [[(1202, 0)]]                   |
+----------------------------+----------------+---------------------------------+
| features.18.1.running_var  | (1,)           | [[(1202, 0)]]                   |
+----------------------------+----------------+---------------------------------+
| features.18.1.weight       | (1,)           | [[(1202, 0)]]                   |
+----------------------------+----------------+---------------------------------+
| features.18.1.bias         | (1,)           | [[(1202, 0)]]                   |
+----------------------------+----------------+---------------------------------+
| classifier.1.weight        | (10, 1)        | [None, [(1202, 0)]]             |
+----------------------------+----------------+---------------------------------+
get_label_filters(label, pruning_info, model, *, with_tensors=False)

Return filters that corresponding to passed label.

Parameters:
  • label (Union[Label, int]) – The label for which filters are be found.

  • pruning_info (ModelPruningInfo) – Pruning state, can be obtained from calibrator.

  • model (nn.Module) – PyTorch model used for calibration.

  • with_tensors (bool) – Include filter values (tensors) or not. Default value is False.

Returns:

List of containers with filter tensor name, shape, mapping and filter value (optionally).

Return type:

List[LabelFilterInfo]

An example of usage:

for filter_ in get_label_filters(label=label, pruning_info=pruning_info, model=model):
    print(filter_)
LabelFilterInfo(tensor_name='features.18.0.weight', shape=(1, 320, 1, 1), mapping=[[(1202, 0)], None, None, None], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.running_mean', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.running_var', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.weight', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='features.18.1.bias', shape=(1,), mapping=[[(1202, 0)]], tensor=None)
LabelFilterInfo(tensor_name='classifier.1.weight', shape=(10, 1), mapping=[None, [(1202, 0)]], tensor=None)

here shape — is the shape of the filter, and mapping can be used to extract the desired filter from tensor.

get_tracing_node_by_module_name(name, graph)

Return the tracing node corresponding to the module with the given name.

Note: current implementation works ONLY for modules with “weight” parameter.

Parameters:
  • name (str) – Name of module in torch model, for example: “features.5.conv.1.0”.

  • graph (List[Node]) – Tracing graph.

Returns:

Node (OpNode) corresponding to the module with the given name.

Return type:

OpNode

Raises:

ValueError – If node corresponding to the module name is not found.

get_module_name_by_tracing_node(node)

Return the module name corresponding to the node.

Note: current implementation works ONLY for modules with “weight” parameter. Also, you can easy get a module by its name with the help of getattr_complex.

Parameters:

node (OpNode) – The node for which the module name should be found.

Returns:

Module name corresponding to the node.

Return type:

str

Raises:

ValueError – If module name corresponding to the node is not found.

Module replacement

Some modules of your model may be non-traceable by ENOT Framework (e.g., nn.MultiHeadAttention from PyTorch) and therefore non-punable. In this package we provide a utility for replacing non-prunable modules with prunable ENOT modules and vice versa.

replace_prunable_modules(original_model, inplace=True)

Substitution of non-prunable modules with their prunable versions from ENOT.

Parameters:
  • original_model (nn.Module) – Pytorch model that we check for non-prunable modules and replace them with prunable ones.

  • inplace (bool) – Whether to modify the model itself without creating a copy.

Returns:

model – Pytorch model with substituted modules.

Return type:

nn.Module

revert_modules_replacement(original_model, inplace=True)

Revert modules substitution.

Parameters:
  • original_model (nn.Module) – Pytorch model that we check for prunable ENOT modules and replace them with original ones.

  • inplace (bool) – Whether to modify the model itself without creating a copy.

Returns:

model – Pytorch model with original modules.

Return type:

nn.Module