import logging
from typing import Any
from typing import Callable
from typing import Optional
import torch
from torch.utils.data import DataLoader
from enot.pruning.calibrate import calibrate_model_for_pruning
from enot.pruning.label_selector import OptimalPruningLabelSelector
from enot.pruning.label_selector import PruningLabelSelector
from enot.pruning.label_selector import UniformPruningLabelSelector
from enot.pruning.label_selector import default_channel_search_step
from enot.pruning.label_selector import default_min_channels_to_leave
from enot.pruning.prune import prune_model
from enot.utils.batch_norm import tune_bn_stats
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import DataLoaderSampleToNSamples
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.dataloader2model import default_sample_to_n_samples
_LOGGER = logging.getLogger(__name__)
[docs]def calibrate_and_prune_model(
label_selector: PruningLabelSelector,
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
finetune_bn: bool = False,
calibration_steps: Optional[int] = None,
calibration_epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
show_tqdm: bool = False,
entry_point: str = 'forward',
) -> torch.nn.Module:
"""
Estimates channel importance values and prunes model with user defined strategy.
This function searches for prunable channels in user-defined ``model``. After extracting channel information from
model graph, estimates channel importance values for later pruning. After that prunes model by removing channels
provided by user-defined channel selection strategy ``label_selector``.
Parameters
----------
label_selector : PruningLabelSelector
Channel selector object. This object should implement :meth:`.PruningLabelSelector.select` method which returns
list with channel indices to prune.
model : torch.nn.Module
Model to prune.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model channel importance values.
loss_function : Callable[[Any, Any], torch.Tensor]
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
finetune_bn : bool, optional
Finetune batch norm layers (specifically, their running mean and running variance values) for better model
quality. Default value is False.
calibration_steps : int or None, optional
Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for
the number of epochs specified in ``calibration_epochs`` argument.
calibration_epochs : int, optional
Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
show_tqdm : bool, optional
Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is
False.
entry_point : str, optional
Model function for execution. See ``Notes`` section for the detailed argument description. Default value is
"forward".
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_model_for_pruning(...)
An example of user-defined strategy::
import numpy as np
from typing import List
from enot.pruning import ModelPruningInfo, PruningLabelSelector, iterate_over_gate_criteria
class CustomPruningSelector(PruningLabelSelector):
def __init__(self, pruning_ratio: float):
self.pruning_ratio: float = pruning_ratio
super().__init__()
def select(self, pruning_info: ModelPruningInfo) -> List[int]:
labels_to_prune: List[int] = []
for _, labels, criteria in iterate_over_gate_criteria(pruning_info):
criteria: np.ndarray = np.array(criteria)
prune_channels_num = int(len(criteria) * self.pruning_ratio)
index_for_pruning = np.argsort(criteria)[:prune_channels_num]
labels_to_prune += np.array(labels)[index_for_pruning].tolist()
return labels_to_prune
``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples:
"forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you
should first write the submodule's name followed by the function name, all separated by dots:
"submodule.forward", "head.predict_features", "submodule1.submodule2.forward".
"""
pruning_info = calibrate_model_for_pruning(
model=model,
dataloader=dataloader,
loss_function=loss_function,
n_steps=calibration_steps,
epochs=calibration_epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
show_tqdm=show_tqdm,
entry_point=entry_point,
)
labels_to_prune = label_selector.select(pruning_info)
pruned_model = prune_model(
model=model,
pruning_info=pruning_info,
prune_labels=sorted(set(labels_to_prune)),
inplace=False,
)
if finetune_bn:
tune_bn_stats(
model=pruned_model,
dataloader=dataloader,
reset_bns=True,
set_momentums_none=True,
n_steps=calibration_steps,
epochs=calibration_epochs,
sample_to_model_inputs=sample_to_model_inputs,
)
return pruned_model
[docs]def calibrate_and_prune_model_equal(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
pruning_ratio: float = 0.5,
finetune_bn: bool = False,
calibration_steps: Optional[int] = None,
calibration_epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
show_tqdm: bool = False,
entry_point: str = 'forward',
) -> torch.nn.Module:
"""
Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are
pruned in each prunable layer).
Parameters
----------
model : torch.nn.Module
Model to prune.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model channel importance values.
loss_function : Callable[[Any, Any], torch.Tensor]
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
pruning_ratio : float, optional
Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels,
which typically reduces the number of FLOPs and parameters by a factor of 4).
finetune_bn : bool, optional
Finetune batch norm layers (specifically, their running mean and running variance values) for better model
quality. Default value is False.
calibration_steps : int or None, optional
Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for
the number of epochs specified in ``calibration_epochs`` argument.
calibration_epochs : int, optional
Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
show_tqdm : bool, optional
Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is
False.
entry_point : str, optional
Model function for execution. See ``Notes`` section for the detailed argument description. Default value is
"forward".
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_equal(...)
``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples:
"forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you
should first write the submodule's name followed by the function name, all separated by dots:
"submodule.forward", "head.predict_features", "submodule1.submodule2.forward".
"""
label_selector = UniformPruningLabelSelector(pruning_ratio)
pruned_model = calibrate_and_prune_model(
label_selector=label_selector,
model=model,
dataloader=dataloader,
loss_function=loss_function,
finetune_bn=finetune_bn,
calibration_steps=calibration_steps,
calibration_epochs=calibration_epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
show_tqdm=show_tqdm,
entry_point=entry_point,
)
return pruned_model
[docs]def calibrate_and_prune_model_optimal(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
latency_calculation_function: Callable[[torch.nn.Module], float],
target_latency: float,
finetune_bn: bool = False,
calibration_steps: Optional[int] = None,
calibration_epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
show_tqdm: bool = False,
min_channels_to_leave_fn: Callable[[int], int] = default_min_channels_to_leave,
channel_search_step_fn: Callable[[int], int] = default_channel_search_step,
n_search_steps: int = 200,
entry_point: str = 'forward',
**kwargs,
) -> torch.nn.Module:
"""
Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune
in each prunable layer) and prunes model.
Parameters
----------
model : torch.nn.Module
Model to prune.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model channel importance values.
loss_function : Callable[[Any, Any], torch.Tensor]
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
latency_calculation_function : Callable[[torch.nn.Module], float]
Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module)
and measure its execution "speed" (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and
other "speed" criteria.
target_latency : float
Target model latency. This argument should have the same units as ``latency_calculation_function``'s output.
finetune_bn : bool, optional
Finetune batch norm layers (specifically, their running mean and running variance values) for better model
quality. Default value is False.
calibration_steps : int or None, optional
Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for
the number of epochs specified in ``calibration_epochs`` argument.
calibration_epochs : int, optional
Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
show_tqdm : bool, optional
Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is
False.
min_channels_to_leave_fn : Callable[[int], int], optional
Function to construct minimal pruning configuration. This function should take the number of channels in
specific layer and return the minimal number of channels to keep. Default value is
:func:`.default_min_channels_to_leave`.
channel_search_step_fn : Callable[[int], int], optional
Function to select search step for each group of channels. This search step defines the number of possible
pruning ratios for a specific group. If search step is equal to 1 - then all number of channels after pruning
are possible (with a lower bound set by ``min_channels_to_leave_fn``). When search step is equal to 2 - the
number of options is reduces by 2. This function should take the number of channels in specific layer and return
search step. Default value is :func:`.default_channel_search_step`.
n_search_steps : int, optional
Number of sampled configurations for pruning (equal to the number of ``latency_calculation_function``
executions) to select optimal architecture. Default value is 200.
entry_point : str, optional
Model function for execution. See ``Notes`` section for the detailed argument description. Default value is
"forward".
**kwargs
Additional keyword arguments for label selector.
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_optimal(...)
``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples:
"forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you
should first write the submodule's name followed by the function name, all separated by dots:
"submodule.forward", "head.predict_features", "submodule1.submodule2.forward".
"""
label_selector = OptimalPruningLabelSelector(
model=model,
latency_calculation_function=latency_calculation_function,
target_latency=target_latency,
min_channels_to_leave_fn=min_channels_to_leave_fn,
channel_search_step_fn=channel_search_step_fn,
n_search_steps=n_search_steps,
**kwargs,
)
pruned_model = calibrate_and_prune_model(
label_selector=label_selector,
model=model,
dataloader=dataloader,
loss_function=loss_function,
finetune_bn=finetune_bn,
calibration_steps=calibration_steps,
calibration_epochs=calibration_epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
show_tqdm=show_tqdm,
entry_point=entry_point,
)
return pruned_model