.. _Pruning: ####### Pruning ####### The ``enot.pruning`` module contains functional for automatic pruning of user-models. Pruning engine works fine with many models and frameworks like Torchvision, Pytorch Image Models (TIMM), OpenMMLab, and others. This package features pre-defined or custom pruning procedures for removing least important filters or neurons. ENOT Pruning module currently supports structured pruning procedure. User can define pruning ratio (percentage of channels removed) manually or use any of the pre-defined strategies. The first (and the simplest one) is `equal pruning`_ strategy, which keeps roughly the same percentage of channels in each prunable layer. The second (and quite compute intense) is `optimal pruning`_ strategy, which searches for the optimal pruning ratios for each prunable layer. A couple definitions to simplify documentation reading experience: * **Pruning ratio** is a fraction of channels to remove (prune) in prunable group or prunable layer. * **Gate** (or channel gate) is a special ``torch.nn.Module`` which gathers and saves channel "local importance". "Local importance" means that you can compare channels by their importance values within single layer, but not between distinct layers. * **Calibration** procedure estimates channels importance values according to user model and data. * **Channel label** is the global channel index in a network. * **Pruning config** specifies pruning amount for each prunable layer. See more :ref:`here `. *********** Pruning API *********** Pruning functional accessible to user is divided into three sections: * `High-level interface`_: the easiest way to use pruning functional. * `Low-level interface`_: low-level API for pruning, takes moderate time to understand and use. .. _pruning high-level interface: ******************** High-level interface ******************** High-level interface provides two ways to prune model. They work in the same manner, but have different policies for the amount of pruning for each layer: * :func:`~enot.pruning.calibrate_and_prune_model_equal` * :func:`~enot.pruning.calibrate_and_prune_model_optimal` .. _equal pruning: .. autofunction:: enot.pruning.calibrate_and_prune_model_equal .. note:: Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_equal(...) ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: *forward*, *execute_model*, *forward_train*. If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: *submodule.forward*, *head.predict_features*, *submodule1.submodule2.forward*. .. _optimal pruning: .. autofunction:: enot.pruning.calibrate_and_prune_model_optimal .. _pruning low-level interface: ******************* Low-level interface ******************* Low-level interface consists of: - :ref:`Pruning information (state) ` - :ref:`Calibration ` - :ref:`Model pruning ` - :ref:`Label selector ` - :ref:`Channels selector constraint ` The main class for pruning is :class:`.ModelPruningInfo`. It contains necessary information for model pruning: model execution graph, list of all prunable and non-prunable channels, all channel importance values, etc. Model pruning info is obtained through calibration process. Calibration for pruning is made by :class:`.PruningCalibrator` class. This class's instance should be used as a context manager. Inside it's context user should compute losses and perform backward passes through his model on calibration data. User can also calibrate his model by calling :func:`.calibrate_model_for_pruning` function and provide all necessary arguments for proper model execution and loss computation. Actual model pruning is performed by :func:`.prune_model` function which requires pruning info object and list of channel labels to prune. To select which labels should be pruned, and which should not, we created a special interface for label selection. Base class for all label selectors is :class:`.PruningLabelSelector`. This class defines an abstract function :meth:`.PruningLabelSelector.select` which should select labels to prune and which subclasses should implement to fit our pipelines based on :func:`.calibrate_and_prune_model`. Abstract class :class:`.ScorePruningLabelSelector` is a straightforward way to select least important channel labels. It utilises :ref:`pruning config` to select least important channels in each prunable layer. This is done by sorting channels by their importance values and selecting top-K least important channels based on the corresponding pruning ratio from the pruning config. These two label selectors are used internally in high-level interface for pruning described :ref:`above `: * :class:`.UniformPruningLabelSelector` * :class:`.OptimalPruningLabelSelector` .. _pruning pruning info: ============ Pruning info ============ .. autoclass:: enot.pruning.ModelPruningInfo :members: summary, n_prunable_groups .. _pruning calibration: =========== Calibration =========== .. autoclass:: enot.pruning.PruningCalibrator :members: __init__, pruning_info .. note:: ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: *forward*, *execute_model*, *forward_train*. If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: *submodule.forward*, *head.predict_features*, *submodule1.submodule2.forward*. .. autofunction:: enot.pruning.calibrate_model_for_pruning .. _pruning model pruning: ============= Model pruning ============= .. autofunction:: enot.pruning.prune_model The following function combines the functional of :func:`enot.pruning.calibrate_model_for_pruning` and :func:`enot.pruning.prune_model` functions: .. autofunction:: enot.pruning.calibrate_and_prune_model .. note:: Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_equal(...) ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: *forward*, *execute_model*, *forward_train*. If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: *submodule.forward*, *head.predict_features*, *submodule1.submodule2.forward*. .. _pruning label selector: ============== Label selector ============== To select which labels should be pruned use one of the following selectors: * :class:`~enot.pruning.UniformPruningLabelSelector` --- removes an equal percentage of channels in each prunable layer. * :class:`~enot.pruning.GlobalPruningLabelSelectorByChannels` --- selects the least important channels within network. * :class:`~enot.pruning.GlobalLatencyPruningLabelSelector` --- finds the model (labels) with latency as close as possible to the target latency parameter. * :class:`~enot.pruning.KnapsackPruningLabelSelector` --- maximizes the total score of the model with latency constraint according to the classical knapsack algorithm. * :class:`~enot.pruning.OptimalPruningLabelSelector` --- finds the optimal pruning ratios for each prunable layer. There are also two useful classes to help you implement your own label selector: * :class:`~enot.pruning.PruningLabelSelector` --- base class for all label selectors. * :class:`~enot.pruning.ScorePruningLabelSelector` --- base class for selecting least important channel labels. .. autoclass:: enot.pruning.UniformPruningLabelSelector :members: __init__ :show-inheritance: .. autoclass:: enot.pruning.GlobalPruningLabelSelectorByChannels :members: __init__ :show-inheritance: .. autoclass:: enot.pruning.GlobalLatencyPruningLabelSelector :members: __init__ :show-inheritance: .. autoclass:: enot.pruning.KnapsackPruningLabelSelector :members: __init__ :show-inheritance: .. autoclass:: enot.pruning.OptimalPruningLabelSelector :members: __init__ .. autoclass:: enot.pruning.PruningLabelSelector :members: .. autoclass:: enot.pruning.ScorePruningLabelSelector :members: :private-members: :show-inheritance: Here is an example of simple label selector that only selects even labels: .. code-block:: python import operator from functools import reduce from typing import List import torch from enot.pruning import ModelPruningInfo from enot.pruning import PruningLabelSelector class EvenPruningLabelSelector(PruningLabelSelector): def select(self, model: torch.nn.Module, pruning_info: ModelPruningInfo) -> List[int]: labels_by_gates, labels_by_score = self.extract_labels_info(pruning_info) # we don't need model and labels_by_score for our label selector del model, labels_by_score is_even = lambda label: label.label % 2 == 0 labels = reduce(operator.concat, labels_by_gates) even_labels = [*filter(is_even, labels)] return even_labels .. _pruning channels selector constraint: ============================ Channels Selector Constraint ============================ .. autoclass:: enot.pruning.ChannelsSelectorConstraint :members: :private-members: .. autoclass:: enot.pruning.DefaultChannelsSelectorConstraint :members: :show-inheritance: Here you can see implementation of :meth:`~enot.pruning.ChannelsSelectorConstraint._get` method for :class:`~enot.pruning.DefaultChannelsSelectorConstraint`: .. literalinclude:: ../../../src/enot/pruning/label_selector/channel_selector_constraint.py :pyobject: DefaultChannelsSelectorConstraint._get ***** Tools ***** Below you can find a set of tools that can help refine information of our pruning structure. Label printing conventions: - ``[--]:960:(75436, 76395)`` --- group of 960 prunable labels with IDs from 75436 to 76395 - ``[N-]:448:(3, 450)`` --- group of 448 non-prunable labels with IDs from 3 to 450 The second ``-`` symbol means whether these label group are sequential or not. .. autofunction:: enot.pruning.tools.tabulate_module_dependencies An example of usage: .. container:: toggle, toggle-hidden .. code-block:: python print(tabulate_module_dependencies(module_name='features.17.conv.1.0', pruning_info=pruning_info, model=model)) .. code-block:: text +-------------------+-----------------------------------------------+ | module name | features.17.conv.1.0 | +-------------------+-----------------------------------------------+ | node name | operation_node_conv2d_456 | +-------------------+-----------------------------------------------+ | node operation | conv2d | +-------------------+-----------------------------------------------+ | prunable groups | 1 | +-------------------+-----------------------------------------------+ | unprunable groups | 0 | +-------------------+-----------------------------------------------+ | node labels | [None, '[--]:960:(75436, 76395)', None, None] | +-------------------+-----------------------------------------------+ Prunable group 0 with labels [--]:960:(75436, 76395) +----------------------+---------------------------------------------------------------------------------------------+ | features.17.conv.1.0 | Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False) | +----------------------+---------------------------------------------------------------------------------------------+ | features.17.conv.0.0 | Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False) | +----------------------+---------------------------------------------------------------------------------------------+ Prunable group 0 users: +--------------------+-----------------------------------------------------------------+ | features.17.conv.2 | Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False) | +--------------------+-----------------------------------------------------------------+ .. autofunction:: enot.pruning.tools.tabulate_unprunable_groups An example of usage: .. container:: toggle, toggle-hidden .. code-block:: python print(tabulate_unprunable_groups(pruning_info=pruning_info, model=model, show_all_nodes=True)) .. code-block:: text Unprunable groups: Unprunable group 0 with labels [N-]:3:(0, 2) +----------------------+---------------------+ | parameter_node_199 | features.0.0.weight | +----------------------+---------------------+ | placeholder_node_314 | | +----------------------+---------------------+ Unprunable group 0 users: +--------------+------------------------------------------------------------------------------+ | features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) | +--------------+------------------------------------------------------------------------------+ Unprunable group 1 with labels [N-]:448:(3, 450) +----------------------+--+ | placeholder_node_314 | | +----------------------+--+ Unprunable group 1 users: +--------------+------------------------------------------------------------------------------+ | features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) | +--------------+------------------------------------------------------------------------------+ Unprunable group 2 with labels [N-]:6:(486, 491) +--------------------+---------------------+ | parameter_node_199 | features.0.0.weight | +--------------------+---------------------+ Unprunable group 2 users: +--------------+------------------------------------------------------------------------------+ | features.0.0 | Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) | +--------------+------------------------------------------------------------------------------+ ... .. autofunction:: enot.pruning.tools.tabulate_label_dependencies An example of usage: .. container:: toggle, toggle-hidden .. code-block:: python print(tabulate_label_dependencies(label=88964, pruning_info=pruning_info, model=model, show_all_nodes=True)) .. code-block:: text label: 88964 +---------------------+------------------------------------------------------------------+ | parameter_node_94 | features.18.0.weight | +---------------------+------------------------------------------------------------------+ | features.18.0 | Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) | +---------------------+------------------------------------------------------------------+ | parameter_node_197 | features.18.1.running_mean | +---------------------+------------------------------------------------------------------+ | parameter_node_151 | features.18.1.running_var | +---------------------+------------------------------------------------------------------+ | parameter_node_81 | features.18.1.weight | +---------------------+------------------------------------------------------------------+ | parameter_node_113 | features.18.1.bias | +---------------------+------------------------------------------------------------------+ | batch_norm | | +---------------------+------------------------------------------------------------------+ | hardtanh | | +---------------------+------------------------------------------------------------------+ | adaptive_avg_pool2d | | +---------------------+------------------------------------------------------------------+ | flatten | | +---------------------+------------------------------------------------------------------+ | dropout | | +---------------------+------------------------------------------------------------------+ | parameter_node_42 | classifier.1.weight | +---------------------+------------------------------------------------------------------+ .. autofunction:: enot.pruning.tools.tabulate_label_filters An example of usage: .. container:: toggle, toggle-hidden .. code-block:: python print(tabulate_label_filters(label=88964, pruning_info=pruning_info, model=model)) .. code-block:: text label: 88964 +----------------------------+----------------+---------------------------------+ | tensor name | shape | mapping | +============================+================+=================================+ | features.18.0.weight | (1, 320, 1, 1) | [[(1202, 0)], None, None, None] | +----------------------------+----------------+---------------------------------+ | features.18.1.running_mean | (1,) | [[(1202, 0)]] | +----------------------------+----------------+---------------------------------+ | features.18.1.running_var | (1,) | [[(1202, 0)]] | +----------------------------+----------------+---------------------------------+ | features.18.1.weight | (1,) | [[(1202, 0)]] | +----------------------------+----------------+---------------------------------+ | features.18.1.bias | (1,) | [[(1202, 0)]] | +----------------------------+----------------+---------------------------------+ | classifier.1.weight | (10, 1) | [None, [(1202, 0)]] | +----------------------------+----------------+---------------------------------+ .. autofunction:: enot.pruning.tools.get_label_filters An example of usage: .. container:: toggle, toggle-hidden .. code-block:: python for filter_ in get_label_filters(label=label, pruning_info=pruning_info, model=model): print(filter_) .. code-block:: text LabelFilterInfo(tensor_name='features.18.0.weight', shape=(1, 320, 1, 1), mapping=[[(1202, 0)], None, None, None], tensor=None) LabelFilterInfo(tensor_name='features.18.1.running_mean', shape=(1,), mapping=[[(1202, 0)]], tensor=None) LabelFilterInfo(tensor_name='features.18.1.running_var', shape=(1,), mapping=[[(1202, 0)]], tensor=None) LabelFilterInfo(tensor_name='features.18.1.weight', shape=(1,), mapping=[[(1202, 0)]], tensor=None) LabelFilterInfo(tensor_name='features.18.1.bias', shape=(1,), mapping=[[(1202, 0)]], tensor=None) LabelFilterInfo(tensor_name='classifier.1.weight', shape=(10, 1), mapping=[None, [(1202, 0)]], tensor=None) here ``shape`` --- is the shape of the filter, and ``mapping`` can be used to extract the desired filter from tensor. .. autofunction:: enot.tensor_trace.tools.get_tracing_node_by_module_name .. autofunction:: enot.tensor_trace.tools.get_module_name_by_tracing_node .. _s2mi ref: *************************************************** Converting dataloader items to PyTorch model inputs *************************************************** **Short summary:** *When ENOT needs to pass items from your dataloader to your model, it requires a special mapping function. This function will be used as following:* .. code-block:: python # Your code: your_model = ... your_dataloader = ... your_mapping_function = ... # In the ENOT framework: sample = next(your_dataloader) model_args, model_kwargs = your_mapping_function(sample) model_result = your_model(*model_args, **model_kwargs) When working with ENOT framework, sometimes it is necessary to write functions, which can transform dataloader items to model input format. We can later use these functions to extract dataloader items with the `__next__` method and convert them into user's custom model's input format. Such function should take a single input of type `Any` - a single item from user dataloader. It should return a tuple with two elements. First element is a tuple with model positional arguments, which will be passed to model's `__call__` method by unpacking operator `*args`. Second element is a dictionary with string keys and any-type values, which defines model keyword arguments, and will be passed to it's `__call__` method by unpacking operator `**kwargs`. Let's see the most basic example: .. code-block:: python DataLoader returns: (images: torch.Tensor, labels: torch.Tensor) Model expects: (images: torch.Tensor) def my_conversion_function(x): return (x[0], ), {} # Single positional argument for model, no keyword arguments. The same functionality is provided by our function :func:`.default_sample_to_model_inputs`. Example function which changes model's default forward keyword argument value: .. code-block:: python DataLoader returns: (images: torch.Tensor, labels: torch.Tensor) Model expects: (images: torch.Tensor, should_profile: bool = True) def my_conversion_function(x): return (x[0], ), {'should_profile': False} # We do not want to profile! Example function which also performs data pre-processing and moves it to the GPU: .. code-block:: python DataLoader returns: (images: torch.Tensor, labels: torch.Tensor) Model expects: (images: torch.Tensor) def my_conversion_function(x): # Normalizing images, casting them to float32 format, and move to cuda. return ((x[0].float()) / 255).cuda(), ), {} Advanced case with complex dataloader item structure and model positional and keyword arguments: .. code-block:: python DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'masks': List[torch.Tensor]} Model expects: (sequence: torch.Tensor, mask: Optional[torch.Tensor] = None, unroll_attention: bool = False) def my_conversion_function(x): sequence = x['sequence'] mask = x['masks'][0] return (sequence, ), {'mask': mask, 'unroll_attention': True} .. autofunction:: enot.utils.dataloader2model.default_sample_to_model_inputs .. _s2ns ref: ***************************************************** Extracting the number of samples from dataloader item ***************************************************** Sometimes ENOT needs to know the number of samples of interest in dataloader item. Sample of interest could be an image, one second of audio, number of bounding boxes, e.t.c. What single sample means fully depends on you. To provide this information to ENOT, you should define a special function. This function will be used as following: .. code-block:: python # Your code: your_model = ... your_dataloader = ... your_function = ... # In the ENOT framework: sample = next(your_dataloader) n_base_samples: int = your_function(sample) # process n_base_samples Such function should take a single input of type `Any` - a single item from user dataloader. It should return single integer value - number of samples of interest in dataloader sample. Let's see the most basic example: .. code-block:: python DataLoader returns: (images: torch.Tensor, labels: torch.Tensor) def my_function(x): # Extract the number of images from the batch dimension. return x[0].shape[0] The same functionality is provided by our function :func:`.default_sample_to_n_samples`. Advanced case with complex dataloader item structure: .. code-block:: python DataLoader returns: {'sequence': torch.Tensor, 'translation': torch.Tensor, 'mask': torch.Tensor} def my_conversion_function(x): # sequence is a tensor of shape (N, T1) with int64 dtype. # translation is a tensor of shape (N, T2) with int64 dtype. # mask is a tensor of shape (N, T1) with int64 dtype. # mask is equal to 1 where sequence values are correct (not padded) and 0 elsewhere. total_time_steps = x['mask'].sum() total_seconds = total_time_steps // 100 # Let's suppose that one time step is 10ms. return total_seconds # Total audio length in seconds in all samples. .. autofunction:: enot.utils.dataloader2model.default_sample_to_n_samples