Pruning v2
The enot.pruning_v2
module is the successor of enot.pruning module.
It solves the same problem,
but has a more powerful engine and will support more operations and types of pruning in the future.
All the basic concepts used in this module are described here.
Module exports the following names:
(label selectors)
The pruning workflow can be described as follows:
Create
PruningIR
using a model.Create
LossAwareCriterion
context manager using PruningIR and collect scores inside this context (calibration).Create label selector, for example
UniformLabelSelector
and select the labels to be pruned using theselect()
method.Prune model using
prune()
method of PruningIR and selected labels.
The pruning code of user-defined model looks like this:
from enot.pruning_v2 import PruningIR
from enot.pruning_v2 import LossAwareCriterion
from enot.pruning_v2 import UniformLabelSelector
# 1. create PruningIR
ir = PruningIR(model)
# 2. create criterion and collect scores (calibration)
with LossAwareCriterion(ir) as score_collector:
for sample in train_loader:
model_output = score_collector(sample)
loss = loss_function(model_output, sample)
loss.backward()
# 3. create label selector and select the labels to be pruned
label_selector = UniformLabelSelector(pruning_ratio=0.5)
labels_for_pruning = label_selector.select(ir.snapshot())
# 4. prune model using selected labels
ir = ir.prune(labels=labels_for_pruning)
pruned_model = ir.model
- class PruningIR(model, *, leaf_modules=None, registered_modules_as_leaf_modules=True, soft_tracing=True)
The main class containing internal pruning data and providing a high-level interface for pruning user-defined models.
- force_thunk_update
Whether to force updates to arguments that are
thunks
. Default value isFalse
.- Type:
- attention_pruning_cfg
Attention pruning configuration.
- Type:
Dict
Notes
It should only be serialized/deserialized using the
dill
module:>>> torch.save(ir, 'ir.pt', pickle_module=dill) >>> torch.load('ir.pt', pickle_module=dill)
- __init__(model, *, leaf_modules=None, registered_modules_as_leaf_modules=True, soft_tracing=True)
- Parameters:
model (torch.nn.Module) – Model for pruning.
leaf_modules (list of Type[torch.nn.Module], torch.nn.Module or PruningIR.LeafModule, Optional) – Types of modules or module instances that must be interpreted as leaf modules. Leaf modules are the atomic units that appear in the IR. It is also possible to mark inputs/outputs of leaf modules as non-prunable, using
PruningIR.LeafModule
, it possible only for modules that have their own handler. If the specified module (type) does not have its own handler, then its input and output will be unprunable in any case.registered_modules_as_leaf_modules (bool) – Whether
torch.nn.Module
-modules registered in the label inference registry should be treated as leaf modules or not. Default value is True.soft_tracing (bool) – If True then untraceable modules will be interpreted as leaf modules. True by default.
- prune(name='main', *, labels, inplace=False, skip_label_checking=False)
Run pruning procedure.
- Parameters:
name (str) – Name of the snapshot for pruning. In most cases this parameter can be ommited.
labels (List[Label]) – List of labels that should be pruned.
inplace (bool) – Enables inplace modification of input model (reduces memory consumption). False by default.
skip_label_checking (bool) – If True, labels flags consistency checks will be skipped. Default value is False.
- Returns:
Reference to the pruned intermediate representation.
- Return type:
Notes
If labels is empty and inplace is False, a copy of the current intermediate representation will be returned.
- class LossAwareCriterion(ir, snapshot_name='main')
Default ENOT pruning Criterion.
- class UniformLabelSelector(pruning_ratio)
Label selector for uniform (equal) pruning.
Removes an equal percentage of labels in each prunable layer.
- class GlobalScoreLabelSelector(n_or_ratio)
Label selector for global pruning.
Selects the least important labels within network.
- __init__(n_or_ratio)
- Parameters:
n_or_ratio (Number) – Number of labels or labels ratio to remove within all network. If the parameter is in the range (0, 1), it is interpreted as a fraction of all prunable unique labels. If the parameter greater or equal to 1, it is interpreted as a number of labels to remove.
- class BinarySearchLatencyLabelSelector(target_latency, latency_calculation_function, selector_cb=None)
Based on binary search algorithm label selector.
It finds the model with latency as close as possible to the target latency parameter, and always selects the labels with the lowest scores.
- __init__(target_latency, latency_calculation_function, selector_cb=None)
- Parameters:
target_latency (float) – Target model latency. This argument should have the same units as output of
latency_calculation_function
.latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.
selector_cb (Optional[Callable[[float, float], bool]]) – An optional callback that is called at each iteration of the search process. The callback should take the current latency as the first parameter and the target latency as the second parameter and return True if the search procedure should be stopped, False otherwise. Can be used for logging and early stop.
- class KnapsackLabelSelector(target_latency, latency_calculation_function, *, max_relative_latency_step=0.5, strict_label_order=True, constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
Label selector based on knapsack algorithm.
We assume that the label score is value and latency is weight from the classical algorithm, so
KnapsackLabelSelector
maximizes total score of the model with latency constraint. This algorithm cannot be used to find an exact solution because latency of labels have non-linear dependency, but the iterative approach gives good results. At each iteration, the label selector recalculates latency estimation for each label according to the selection at the previous iteration and solve knapsack problem with target latency constraint, until problem converged to this constraint.- __init__(target_latency, latency_calculation_function, *, max_relative_latency_step=0.5, strict_label_order=True, constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
- Parameters:
target_latency (float) – Target model latency. This argument should have the same units as output of
latency_calculation_function
.latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.
max_relative_latency_step (float) – Determines the maximum relative difference between the baseline latency and target latency for one knapsack iteration. The target latency for each iteration is calculated as
max(latency * (1 - max_relative_latency_step), target_latency)
that allows to minimize score loss in case of aggressive acceleration (more than x2). Value must be in the range (0, 1]. Default value is 0.5.strict_label_order (bool) – If True, then the labels within a group are always selected in ascending order of score, otherwise any order of selection is acceptable. Default value is True.
constraint (Optional[ChannelsSelectorConstraint]) – Optional
ChannelsSelectorConstraint
instance, that calculates low, high and step for the constraint for total number of labels in group. If None, thenDefaultChannelsSelectorConstraint
will be used. None by default.max_iterations (int) – Maximum number of iterations. Default value is 15.
solver_execution_time_limit (Optional[Number]) – Time limit in seconds for solving a linear problem at each iteration. Default value is 60.
verbose (bool) – Whether to print selection statistics and show progress bar or not. True by default.
- class SCBOLabelSelector(target_latency, latency_calculation_function, n_search_steps, n_starting_points=None, batch_size=4, constraint=None, additional_starting_points_generator=None, device=None, dtype=torch.float64, verbose=True, handle_ctrl_c=False, seed=None, scbo_kwargs=None)
Label selector based on SCBO algorithm.
- __init__(target_latency, latency_calculation_function, n_search_steps, n_starting_points=None, batch_size=4, constraint=None, additional_starting_points_generator=None, device=None, dtype=torch.float64, verbose=True, handle_ctrl_c=False, seed=None, scbo_kwargs=None)
- Parameters:
target_latency (float) – Target model latency. This argument should have the same units as output of
latency_calculation_function
.latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria. If latency measurement cannot be performed for a particular model configuration due to technical reasons (e.g., problems with measurements on a server), then
latency_calculation_function
should raiseLatencyMeasurementError
.n_search_steps (Optional[int]) – Number of configurations (samples) for pruning to find optimal architecture. If None, then search loop will stop when SCBO converges.
n_starting_points (Optional[int]) – Number of sampled configurations used for startup step. If None, then
n_starting_points
calculates asmin(total_number_of_groups_in_model * 2, n_search_steps - 1)
. None by default.batch_size (int) – Batch size for SCBO algorithm. 4 by default.
constraint (Optional[ChannelsSelectorConstraint]) – Optional
ChannelsSelectorConstraint
instance, that calculates low, high and step for the constraint for total number of labels in group. If None, thenDefaultChannelsSelectorConstraint
will be used. None by default.additional_starting_points_generator (Optional[StartingPointsGenerator]) – User defined generator of additional starting points for SCBO algorithm. None by default.
device (Optional[Union[str, torch.device]]) – The desired device for SCBO algorithm. None by default (cuda if possible).
dtype (torch.dtype) – The desired dtype of SCBO algorithm. torch.double by default.
verbose (bool) – Whether to show progress bar or not. True by default.
handle_ctrl_c (bool) – Whether to allow early stop by Ctrl+c or not. False by default.
seed (Optional[int]) – Optional seed for SCBO algorithm. Default value is None.
scbo_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments for SCBO (developers only).
- set_search_step_cb(search_step_cb)
Set search step callback. This callback will be called every time a pair (score, latency) improvement occurs during the search procedure.
- Parameters:
search_step_cb (Callable[[List[Label], int, float, float], bool]) –
argument, (Callback function. This function should take list of labels (which were selected by SCBO) as first) –
new (number of the best step as second argument, new best score of pruned model as third argument and) –
terminated. (latency of pruned model as the last argument. If callback returns True, the search is) –
- Return type:
- class LatencyMeasurementError
Must be raised by a latency measurement function if the measurement cannot be performed for a particular model configuration.