import logging
from typing import Any
from typing import Callable
from typing import Optional
import torch
from torch.utils.data import DataLoader
from enot.pruning.calibrate import calibrate_model_for_pruning
from enot.pruning.label_selector import GlobalPruningLabelSelectorByChannels
from enot.pruning.label_selector import OptimalPruningLabelSelector
from enot.pruning.label_selector import PruningLabelSelector
from enot.pruning.label_selector import TArchOptStrategyFunc
from enot.pruning.label_selector import UniformPruningLabelSelector
from enot.pruning.label_selector import default_architecture_optimization_strategy
from enot.pruning.prune import prune_model
from enot.utils.batch_norm import tune_bn_stats
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import DataLoaderSampleToNSamples
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.dataloader2model import default_sample_to_n_samples
_LOGGER = logging.getLogger(__name__)
[docs]def calibrate_and_prune_model(
label_selector: PruningLabelSelector,
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
finetune_bn: bool = False,
calibration_steps: Optional[int] = None,
calibration_epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
show_tqdm: bool = False,
entry_point: str = 'forward',
) -> torch.nn.Module:
"""
Estimates channel importance values and prunes model with user defined strategy.
This function searches for prunable channels in user-defined ``model``. After extracting channel information from
model graph, estimates channel importance values for later pruning. After that prunes model by removing channels
provided by user-defined channel selection strategy ``label_selector``.
Parameters
----------
label_selector : PruningLabelSelector
Channel selector object. This object should implement :meth:`.PruningLabelSelector.select` method which returns
list with channel indices to prune.
model : torch.nn.Module
Model to prune.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model channel importance values.
loss_function : Callable[[Any, Any], torch.Tensor]
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
finetune_bn : bool, optional
Finetune batch norm layers (specifically, their running mean and running variance values) for better model
quality. Default value is False.
calibration_steps : int or None, optional
Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for
the number of epochs specified in ``calibration_epochs`` argument.
calibration_epochs : int, optional
Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
show_tqdm : bool, optional
Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is
False.
entry_point : str, optional
Model function for execution. See ``Notes`` section for the detailed argument description. Default value is
"forward".
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_model_for_pruning(...)
An example of user-defined strategy::
import numpy as np
from typing import List
from enot.pruning import ModelPruningInfo, PruningLabelSelector, iterate_over_gate_criteria
class CustomPruningSelector(PruningLabelSelector):
def __init__(self, pruning_ratio: float):
self.pruning_ratio: float = pruning_ratio
super().__init__()
def select(self, pruning_info: ModelPruningInfo) -> List[int]:
labels_to_prune: List[int] = []
for _, labels, criteria in iterate_over_gate_criteria(pruning_info):
criteria: np.ndarray = np.array(criteria)
prune_channels_num = int(len(criteria) * self.pruning_ratio)
index_for_pruning = np.argsort(criteria)[:prune_channels_num]
labels_to_prune += np.array(labels)[index_for_pruning].tolist()
return labels_to_prune
``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples:
"forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you
should first write the submodule's name followed by the function name, all separated by dots:
"submodule.forward", "head.predict_features", "submodule1.submodule2.forward".
"""
pruning_info = calibrate_model_for_pruning(
model=model,
dataloader=dataloader,
loss_function=loss_function,
n_steps=calibration_steps,
epochs=calibration_epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
show_tqdm=show_tqdm,
entry_point=entry_point,
)
labels_to_prune = label_selector.select(pruning_info)
pruned_model = prune_model(
model=model,
pruning_info=pruning_info,
prune_labels=sorted(set(labels_to_prune)),
inplace=False,
)
if finetune_bn:
tune_bn_stats(
model=pruned_model,
dataloader=dataloader,
reset_bns=True,
set_momentums_none=True,
n_steps=calibration_steps,
epochs=calibration_epochs,
sample_to_model_inputs=sample_to_model_inputs,
)
return pruned_model
[docs]def calibrate_and_prune_model_equal(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
pruning_ratio: float = 0.5,
finetune_bn: bool = False,
calibration_steps: Optional[int] = None,
calibration_epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
show_tqdm: bool = False,
entry_point: str = 'forward',
) -> torch.nn.Module:
"""
Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are
pruned in each prunable layer).
Parameters
----------
model : torch.nn.Module
Model to prune.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model channel importance values.
loss_function : Callable[[Any, Any], torch.Tensor]
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
pruning_ratio : float, optional
Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels,
which typically reduces the number of FLOPs and parameters by a factor of 4).
finetune_bn : bool, optional
Finetune batch norm layers (specifically, their running mean and running variance values) for better model
quality. Default value is False.
calibration_steps : int or None, optional
Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for
the number of epochs specified in ``calibration_epochs`` argument.
calibration_epochs : int, optional
Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
show_tqdm : bool, optional
Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is
False.
entry_point : str, optional
Model function for execution. See ``Notes`` section for the detailed argument description. Default value is
"forward".
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_equal(...)
``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples:
"forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you
should first write the submodule's name followed by the function name, all separated by dots:
"submodule.forward", "head.predict_features", "submodule1.submodule2.forward".
"""
label_selector = UniformPruningLabelSelector(pruning_ratio)
pruned_model = calibrate_and_prune_model(
label_selector=label_selector,
model=model,
dataloader=dataloader,
loss_function=loss_function,
finetune_bn=finetune_bn,
calibration_steps=calibration_steps,
calibration_epochs=calibration_epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
show_tqdm=show_tqdm,
entry_point=entry_point,
)
return pruned_model
[docs]def calibrate_and_prune_model_optimal(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
latency_calculation_function: Callable[[torch.nn.Module], float],
target_latency: float,
finetune_bn: bool = False,
calibration_steps: Optional[int] = None,
calibration_epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
show_tqdm: bool = False,
architecture_optimization_strategy: TArchOptStrategyFunc = default_architecture_optimization_strategy,
n_search_steps: int = 200,
entry_point: str = 'forward',
**kwargs,
) -> torch.nn.Module:
"""
Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune
in each prunable layer) and prunes model.
Parameters
----------
model : torch.nn.Module
Model to prune.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model channel importance values.
loss_function : Callable[[Any, Any], torch.Tensor]
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
latency_calculation_function : Callable[[torch.nn.Module], float]
Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module)
and measure its execution "speed" (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and
other "speed" criteria.
target_latency : float
Target model latency. This argument should have the same units as ``latency_calculation_function``'s output.
finetune_bn : bool, optional
Finetune batch norm layers (specifically, their running mean and running variance values) for better model
quality. Default value is False.
calibration_steps : int or None, optional
Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for
the number of epochs specified in ``calibration_epochs`` argument.
calibration_epochs : int, optional
Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
show_tqdm : bool, optional
Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is
False.
architecture_optimization_strategy : Callable[[int], Tuple[int, int]]
Function that constructs an architecture pruning configuration. This function should take the number of channels
in specific layer and return the minimal number of channels to keep and channel search step. This search
step defines the number of possible pruning ratios for a specific group.
If search step is equal to 1, then all number of channels after
pruning are possible (with a lower bound set by the minimal number of channels). When search step is equal
to 2 - the number of options is reduces by 2. This function should take the number of channels in specific
layer and return search step. Default value is :func:`.default_architecture_optimization_strategy`.
n_search_steps : int, optional
Number of sampled configurations for pruning (equal to the number of ``latency_calculation_function``
executions) to select optimal architecture. Default value is 200.
entry_point : str, optional
Model function for execution. See ``Notes`` section for the detailed argument description. Default value is
"forward".
**kwargs
Additional keyword arguments for label selector.
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_optimal(...)
``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples:
"forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you
should first write the submodule's name followed by the function name, all separated by dots:
"submodule.forward", "head.predict_features", "submodule1.submodule2.forward".
"""
label_selector = OptimalPruningLabelSelector(
model=model,
latency_calculation_function=latency_calculation_function,
target_latency=target_latency,
architecture_optimization_strategy=architecture_optimization_strategy,
n_search_steps=n_search_steps,
**kwargs,
)
pruned_model = calibrate_and_prune_model(
label_selector=label_selector,
model=model,
dataloader=dataloader,
loss_function=loss_function,
finetune_bn=finetune_bn,
calibration_steps=calibration_steps,
calibration_epochs=calibration_epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
show_tqdm=show_tqdm,
entry_point=entry_point,
)
return pruned_model
def calibrate_and_prune_model_global_wrt_metric_drop(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
validation_function: Callable[[torch.nn.Module], float],
maximal_acceptable_metric_drop: float,
minimal_channels_to_prune: int = 100,
maximal_channels_to_prune: int = 300,
channel_step_to_search: int = 10,
finetune_bn: bool = False,
calibration_steps: Optional[int] = None,
calibration_epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
show_tqdm: bool = False,
entry_point: str = 'forward',
) -> torch.nn.Module:
"""
Estimates channel importance values and prunes model with global pruning strategy (same percentage of channels are
pruned in each prunable layer).
Parameters
----------
model : torch.nn.Module
Model to prune.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model channel importance values.
loss_function : Callable[[Any, Any], torch.Tensor]
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
validation_function : Callable[[torch.nn.Module], float]
Function which evaluates pruned model to measure desired metric.
maximal_acceptable_metric_drop : float
Maximal value of the metric decrease.
minimal_channels_to_prune : int
Minimal channels amount to prune within all network.
maximal_channels_to_prune: int
Maximal channels amount to prune within all network.
channel_step_to_search : int
Channel configuration search step size. The greater value gives faster but less accurate results.
finetune_bn : bool, optional
Finetune batch norm layers (specifically, their running mean and running variance values) for better model
quality. Default value is False.
calibration_steps : int or None, optional
Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for
the number of epochs specified in ``calibration_epochs`` argument.
calibration_epochs : int, optional
Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
show_tqdm : bool, optional
Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is
False.
entry_point : str, optional
Model function for execution. See ``Notes`` section for the detailed argument description. Default value is
"forward".
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_and_prune_model_equal(...)
``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples:
"forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you
should first write the submodule's name followed by the function name, all separated by dots:
"submodule.forward", "head.predict_features", "submodule1.submodule2.forward".
"""
baseline_metric = validation_function(model)
_LOGGER.info(f'Baseline metric of user model is: {baseline_metric}')
baseline_model_device = next(model.parameters()).device
pruning_info = calibrate_model_for_pruning(
model=model,
dataloader=dataloader,
loss_function=loss_function,
n_steps=calibration_steps,
epochs=calibration_epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
show_tqdm=show_tqdm,
entry_point=entry_point,
)
model = model.cpu()
prev_pruned_model = None
for n_channels in range(minimal_channels_to_prune, maximal_channels_to_prune + 1, channel_step_to_search):
_LOGGER.info(f"Current pruned channels: {n_channels}")
label_selector = GlobalPruningLabelSelectorByChannels(n_channels)
labels_to_prune = label_selector.select(pruning_info)
pruned_model = prune_model(
model=model,
pruning_info=pruning_info,
prune_labels=sorted(set(labels_to_prune)),
inplace=False,
)
pruned_model = pruned_model.to(device=baseline_model_device)
if finetune_bn:
tune_bn_stats(
model=pruned_model,
dataloader=dataloader,
reset_bns=True,
set_momentums_none=True,
n_steps=calibration_steps,
epochs=calibration_epochs,
sample_to_model_inputs=sample_to_model_inputs,
)
p_metric = validation_function(pruned_model)
pruned_model = pruned_model.cpu()
_LOGGER.info(f"Metric after pruning: {p_metric}")
if baseline_metric - p_metric > maximal_acceptable_metric_drop:
if prev_pruned_model is None:
prev_pruned_model = pruned_model
model.to(device=baseline_model_device)
return prev_pruned_model.to(device=baseline_model_device)
prev_pruned_model = pruned_model
model.to(device=baseline_model_device)
return prev_pruned_model.to(device=baseline_model_device)