import logging
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
import torch
from torch.utils.data import DataLoader
from enot.pruning.calibrate import calibrate_model_for_pruning
from enot.pruning.prune import prune_model
from enot.pruning.prune_strategy import get_labels_for_equal_pruning
from enot.pruning.pruning_info import ModelPruningInfo
from enot.utils.batch_norm import tune_bn_stats
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import DataLoaderSampleToNSamples
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.dataloader2model import default_sample_to_n_samples
_LOGGER = logging.getLogger(__name__)
[docs]def calibrate_and_prune_model(
label_selection_fn: Callable[[torch.nn.Module, ModelPruningInfo], List[int]],
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
finetune_bn: bool = False,
n_steps: Optional[int] = None,
epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
verbose: int = 0,
**kwargs,
) -> torch.nn.Module:
"""
Estimates channel importances and prunes model with user defined strategy.
This function searches for prunable channels in user-defined ``model``. After extracting channel information from
model graph, estimates channel importances for later pruning. After that prunes model by removing channels specified
by user-defined strategy ``label_selection_fn``.
Parameters
----------
label_selection_fn : Callable[[torch.nn.Module, ModelPruningInfo], List[int]]
Channel selection strategy for pruning. This function should typically return labels of the least
important channels.
model : torch.nn.Module
Model to calibrate pruning gates for.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model's channel importances.
loss_function : Callable
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
finetune_bn : bool
Finetune running mean and running variance for better model quality.
n_steps : int or None, optional
Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader
images for the number of epochs in ``epochs`` argument.
epochs : int, optional
Number of total threshold calibration epochs. Not used when ``n_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more in :doc:`dataloader2model`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more in :doc:`dataloader2model`.
verbose : int, optional
Procedure verbosity level. 0 disables all messages, 1 enables ``tqdm`` progress bar logging, 2 gives additional
information about calibration. Default value is 0.
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_model_for_pruning(...)
``label_selection_fn`` is a function that defines model pruning strategy. This function should return a list with
integer values - channel labels to remove from the model. Input arguments of ``label_selection_fn`` are:
model : torch.nn.Module
Model to calibrate pruning gates for.
pruning_info : ModelPruningInfo
Pruning-related information collected during calibration stage
(see more in :py:meth:`calibrate_model_for_pruning`).
kwargs : Dict[str, Any]
Keyword arguments for ``label_selection_fn``.
An example of user-defined strategy::
from enot.pruning import ModelPruningInfo, iterate_over_gate_criteria
def label_selection_fn(model: torch.nn.Module, pruning_info: ModelPruningInfo, **kwargs):
labels_to_prune: List[int] = []
pruning_ratio = kwargs.pop('pruning_ratio')
for _, labels, criteria in iterate_over_gate_criteria(pruning_info):
criteria: np.ndarray = np.array(criteria)
prune_channels_num = int(len(criteria) * pruning_ratio)
index_for_pruning = np.argsort(criteria)[:prune_channels_num]
labels_to_prune += np.array(labels)[index_for_pruning].tolist()
return labels_to_prune
"""
pruning_info = calibrate_model_for_pruning(
model=model,
dataloader=dataloader,
loss_function=loss_function,
n_steps=n_steps,
epochs=epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
verbose=verbose,
)
labels_to_prune = label_selection_fn(
model,
pruning_info,
**kwargs,
)
pruned_model = prune_model(
model=model,
pruning_info=pruning_info,
prune_labels=sorted(set(labels_to_prune)),
inplace=False,
)
if finetune_bn:
tune_bn_stats(
model=pruned_model,
dataloader=dataloader,
reset_bns=True,
set_momentums_none=True,
n_steps=n_steps,
epochs=epochs,
sample_to_model_inputs=sample_to_model_inputs,
)
return pruned_model
[docs]def calibrate_and_prune_model_equal(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
pruning_ratio: float = 0.5,
finetune_bn: bool = False,
n_steps: Optional[int] = None,
epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
verbose: int = 0,
**kwargs,
) -> torch.nn.Module:
"""
Estimates channel importances and prunes model with equal pruning strategy (same amount of channels will be pruned
at each prunable layers).
Parameters
----------
model : torch.nn.Module
Model to calibrate pruning gates for.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model's channel importances.
loss_function : Callable
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
pruning_ratio : float
Relative amount of channels to prune at each prunable layers.
finetune_bn : bool
Finetune running mean and running variance for better model quality.
n_steps : int or None, optional
Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader
images for the number of epochs in ``epochs`` argument.
epochs : int, optional
Number of total threshold calibration epochs. Not used when ``n_steps`` argument is not None. Default value is
1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more in :doc:`dataloader2model`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more in :doc:`dataloader2model`.
verbose : int, optional
Procedure verbosity level. 0 disables all messages, 1 enables ``tqdm`` progress bar logging, 2 gives additional
information about calibration. Default value is 0.
Returns
-------
torch.nn.Module
Pruned model.
Notes
-----
Before calling this function, your model should be prepared to be as close to practical inference usage as possible.
For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this
method (e.g. when the model contains dropout layers).
Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with
inference input preprocessing).
If you encounter errors during backward call, you can wrap this function call with the following statement::
with torch.autograd.set_detect_anomaly(True):
calibrate_model_for_pruning(...)
"""
kwargs.update({'pruning_ratio': pruning_ratio})
def label_selection_fn(
_model: torch.nn.Module,
_pruning_info: ModelPruningInfo,
**_kwargs,
):
_pruning_ratio = _kwargs.get('pruning_ratio', None)
if _pruning_ratio is None:
raise ValueError(
'For equal pruning ``calibrate_and_prune_model_equal`` '
'function you should pass keyword argument ``pruning_ratio``'
)
_labels_to_prune = get_labels_for_equal_pruning(
pruning_info=_pruning_info,
pruning_ratio=_pruning_ratio,
)
return _labels_to_prune
pruned_model = calibrate_and_prune_model(
label_selection_fn=label_selection_fn,
model=model,
dataloader=dataloader,
loss_function=loss_function,
finetune_bn=finetune_bn,
n_steps=n_steps,
epochs=epochs,
sample_to_n_samples=sample_to_n_samples,
sample_to_model_inputs=sample_to_model_inputs,
verbose=verbose,
**kwargs,
)
return pruned_model