Pruning v2

The enot.pruning_v2 module is the successor of enot.pruning module. It solves the same problem, but has a more powerful engine and will support more operations and types of pruning in the future.

All the basic concepts used in this module are described here.

Module exports the following names:

The pruning workflow can be described as follows:

  1. Create PruningIR using a model.

  2. Create LossAwareCriterion context manager using PruningIR and collect scores inside this context (calibration).

  3. Create label selector, for example UniformLabelSelector and select the labels to be pruned using the select() method.

  4. Prune model using prune() method of PruningIR and selected labels.

The pruning code of user-defined model looks like this:

from enot.pruning_v2 import PruningIR
from enot.pruning_v2 import LossAwareCriterion
from enot.pruning_v2 import UniformLabelSelector

# 1. create PruningIR
ir = PruningIR(model)

# 2. create criterion and collect scores (calibration)
with LossAwareCriterion(ir) as score_collector:
    for sample in train_loader:
        model_output = score_collector(sample)
        loss = loss_function(model_output, sample)
        loss.backward()

# 3. create label selector and select the labels to be pruned
label_selector = UniformLabelSelector(pruning_ratio=0.5)
labels_for_pruning = label_selector.select(ir)

# 4. prune model using selected labels
ir = ir.prune(labels=labels_for_pruning)
pruned_model = ir.model
class PruningIR(model, *, leaf_modules=None, soft_tracing=True)

The main class containing internal pruning data and providing a high-level interface for pruning user-defined models.

Notes

It should only be serialized/deserialized using the dill module:

>>> torch.save(ir, 'ir.pt', pickle_module=dill)
>>> torch.load('ir.pt', pickle_module=dill)
__init__(model, *, leaf_modules=None, soft_tracing=True)
Parameters:
  • model (torch.nn.Module) – Model for pruning.

  • leaf_modules (list of Type[torch.nn.Module] or torch.nn.Module, Optional) – Types of modules or module instances that must be interpreted as leaf modules. Leaf modules are the atomic units that appear in the IR.

  • soft_tracing (bool) – If True then untraceable modules will be interpreted as leaf modules. True by default.

prune(name='main', *, labels, inplace=False, skip_label_checking=False)

Run pruning procedure.

Parameters:
  • name (str) – Name of the snapshot for pruning. In most cases this parameter can be ommited.

  • labels (List[Label]) – List of labels that should be pruned.

  • inplace (bool) – Enables inplace modification of input model (reduces memory consumption). False by default.

  • skip_label_checking (bool) – If True, labels flags consistency checks will be skipped. Default value is False.

Returns:

Reference to the pruned intermediate representation.

Return type:

PruningIR

Notes

If labels is empty and inplace is False, a copy of the current intermediate representation will be returned.

class LossAwareCriterion(ir, snapshot_name='main')

Default ENOT pruning Criterion.

__init__(ir, snapshot_name='main')
Parameters:
  • ir (PruningIR) – Model intermediate representation.

  • snapshot_name (str) – Name of the snapshot to be used for calibration. Default value is DEFAULT_SNAPSHOT_NAME.

class UniformLabelSelector(pruning_ratio)

Label selector for uniform (equal) pruning.

Removes an equal percentage of channels in each prunable layer.

__init__(pruning_ratio)
Parameters:

pruning_ratio (float) – Ratio of prunable channels to remove in each group.

select(ir, snapshot_name='main')

Method that chooses which labels should be pruned based on current label selector policy.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:
  • ir (PruningIR) – Model intermediate representation.

  • snapshot_name (str) – Name of the snapshot.

Returns:

List of channel labels which should be pruned.

Return type:

list of Label

class GlobalScoreLabelSelector(n_or_ratio)

Label selector for global pruning.

Selects the least important channels within network.

__init__(n_or_ratio)
Parameters:

n_or_ratio (Number) – Number of labels or labels ratio to remove within all network. If the parameter is in the range (0, 1), it is interpreted as a fraction of all prunable unique labels. If the parameter greater or equal to 1, it is interpreted as a number of labels to remove.

select(ir, snapshot_name='main')

Method that chooses which labels should be pruned based on current label selector policy.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:
  • ir (PruningIR) – Model intermediate representation.

  • snapshot_name (str) – Name of the snapshot.

Returns:

List of channel labels which should be pruned.

Return type:

list of Label

class KnapsackLabelSelector(target_latency, latency_calculation_function, strict_label_order=True, channels_constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)

Label selector based on knapsack algorithm.

We assume that the label score is value and latency is weight from the classical algorithm, so KnapsackLabelSelector maximizes total score of the model with latency constraint. This algorithm cannot be used to find an exact solution because latency of labels have non-linear dependency, but the iterative approach gives good results. At each iteration, the label selector recalculates latency estimation for each label according to the selection at the previous iteration and solve knapsack problem with target latency constraint, until problem converged to this constraint.

__init__(target_latency, latency_calculation_function, strict_label_order=True, channels_constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
Parameters:
property max_iterations: int

Maximum number of iterations.

select(ir, snapshot_name='main')

Method that chooses which labels should be pruned based on current label selector policy.

Warning

Depending on label selector implementation, this function may have significant execution time.

Parameters:
  • ir (PruningIR) – Model intermediate representation.

  • snapshot_name (str) – Name of the snapshot.

Returns:

List of channel labels which should be pruned.

Return type:

list of Label