Pruning v2

The enot.pruning_v2 module is the successor of enot.pruning module. It solves the same problem, but has a more powerful engine and will support more operations and types of pruning in the future.

All the basic concepts used in this module are described here.

Module exports the following names:

(label selectors)

The pruning workflow can be described as follows:

  1. Create PruningIR using a model.

  2. Create LossAwareCriterion context manager using PruningIR and collect scores inside this context (calibration).

  3. Create label selector, for example UniformLabelSelector and select the labels to be pruned using the select() method.

  4. Prune model using prune() method of PruningIR and selected labels.

The pruning code of user-defined model looks like this:

from enot.pruning_v2 import PruningIR
from enot.pruning_v2 import LossAwareCriterion
from enot.pruning_v2 import UniformLabelSelector

# 1. create PruningIR
ir = PruningIR(model)

# 2. create criterion and collect scores (calibration)
with LossAwareCriterion(ir) as score_collector:
    for sample in train_loader:
        model_output = score_collector(sample)
        loss = loss_function(model_output, sample)
        loss.backward()

# 3. create label selector and select the labels to be pruned
label_selector = UniformLabelSelector(pruning_ratio=0.5)
labels_for_pruning = label_selector.select(ir.snapshot())

# 4. prune model using selected labels
ir = ir.prune(labels=labels_for_pruning)
pruned_model = ir.model
class PruningIR(model, *, leaf_modules=None, registered_modules_as_leaf_modules=True, soft_tracing=True)

The main class containing internal pruning data and providing a high-level interface for pruning user-defined models.

force_thunk_update

Whether to force updates to arguments that are thunks. Default value is False.

Type:

bool

attention_pruning_cfg

Attention pruning configuration.

Type:

Dict

Notes

It should only be serialized/deserialized using the dill module:

>>> torch.save(ir, 'ir.pt', pickle_module=dill)
>>> torch.load('ir.pt', pickle_module=dill)
__init__(model, *, leaf_modules=None, registered_modules_as_leaf_modules=True, soft_tracing=True)
Parameters:
  • model (torch.nn.Module) – Model for pruning.

  • leaf_modules (list of Type[torch.nn.Module], torch.nn.Module or PruningIR.LeafModule, Optional) – Types of modules or module instances that must be interpreted as leaf modules. Leaf modules are the atomic units that appear in the IR. It is also possible to mark inputs/outputs of leaf modules as non-prunable, using PruningIR.LeafModule, it possible only for modules that have their own handler. If the specified module (type) does not have its own handler, then its input and output will be unprunable in any case.

  • registered_modules_as_leaf_modules (bool) – Whether torch.nn.Module-modules registered in the label inference registry should be treated as leaf modules or not. Default value is True.

  • soft_tracing (bool) – If True then untraceable modules will be interpreted as leaf modules. True by default.

prune(name='main', *, labels, inplace=False, skip_label_checking=False)

Run pruning procedure.

Parameters:
  • name (str) – Name of the snapshot for pruning. In most cases this parameter can be ommited.

  • labels (List[Label]) – List of labels that should be pruned.

  • inplace (bool) – Enables inplace modification of input model (reduces memory consumption). False by default.

  • skip_label_checking (bool) – If True, labels flags consistency checks will be skipped. Default value is False.

Returns:

Reference to the pruned intermediate representation.

Return type:

PruningIR

Notes

If labels is empty and inplace is False, a copy of the current intermediate representation will be returned.

class LossAwareCriterion(ir, snapshot_name='main')

Default ENOT pruning Criterion.

__init__(ir, snapshot_name='main')
Parameters:
  • ir (PruningIR) – Model intermediate representation.

  • snapshot_name (str) – Name of the snapshot to be used for calibration. Default value is DEFAULT_SNAPSHOT_NAME.

class UniformLabelSelector(pruning_ratio)

Label selector for uniform (equal) pruning.

Removes an equal percentage of labels in each prunable layer.

__init__(pruning_ratio)
Parameters:

pruning_ratio (float) – Ratio of prunable labels to remove in each group.

class GlobalScoreLabelSelector(n_or_ratio)

Label selector for global pruning.

Selects the least important labels within network.

__init__(n_or_ratio)
Parameters:

n_or_ratio (Number) – Number of labels or labels ratio to remove within all network. If the parameter is in the range (0, 1), it is interpreted as a fraction of all prunable unique labels. If the parameter greater or equal to 1, it is interpreted as a number of labels to remove.

class BinarySearchLatencyLabelSelector(target_latency, latency_calculation_function, selector_cb=None)

Based on binary search algorithm label selector.

It finds the model with latency as close as possible to the target latency parameter, and always selects the labels with the lowest scores.

__init__(target_latency, latency_calculation_function, selector_cb=None)
Parameters:
  • target_latency (float) – Target model latency. This argument should have the same units as output of latency_calculation_function.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.

  • selector_cb (Optional[Callable[[float, float], bool]]) – An optional callback that is called at each iteration of the search process. The callback should take the current latency as the first parameter and the target latency as the second parameter and return True if the search procedure should be stopped, False otherwise. Can be used for logging and early stop.

class KnapsackLabelSelector(target_latency, latency_calculation_function, *, max_relative_latency_step=0.5, strict_label_order=True, constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)

