import logging
from itertools import islice
from typing import Any
from typing import Callable
from typing import Optional
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from enot.pruning.prune_calibrator import PruningCalibrator
from enot.pruning.pruning_info import ModelPruningInfo
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import DataLoaderSampleToNSamples
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.dataloader2model import default_sample_to_n_samples
_LOGGER = logging.getLogger(__name__)
[docs]def calibrate_model_for_pruning(
model: torch.nn.Module,
dataloader: DataLoader,
loss_function: Callable[[Any, Any], torch.Tensor],
n_steps: Optional[int] = None,
epochs: int = 1,
sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples,
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
show_tqdm: bool = False,
entry_point: str = 'forward',
max_prunable_labels: int = 4096,
) -> ModelPruningInfo:
"""
Estimates model channel importance values for later pruning.
This function searches for prunable channels in user-defined ``model``. After extracting channel information from
model graph, estimates channel importance values for later pruning.
Parameters
----------
model : torch.nn.Module
Model to calibrate pruning gates for.
dataloader : torch.utils.data.DataLoader
Dataloader for estimation of model channel importance values.
loss_function : Callable[[Any, Any], torch.Tensor]
Function which takes model output and dataloader sample and computes loss tensor. This function should take two
inputs and return scalar PyTorch tensor.
n_steps : int or None, optional
Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for
the number of epochs specified in ``epochs`` argument.
epochs : int, optional
Number of total calibration epochs. Not used when ``n_steps`` argument is not None. Default value is 1.
sample_to_n_samples : Callable, optional
Function which computes the number of instances (objects to process) in single dataloader batch (dataloader
sample). This function should take single input (dataloader sample) and return single integer - the number of
instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`.
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
show_tqdm : bool, optional
Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is
False.
entry_point: str, optional
Name of model method for tracing. In this method all prunable operations should be forwarded. Default value is
'forward'.
max_prunable_labels : int, optional
Maximum number of labels in group which can be pruned. If there are more labels in group, then this group
will be marked as non-prunable. Default value is 4096.
Returns
-------
ModelPruningInfo
Pruning information for later usage in pruning methods.
"""
steps_in_epoch = len(dataloader)
remaining_steps = steps_in_epoch * epochs if n_steps is None else n_steps
str_epochs = f'{remaining_steps / steps_in_epoch:.3f}' if n_steps is None else f'{epochs}'
_LOGGER.info(f'Calibrating channels for pruning for {str_epochs} epochs ({remaining_steps} steps).')
pruning_calibrator = PruningCalibrator(
model=model,
entry_point=entry_point,
max_prunable_labels=max_prunable_labels,
)
with pruning_calibrator:
while remaining_steps > 0:
# Slicing dataloader if necessary.
dataloader_ = islice(dataloader, remaining_steps) if remaining_steps < steps_in_epoch else dataloader
tqdm_iterator = tqdm(dataloader_, total=min(remaining_steps, steps_in_epoch), disable=not show_tqdm)
for sample in tqdm_iterator:
# Model arguments preparation and forward.
model_args, model_kwargs = sample_to_model_inputs(sample)
model_output = model(*model_args, **model_kwargs)
# Loss evaluation and gate gradient computation.
loss = loss_function(model_output, sample)
loss *= sample_to_n_samples(sample) # Scaling loss to number of samples.
# When strange error occurs here - see function documentation
# for improved error message generation.
loss.backward()
remaining_steps -= steps_in_epoch
_LOGGER.info('Calibration finished.')
return pruning_calibrator.pruning_info