Source code for enot.pruning.calibrate

import logging
from itertools import islice
from typing import Any
from typing import Callable
from typing import Optional

import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from enot.pruning.prune_calibrator import PruningCalibrator
from enot.pruning.pruning_info import ModelPruningInfo
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import DataLoaderSampleToNSamples
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.dataloader2model import default_sample_to_n_samples

_LOGGER = logging.getLogger(__name__)


[docs]def calibrate_model_for_pruning( model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], n_steps: Optional[int] = None, epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, show_tqdm: bool = False, entry_point: str = 'forward', max_prunable_labels: int = 4096, ) -> ModelPruningInfo: """ Estimates model channel importance values for later pruning. This function searches for prunable channels in user-defined ``model``. After extracting channel information from model graph, estimates channel importance values for later pruning. Parameters ---------- model : torch.nn.Module Model to calibrate pruning gates for. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model channel importance values. loss_function : Callable[[Any, Any], torch.Tensor] Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. n_steps : int or None, optional Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in ``epochs`` argument. epochs : int, optional Number of total calibration epochs. Not used when ``n_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`. show_tqdm : bool, optional Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is False. entry_point: str, optional Name of model method for tracing. In this method all prunable operations should be forwarded. Default value is 'forward'. max_prunable_labels : int, optional Maximum number of labels in group which can be pruned. If there are more labels in group, then this group will be marked as non-prunable. Default value is 4096. Returns ------- ModelPruningInfo Pruning information for later usage in pruning methods. """ steps_in_epoch = len(dataloader) remaining_steps = steps_in_epoch * epochs if n_steps is None else n_steps str_epochs = f'{remaining_steps / steps_in_epoch:.3f}' if n_steps is None else f'{epochs}' _LOGGER.info(f'Calibrating channels for pruning for {str_epochs} epochs ({remaining_steps} steps).') pruning_calibrator = PruningCalibrator( model=model, entry_point=entry_point, max_prunable_labels=max_prunable_labels, ) with pruning_calibrator: while remaining_steps > 0: # Slicing dataloader if necessary. dataloader_ = islice(dataloader, remaining_steps) if remaining_steps < steps_in_epoch else dataloader tqdm_iterator = tqdm(dataloader_, total=min(remaining_steps, steps_in_epoch), disable=not show_tqdm) for sample in tqdm_iterator: # Model arguments preparation and forward. model_args, model_kwargs = sample_to_model_inputs(sample) model_output = model(*model_args, **model_kwargs) # Loss evaluation and gate gradient computation. loss = loss_function(model_output, sample) loss *= sample_to_n_samples(sample) # Scaling loss to number of samples. # When strange error occurs here - see function documentation # for improved error message generation. loss.backward() remaining_steps -= steps_in_epoch _LOGGER.info('Calibration finished.') return pruning_calibrator.pruning_info