Source code for enot.pruning.calibrate_and_prune

import logging
from typing import Any
from typing import Callable
from typing import Optional

import torch
from torch.utils.data import DataLoader

from enot.pruning.calibrate import calibrate_model_for_pruning
from enot.pruning.label_selector import OptimalPruningLabelSelector
from enot.pruning.label_selector import PruningLabelSelector
from enot.pruning.label_selector import UniformPruningLabelSelector
from enot.pruning.label_selector import default_channel_search_step
from enot.pruning.label_selector import default_min_channels_to_leave
from enot.pruning.prune import prune_model
from enot.utils.batch_norm import tune_bn_stats
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import DataLoaderSampleToNSamples
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.dataloader2model import default_sample_to_n_samples

_LOGGER = logging.getLogger(__name__)


[docs]def calibrate_and_prune_model( label_selector: PruningLabelSelector, model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], finetune_bn: bool = False, calibration_steps: Optional[int] = None, calibration_epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, show_tqdm: bool = False, entry_point: str = 'forward', ) -> torch.nn.Module: """ Estimates channel importance values and prunes model with user defined strategy. This function searches for prunable channels in user-defined ``model``. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategy ``label_selector``. Parameters ---------- label_selector : PruningLabelSelector Channel selector object. This object should implement :meth:`.PruningLabelSelector.select` method which returns list with channel indices to prune. model : torch.nn.Module Model to prune. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model channel importance values. loss_function : Callable[[Any, Any], torch.Tensor] Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. finetune_bn : bool, optional Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False. calibration_steps : int or None, optional Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in ``calibration_epochs`` argument. calibration_epochs : int, optional Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`. show_tqdm : bool, optional Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is False. entry_point : str, optional Model function for execution. See ``Notes`` section for the detailed argument description. Default value is "forward". Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...) An example of user-defined strategy:: import numpy as np from typing import List from enot.pruning import ModelPruningInfo, PruningLabelSelector, iterate_over_gate_criteria class CustomPruningSelector(PruningLabelSelector): def __init__(self, pruning_ratio: float): self.pruning_ratio: float = pruning_ratio super().__init__() def select(self, pruning_info: ModelPruningInfo) -> List[int]: labels_to_prune: List[int] = [] for _, labels, criteria in iterate_over_gate_criteria(pruning_info): criteria: np.ndarray = np.array(criteria) prune_channels_num = int(len(criteria) * self.pruning_ratio) index_for_pruning = np.argsort(criteria)[:prune_channels_num] labels_to_prune += np.array(labels)[index_for_pruning].tolist() return labels_to_prune ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: "forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: "submodule.forward", "head.predict_features", "submodule1.submodule2.forward". """ pruning_info = calibrate_model_for_pruning( model=model, dataloader=dataloader, loss_function=loss_function, n_steps=calibration_steps, epochs=calibration_epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, show_tqdm=show_tqdm, entry_point=entry_point, ) labels_to_prune = label_selector.select(pruning_info) pruned_model = prune_model( model=model, pruning_info=pruning_info, prune_labels=sorted(set(labels_to_prune)), inplace=False, ) if finetune_bn: tune_bn_stats( model=pruned_model, dataloader=dataloader, reset_bns=True, set_momentums_none=True, n_steps=calibration_steps, epochs=calibration_epochs, sample_to_model_inputs=sample_to_model_inputs, ) return pruned_model
[docs]def calibrate_and_prune_model_equal( model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], pruning_ratio: float = 0.5, finetune_bn: bool = False, calibration_steps: Optional[int] = None, calibration_epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, show_tqdm: bool = False, entry_point: str = 'forward', ) -> torch.nn.Module: """ Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer). Parameters ---------- model : torch.nn.Module Model to prune. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model channel importance values. loss_function : Callable[[Any, Any], torch.Tensor] Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. pruning_ratio : float, optional Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4). finetune_bn : bool, optional Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False. calibration_steps : int or None, optional Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in ``calibration_epochs`` argument. calibration_epochs : int, optional Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`. show_tqdm : bool, optional Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is False. entry_point : str, optional Model function for execution. See ``Notes`` section for the detailed argument description. Default value is "forward". Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_equal(...) ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: "forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: "submodule.forward", "head.predict_features", "submodule1.submodule2.forward". """ label_selector = UniformPruningLabelSelector(pruning_ratio) pruned_model = calibrate_and_prune_model( label_selector=label_selector, model=model, dataloader=dataloader, loss_function=loss_function, finetune_bn=finetune_bn, calibration_steps=calibration_steps, calibration_epochs=calibration_epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, show_tqdm=show_tqdm, entry_point=entry_point, ) return pruned_model
[docs]def calibrate_and_prune_model_optimal( model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], latency_calculation_function: Callable[[torch.nn.Module], float], target_latency: float, finetune_bn: bool = False, calibration_steps: Optional[int] = None, calibration_epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, show_tqdm: bool = False, min_channels_to_leave_fn: Callable[[int], int] = default_min_channels_to_leave, channel_search_step_fn: Callable[[int], int] = default_channel_search_step, n_search_steps: int = 200, entry_point: str = 'forward', **kwargs, ) -> torch.nn.Module: """ Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model. Parameters ---------- model : torch.nn.Module Model to prune. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model channel importance values. loss_function : Callable[[Any, Any], torch.Tensor] Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. latency_calculation_function : Callable[[torch.nn.Module], float] Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module) and measure its execution "speed" (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and other "speed" criteria. target_latency : float Target model latency. This argument should have the same units as ``latency_calculation_function``'s output. finetune_bn : bool, optional Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False. calibration_steps : int or None, optional Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in ``calibration_epochs`` argument. calibration_epochs : int, optional Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`. show_tqdm : bool, optional Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is False. min_channels_to_leave_fn : Callable[[int], int], optional Function to construct minimal pruning configuration. This function should take the number of channels in specific layer and return the minimal number of channels to keep. Default value is :func:`.default_min_channels_to_leave`. channel_search_step_fn : Callable[[int], int], optional Function to select search step for each group of channels. This search step defines the number of possible pruning ratios for a specific group. If search step is equal to 1 - then all number of channels after pruning are possible (with a lower bound set by ``min_channels_to_leave_fn``). When search step is equal to 2 - the number of options is reduces by 2. This function should take the number of channels in specific layer and return search step. Default value is :func:`.default_channel_search_step`. n_search_steps : int, optional Number of sampled configurations for pruning (equal to the number of ``latency_calculation_function`` executions) to select optimal architecture. Default value is 200. entry_point : str, optional Model function for execution. See ``Notes`` section for the detailed argument description. Default value is "forward". **kwargs Additional keyword arguments for label selector. Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_optimal(...) ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: "forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: "submodule.forward", "head.predict_features", "submodule1.submodule2.forward". """ label_selector = OptimalPruningLabelSelector( model=model, latency_calculation_function=latency_calculation_function, target_latency=target_latency, min_channels_to_leave_fn=min_channels_to_leave_fn, channel_search_step_fn=channel_search_step_fn, n_search_steps=n_search_steps, **kwargs, ) pruned_model = calibrate_and_prune_model( label_selector=label_selector, model=model, dataloader=dataloader, loss_function=loss_function, finetune_bn=finetune_bn, calibration_steps=calibration_steps, calibration_epochs=calibration_epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, show_tqdm=show_tqdm, entry_point=entry_point, ) return pruned_model