Source code for enot.pruning.calibrate_and_prune

import logging
from typing import Any
from typing import Callable
from typing import Optional

import torch
from torch.utils.data import DataLoader

from enot.pruning.calibrate import calibrate_model_for_pruning
from enot.pruning.label_selector import GlobalPruningLabelSelectorByChannels
from enot.pruning.label_selector import OptimalPruningLabelSelector
from enot.pruning.label_selector import PruningLabelSelector
from enot.pruning.label_selector import TArchOptStrategyFunc
from enot.pruning.label_selector import UniformPruningLabelSelector
from enot.pruning.label_selector import default_architecture_optimization_strategy
from enot.pruning.prune import prune_model
from enot.utils.batch_norm import tune_bn_stats
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import DataLoaderSampleToNSamples
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.dataloader2model import default_sample_to_n_samples

_LOGGER = logging.getLogger(__name__)


[docs]def calibrate_and_prune_model( label_selector: PruningLabelSelector, model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], finetune_bn: bool = False, calibration_steps: Optional[int] = None, calibration_epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, show_tqdm: bool = False, entry_point: str = 'forward', ) -> torch.nn.Module: """ Estimates channel importance values and prunes model with user defined strategy. This function searches for prunable channels in user-defined ``model``. After extracting channel information from model graph, estimates channel importance values for later pruning. After that prunes model by removing channels provided by user-defined channel selection strategy ``label_selector``. Parameters ---------- label_selector : PruningLabelSelector Channel selector object. This object should implement :meth:`.PruningLabelSelector.select` method which returns list with channel indices to prune. model : torch.nn.Module Model to prune. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model channel importance values. loss_function : Callable[[Any, Any], torch.Tensor] Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. finetune_bn : bool, optional Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False. calibration_steps : int or None, optional Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in ``calibration_epochs`` argument. calibration_epochs : int, optional Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`. show_tqdm : bool, optional Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is False. entry_point : str, optional Model function for execution. See ``Notes`` section for the detailed argument description. Default value is "forward". Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_model_for_pruning(...) An example of user-defined strategy:: import numpy as np from typing import List from enot.pruning import ModelPruningInfo, PruningLabelSelector, iterate_over_gate_criteria class CustomPruningSelector(PruningLabelSelector): def __init__(self, pruning_ratio: float): self.pruning_ratio: float = pruning_ratio super().__init__() def select(self, pruning_info: ModelPruningInfo) -> List[int]: labels_to_prune: List[int] = [] for _, labels, criteria in iterate_over_gate_criteria(pruning_info): criteria: np.ndarray = np.array(criteria) prune_channels_num = int(len(criteria) * self.pruning_ratio) index_for_pruning = np.argsort(criteria)[:prune_channels_num] labels_to_prune += np.array(labels)[index_for_pruning].tolist() return labels_to_prune ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: "forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: "submodule.forward", "head.predict_features", "submodule1.submodule2.forward". """ pruning_info = calibrate_model_for_pruning( model=model, dataloader=dataloader, loss_function=loss_function, n_steps=calibration_steps, epochs=calibration_epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, show_tqdm=show_tqdm, entry_point=entry_point, ) labels_to_prune = label_selector.select(pruning_info) pruned_model = prune_model( model=model, pruning_info=pruning_info, prune_labels=sorted(set(labels_to_prune)), inplace=False, ) if finetune_bn: tune_bn_stats( model=pruned_model, dataloader=dataloader, reset_bns=True, set_momentums_none=True, n_steps=calibration_steps, epochs=calibration_epochs, sample_to_model_inputs=sample_to_model_inputs, ) return pruned_model
[docs]def calibrate_and_prune_model_equal( model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], pruning_ratio: float = 0.5, finetune_bn: bool = False, calibration_steps: Optional[int] = None, calibration_epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, show_tqdm: bool = False, entry_point: str = 'forward', ) -> torch.nn.Module: """ Estimates channel importance values and prunes model with equal pruning strategy (same percentage of channels are pruned in each prunable layer). Parameters ---------- model : torch.nn.Module Model to prune. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model channel importance values. loss_function : Callable[[Any, Any], torch.Tensor] Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. pruning_ratio : float, optional Relative amount of channels to prune at each prunable layers. Default value is 0.5 (prunes 50% of channels, which typically reduces the number of FLOPs and parameters by a factor of 4). finetune_bn : bool, optional Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False. calibration_steps : int or None, optional Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in ``calibration_epochs`` argument. calibration_epochs : int, optional Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`. show_tqdm : bool, optional Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is False. entry_point : str, optional Model function for execution. See ``Notes`` section for the detailed argument description. Default value is "forward". Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_equal(...) ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: "forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: "submodule.forward", "head.predict_features", "submodule1.submodule2.forward". """ label_selector = UniformPruningLabelSelector(pruning_ratio) pruned_model = calibrate_and_prune_model( label_selector=label_selector, model=model, dataloader=dataloader, loss_function=loss_function, finetune_bn=finetune_bn, calibration_steps=calibration_steps, calibration_epochs=calibration_epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, show_tqdm=show_tqdm, entry_point=entry_point, ) return pruned_model
[docs]def calibrate_and_prune_model_optimal( model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], latency_calculation_function: Callable[[torch.nn.Module], float], target_latency: float, finetune_bn: bool = False, calibration_steps: Optional[int] = None, calibration_epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, show_tqdm: bool = False, architecture_optimization_strategy: TArchOptStrategyFunc = default_architecture_optimization_strategy, n_search_steps: int = 200, entry_point: str = 'forward', **kwargs, ) -> torch.nn.Module: """ Estimates channel importance values, searches for the optimal pruning configuration (percentage of channels to prune in each prunable layer) and prunes model. Parameters ---------- model : torch.nn.Module Model to prune. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model channel importance values. loss_function : Callable[[Any, Any], torch.Tensor] Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. latency_calculation_function : Callable[[torch.nn.Module], float] Function which calculates sample model latency. This function should take sample pruned model (torch.nn.Module) and measure its execution "speed" (float). It could be a number of FLOPs, MACs, inference speed on CPU/GPU and other "speed" criteria. target_latency : float Target model latency. This argument should have the same units as ``latency_calculation_function``'s output. finetune_bn : bool, optional Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False. calibration_steps : int or None, optional Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in ``calibration_epochs`` argument. calibration_epochs : int, optional Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`. show_tqdm : bool, optional Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is False. architecture_optimization_strategy : Callable[[int], Tuple[int, int]] Function that constructs an architecture pruning configuration. This function should take the number of channels in specific layer and return the minimal number of channels to keep and channel search step. This search step defines the number of possible pruning ratios for a specific group. If search step is equal to 1, then all number of channels after pruning are possible (with a lower bound set by the minimal number of channels). When search step is equal to 2 - the number of options is reduces by 2. This function should take the number of channels in specific layer and return search step. Default value is :func:`.default_architecture_optimization_strategy`. n_search_steps : int, optional Number of sampled configurations for pruning (equal to the number of ``latency_calculation_function`` executions) to select optimal architecture. Default value is 200. entry_point : str, optional Model function for execution. See ``Notes`` section for the detailed argument description. Default value is "forward". **kwargs Additional keyword arguments for label selector. Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_optimal(...) ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: "forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: "submodule.forward", "head.predict_features", "submodule1.submodule2.forward". """ label_selector = OptimalPruningLabelSelector( model=model, latency_calculation_function=latency_calculation_function, target_latency=target_latency, architecture_optimization_strategy=architecture_optimization_strategy, n_search_steps=n_search_steps, **kwargs, ) pruned_model = calibrate_and_prune_model( label_selector=label_selector, model=model, dataloader=dataloader, loss_function=loss_function, finetune_bn=finetune_bn, calibration_steps=calibration_steps, calibration_epochs=calibration_epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, show_tqdm=show_tqdm, entry_point=entry_point, ) return pruned_model
def calibrate_and_prune_model_global_wrt_metric_drop( model: torch.nn.Module, dataloader: DataLoader, loss_function: Callable[[Any, Any], torch.Tensor], validation_function: Callable[[torch.nn.Module], float], maximal_acceptable_metric_drop: float, minimal_channels_to_prune: int = 100, maximal_channels_to_prune: int = 300, channel_step_to_search: int = 10, finetune_bn: bool = False, calibration_steps: Optional[int] = None, calibration_epochs: int = 1, sample_to_n_samples: DataLoaderSampleToNSamples = default_sample_to_n_samples, sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs, show_tqdm: bool = False, entry_point: str = 'forward', ) -> torch.nn.Module: """ Estimates channel importance values and prunes model with global pruning strategy (same percentage of channels are pruned in each prunable layer). Parameters ---------- model : torch.nn.Module Model to prune. dataloader : torch.utils.data.DataLoader Dataloader for estimation of model channel importance values. loss_function : Callable[[Any, Any], torch.Tensor] Function which takes model output and dataloader sample and computes loss tensor. This function should take two inputs and return scalar PyTorch tensor. validation_function : Callable[[torch.nn.Module], float] Function which evaluates pruned model to measure desired metric. maximal_acceptable_metric_drop : float Maximal value of the metric decrease. minimal_channels_to_prune : int Minimal channels amount to prune within all network. maximal_channels_to_prune: int Maximal channels amount to prune within all network. channel_step_to_search : int Channel configuration search step size. The greater value gives faster but less accurate results. finetune_bn : bool, optional Finetune batch norm layers (specifically, their running mean and running variance values) for better model quality. Default value is False. calibration_steps : int or None, optional Number of total calibration steps. Default value is None, which runs calibration on all dataloader images for the number of epochs specified in ``calibration_epochs`` argument. calibration_epochs : int, optional Number of total calibration epochs. Not used when ``calibration_steps`` argument is not None. Default value is 1. sample_to_n_samples : Callable, optional Function which computes the number of instances (objects to process) in single dataloader batch (dataloader sample). This function should take single input (dataloader sample) and return single integer - the number of instances. Default value is :func:`.default_sample_to_n_samples`. See more :ref:`here <s2ns ref>`. sample_to_model_inputs : Callable, optional Function to map dataloader samples to model input format. Default value is :func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`. show_tqdm : bool, optional Whether to log calibration procedure with `tqdm <https://github.com/tqdm/tqdm>`_ progress bar. Default value is False. entry_point : str, optional Model function for execution. See ``Notes`` section for the detailed argument description. Default value is "forward". Returns ------- torch.nn.Module Pruned model. Notes ----- Before calling this function, your model should be prepared to be as close to practical inference usage as possible. For example, it is your responsibility to call ``eval`` method of your model if your inference requires calling this method (e.g. when the model contains dropout layers). Typically, it is better to calibrate model for pruning on validation-like data without augmentations (but with inference input preprocessing). If you encounter errors during backward call, you can wrap this function call with the following statement:: with torch.autograd.set_detect_anomaly(True): calibrate_and_prune_model_equal(...) ``entry_point`` is a string that specifies which function of the user's model will be traced. Simple examples: "forward", "execute_model", "forward_train". If such a function is located in the model's submodule - then you should first write the submodule's name followed by the function name, all separated by dots: "submodule.forward", "head.predict_features", "submodule1.submodule2.forward". """ baseline_metric = validation_function(model) _LOGGER.info(f'Baseline metric of user model is: {baseline_metric}') baseline_model_device = next(model.parameters()).device pruning_info = calibrate_model_for_pruning( model=model, dataloader=dataloader, loss_function=loss_function, n_steps=calibration_steps, epochs=calibration_epochs, sample_to_n_samples=sample_to_n_samples, sample_to_model_inputs=sample_to_model_inputs, show_tqdm=show_tqdm, entry_point=entry_point, ) model = model.cpu() prev_pruned_model = None for n_channels in range(minimal_channels_to_prune, maximal_channels_to_prune + 1, channel_step_to_search): _LOGGER.info(f"Current pruned channels: {n_channels}") label_selector = GlobalPruningLabelSelectorByChannels(n_channels) labels_to_prune = label_selector.select(pruning_info) pruned_model = prune_model( model=model, pruning_info=pruning_info, prune_labels=sorted(set(labels_to_prune)), inplace=False, ) pruned_model = pruned_model.to(device=baseline_model_device) if finetune_bn: tune_bn_stats( model=pruned_model, dataloader=dataloader, reset_bns=True, set_momentums_none=True, n_steps=calibration_steps, epochs=calibration_epochs, sample_to_model_inputs=sample_to_model_inputs, ) p_metric = validation_function(pruned_model) pruned_model = pruned_model.cpu() _LOGGER.info(f"Metric after pruning: {p_metric}") if baseline_metric - p_metric > maximal_acceptable_metric_drop: if prev_pruned_model is None: prev_pruned_model = pruned_model model.to(device=baseline_model_device) return prev_pruned_model.to(device=baseline_model_device) prev_pruned_model = pruned_model model.to(device=baseline_model_device) return prev_pruned_model.to(device=baseline_model_device)