Source code for enot.pruning.calibrate_and_prune

import logging
from typing import Any
from typing import Callable
from typing import List
from typing import Optional

import torch
from torch.utils.data import DataLoader

from enot.pruning.calibrate import calibrate_model_for_pruning
from enot.pruning.prune import prune_model
from enot.pruning.prune_strategy import get_labels_for_equal_pruning
from enot.pruning.pruning_info import ModelPruningInfo
from enot.utils.batch_norm import tune_bn_stats
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import DataLoaderSampleToNSamples
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.dataloader2model import default_sample_to_n_samples

_LOGGER = logging.getLogger(__name__)


[docs]def calibrate_and_prune_model( label_selection_fn: Callable[[torch.nn.Module, ModelPruningInfo], List[int]], model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], finetune_bn: bool = False, n_steps: Optional[int] = None, epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, verbose: int = 0, **kwargs, ) -> torch.nn.Module: """ Estimates channel importances and prunes model with user defined strategy. This function searches for prunable channels in user-defined ``model``. After extracting channel information from model graph, estimates channel importances for later pruning. After that prunes model by removing channels specified by user-defined strategy ``label_selection_fn``. Parameters ---------- label_selection_fn : Callable[[torch.nn.Module, ModelPruningInfo], List[int]] Channel selection strategy for pruning. This function should typically return labels of the least important channels. model : torch.nn.Module Model to calibrate pruning gates for. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model's channel importances. loss_function : Callable Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. finetune_bn : bool Finetune running mean and running variance for better model quality. n_steps : int or None, optional Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs in ``epochs`` argument. epochs : int, optional Number of total threshold calibration epochs. Not used when ``n_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more in :doc:`dataloader2model`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more in :doc:`dataloader2model`. verbose : int, optional Procedure verbosity level. 0 disables all messages, 1 enables ``tqdm`` progress bar logging, 2 gives additional information about calibration. Default value is 0. Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...) ``label_selection_fn`` is a function that defines model pruning strategy. This function should return a list with integer values - channel labels to remove from the model. Input arguments of ``label_selection_fn`` are: model : torch.nn.Module Model to calibrate pruning gates for. pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :py:meth:`calibrate_model_for_pruning`). kwargs : Dict[str, Any] Keyword arguments for ``label_selection_fn``. An example of user-defined strategy:: from enot.pruning import ModelPruningInfo, iterate_over_gate_criteria def label_selection_fn(model: torch.nn.Module, pruning_info: ModelPruningInfo, **kwargs): labels_to_prune: List[int] = [] pruning_ratio = kwargs.pop('pruning_ratio') for _, labels, criteria in iterate_over_gate_criteria(pruning_info): criteria: np.ndarray = np.array(criteria) prune_channels_num = int(len(criteria) * pruning_ratio) index_for_pruning = np.argsort(criteria)[:prune_channels_num] labels_to_prune += np.array(labels)[index_for_pruning].tolist() return labels_to_prune """ pruning_info = calibrate_model_for_pruning( model=model, dataloader=dataloader, loss_function=loss_function, n_steps=n_steps, epochs=epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, verbose=verbose, ) labels_to_prune = label_selection_fn( model, pruning_info, **kwargs, ) pruned_model = prune_model( model=model, pruning_info=pruning_info, prune_labels=sorted(set(labels_to_prune)), inplace=False, ) if finetune_bn: tune_bn_stats( model=pruned_model, dataloader=dataloader, reset_bns=True, set_momentums_none=True, n_steps=n_steps, epochs=epochs, sample_to_model_inputs=sample_to_model_inputs, ) return pruned_model
[docs]def calibrate_and_prune_model_equal( model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], pruning_ratio: float = 0.5, finetune_bn: bool = False, n_steps: Optional[int] = None, epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, verbose: int = 0, **kwargs, ) -> torch.nn.Module: """ Estimates channel importances and prunes model with equal pruning strategy (same amount of channels will be pruned at each prunable layers). Parameters ---------- model : torch.nn.Module Model to calibrate pruning gates for. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model's channel importances. loss_function : Callable Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. pruning_ratio : float Relative amount of channels to prune at each prunable layers. finetune_bn : bool Finetune running mean and running variance for better model quality. n_steps : int or None, optional Number of total threshold calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs in ``epochs`` argument. epochs : int, optional Number of total threshold calibration epochs. Not used when ``n_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more in :doc:`dataloader2model`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more in :doc:`dataloader2model`. verbose : int, optional Procedure verbosity level. 0 disables all messages, 1 enables ``tqdm`` progress bar logging, 2 gives additional information about calibration. Default value is 0. Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...) """ kwargs.update({'pruning_ratio': pruning_ratio}) def label_selection_fn( _model: torch.nn.Module, _pruning_info: ModelPruningInfo, **_kwargs, ): _pruning_ratio = _kwargs.get('pruning_ratio', None) if _pruning_ratio is None: raise ValueError( 'For equal pruning ``calibrate_and_prune_model_equal`` ' 'function you should pass keyword argument ``pruning_ratio``' ) _labels_to_prune = get_labels_for_equal_pruning( pruning_info=_pruning_info, pruning_ratio=_pruning_ratio, ) return _labels_to_prune pruned_model = calibrate_and_prune_model( label_selection_fn=label_selection_fn, model=model, dataloader=dataloader, loss_function=loss_function, finetune_bn=finetune_bn, n_steps=n_steps, epochs=epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, verbose=verbose, **kwargs, ) return pruned_model