Pruning v2
The enot.pruning_v2
module is the successor of enot.pruning module.
It solves the same problem,
but has a more powerful engine and will support more operations and types of pruning in the future.
All the basic concepts used in this module are described here.
Module exports the following names:
The pruning workflow can be described as follows:
Create
PruningIR
using a model.Create
LossAwareCriterion
context manager using PruningIR and collect scores inside this context (calibration).Create label selector, for example
UniformLabelSelector
and select the labels to be pruned using theselect()
method.Prune model using
prune()
method of PruningIR and selected labels.
The pruning code of user-defined model looks like this:
from enot.pruning_v2 import PruningIR
from enot.pruning_v2 import LossAwareCriterion
from enot.pruning_v2 import UniformLabelSelector
# 1. create PruningIR
ir = PruningIR(model)
# 2. create criterion and collect scores (calibration)
with LossAwareCriterion(ir) as score_collector:
for sample in train_loader:
model_output = score_collector(sample)
loss = loss_function(model_output, sample)
loss.backward()
# 3. create label selector and select the labels to be pruned
label_selector = UniformLabelSelector(pruning_ratio=0.5)
labels_for_pruning = label_selector.select(ir)
# 4. prune model using selected labels
ir = ir.prune(labels=labels_for_pruning)
pruned_model = ir.model
- class PruningIR(model, *, leaf_modules=None, soft_tracing=True)
The main class containing internal pruning data and providing a high-level interface for pruning user-defined models.
Notes
It should only be serialized/deserialized using the
dill
module:>>> torch.save(ir, 'ir.pt', pickle_module=dill) >>> torch.load('ir.pt', pickle_module=dill)
- __init__(model, *, leaf_modules=None, soft_tracing=True)
- Parameters:
model (torch.nn.Module) – Model for pruning.
leaf_modules (list of Type[torch.nn.Module] or torch.nn.Module, Optional) – Types of modules or module instances that must be interpreted as leaf modules. Leaf modules are the atomic units that appear in the IR.
soft_tracing (bool) – If True then untraceable modules will be interpreted as leaf modules. True by default.
- prune(name='main', *, labels, inplace=False, skip_label_checking=False)
Run pruning procedure.
- Parameters:
name (str) – Name of the snapshot for pruning. In most cases this parameter can be ommited.
labels (List[Label]) – List of labels that should be pruned.
inplace (bool) – Enables inplace modification of input model (reduces memory consumption). False by default.
skip_label_checking (bool) – If True, labels flags consistency checks will be skipped. Default value is False.
- Returns:
Reference to the pruned intermediate representation.
- Return type:
Notes
If labels is empty and inplace is False, a copy of the current intermediate representation will be returned.
- class LossAwareCriterion(ir, snapshot_name='main')
Default ENOT pruning Criterion.
- class UniformLabelSelector(pruning_ratio)
Label selector for uniform (equal) pruning.
Removes an equal percentage of channels in each prunable layer.
- __init__(pruning_ratio)
- Parameters:
pruning_ratio (float) – Ratio of prunable channels to remove in each group.
- select(ir, snapshot_name='main')
Method that chooses which labels should be pruned based on current label selector policy.
Warning
Depending on label selector implementation, this function may have significant execution time.
- class GlobalScoreLabelSelector(n_or_ratio)
Label selector for global pruning.
Selects the least important channels within network.
- __init__(n_or_ratio)
- Parameters:
n_or_ratio (Number) – Number of labels or labels ratio to remove within all network. If the parameter is in the range (0, 1), it is interpreted as a fraction of all prunable unique labels. If the parameter greater or equal to 1, it is interpreted as a number of labels to remove.
- select(ir, snapshot_name='main')
Method that chooses which labels should be pruned based on current label selector policy.
Warning
Depending on label selector implementation, this function may have significant execution time.
- class KnapsackLabelSelector(target_latency, latency_calculation_function, strict_label_order=True, channels_constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
Label selector based on knapsack algorithm.
We assume that the label score is value and latency is weight from the classical algorithm, so
KnapsackLabelSelector
maximizes total score of the model with latency constraint. This algorithm cannot be used to find an exact solution because latency of labels have non-linear dependency, but the iterative approach gives good results. At each iteration, the label selector recalculates latency estimation for each label according to the selection at the previous iteration and solve knapsack problem with target latency constraint, until problem converged to this constraint.- __init__(target_latency, latency_calculation_function, strict_label_order=True, channels_constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
- select(ir, snapshot_name='main')
Method that chooses which labels should be pruned based on current label selector policy.
Warning
Depending on label selector implementation, this function may have significant execution time.