Label selector based on knapsack algorithm.

We assume that the label score is value and latency is weight from the classical algorithm, so KnapsackLabelSelector maximizes total score of the model with latency constraint. This algorithm cannot be used to find an exact solution because latency of labels have non-linear dependency, but the iterative approach gives good results. At each iteration, the label selector recalculates latency estimation for each label according to the selection at the previous iteration and solve knapsack problem with target latency constraint, until problem converged to this constraint.

__init__(target_latency, latency_calculation_function, *, max_relative_latency_step=0.5, strict_label_order=True, constraint=None, max_iterations=15, solver_execution_time_limit=60, verbose=True)
Parameters:
  • target_latency (float) – Target model latency. This argument should have the same units as output of latency_calculation_function.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria.

  • max_relative_latency_step (float) – Determines the maximum relative difference between the baseline latency and target latency for one knapsack iteration. The target latency for each iteration is calculated as max(latency * (1 - max_relative_latency_step), target_latency) that allows to minimize score loss in case of aggressive acceleration (more than x2). Value must be in the range (0, 1]. Default value is 0.5.

  • strict_label_order (bool) – If True, then the labels within a group are always selected in ascending order of score, otherwise any order of selection is acceptable. Default value is True.

  • constraint (Optional[ChannelsSelectorConstraint]) – Optional ChannelsSelectorConstraint instance, that calculates low, high and step for the constraint for total number of labels in group. If None, then DefaultChannelsSelectorConstraint will be used. None by default.

  • max_iterations (int) – Maximum number of iterations. Default value is 15.

  • solver_execution_time_limit (Optional[Number]) – Time limit in seconds for solving a linear problem at each iteration. Default value is 60.

  • verbose (bool) – Whether to print selection statistics and show progress bar or not. True by default.

property max_iterations: int

Maximum number of iterations.

property max_relative_latency_step: float

Maximum relative difference between baseline latency and target latency for one knapsack iteration.

class SCBOLabelSelector(target_latency, latency_calculation_function, n_search_steps, n_starting_points=None, batch_size=4, constraint=None, additional_starting_points_generator=None, device=None, dtype=torch.float64, verbose=True, handle_ctrl_c=False, seed=None, scbo_kwargs=None)

Label selector based on SCBO algorithm.

__init__(target_latency, latency_calculation_function, n_search_steps, n_starting_points=None, batch_size=4, constraint=None, additional_starting_points_generator=None, device=None, dtype=torch.float64, verbose=True, handle_ctrl_c=False, seed=None, scbo_kwargs=None)
Parameters:
  • target_latency (float) – Target model latency. This argument should have the same units as output of latency_calculation_function.

  • latency_calculation_function (Callable[[torch.nn.Module], float]) – Function that calculates model latency. It should take model (torch.nn.Module) and measure the “speed” (float) of its execution. It could be a number of FLOPs, MACs, inference time on CPU/GPU or other “speed” criteria. If latency measurement cannot be performed for a particular model configuration due to technical reasons (e.g., problems with measurements on a server), then latency_calculation_function should raise LatencyMeasurementError.

  • n_search_steps (Optional[int]) – Number of configurations (samples) for pruning to find optimal architecture. If None, then search loop will stop when SCBO converges.

  • n_starting_points (Optional[int]) – Number of sampled configurations used for startup step. If None, then n_starting_points calculates as min(total_number_of_groups_in_model * 2, n_search_steps - 1). None by default.

  • batch_size (int) – Batch size for SCBO algorithm. 4 by default.

  • constraint (Optional[ChannelsSelectorConstraint]) – Optional ChannelsSelectorConstraint instance, that calculates low, high and step for the constraint for total number of labels in group. If None, then DefaultChannelsSelectorConstraint will be used. None by default.

  • additional_starting_points_generator (Optional[StartingPointsGenerator]) – User defined generator of additional starting points for SCBO algorithm. None by default.

  • device (Optional[Union[str, torch.device]]) – The desired device for SCBO algorithm. None by default (cuda if possible).

  • dtype (torch.dtype) – The desired dtype of SCBO algorithm. torch.double by default.

  • verbose (bool) – Whether to show progress bar or not. True by default.

  • handle_ctrl_c (bool) – Whether to allow early stop by Ctrl+c or not. False by default.

  • seed (Optional[int]) – Optional seed for SCBO algorithm. Default value is None.

  • scbo_kwargs (Optional[Dict[str, Any]]) – Additional keyword arguments for SCBO (developers only).

property batch_size: int

SCBO batch size.

property n_search_steps: int | None

Number of search steps.

property n_starting_points: int | None

Number of starting points.

set_search_step_cb(search_step_cb)

Set search step callback. This callback will be called every time a pair (score, latency) improvement occurs during the search procedure.

Parameters:
  • search_step_cb (Callable[[List[Label], int, float, float], bool]) –

  • argument, (Callback function. This function should take list of labels (which were selected by SCBO) as first) –

  • new (number of the best step as second argument, new best score of pruned model as third argument and) –

  • terminated. (latency of pruned model as the last argument. If callback returns True, the search is) –

Return type:

None

property verbose: bool

Whether to show progress bar or not.

class LatencyMeasurementError

Must be raised by a latency measurement function if the measurement cannot be performed for a particular model configuration.