[docs]def default_channel_search_step(current_layer_channels: int) -> int:
"""
Default function to specify search step for channel group by its size.
It returns step equal to 1 if channel group has size less or equal to 64, and returns step equal to 4 otherwise.
Parameters
----------
current_layer_channels : int
Number of channels in a layer.
Returns
-------
int
Search step for :ref:`pruning config` generation.
"""
return 4 if current_layer_channels > 64 else 1
[docs]def default_min_channels_to_leave(
current_layer_channels: int,
absolute_min_channels: int = 16,
relative_min_channels: float = 0.25,
) -> int:
"""
Default function to specify the number of channels to keep by the total number of channels in a layer.
This function allows to set two boundaries for minimal number of channels in each layer: absolute and relative.
Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels
in a layer. Result is a maximal value among these two.
Parameters
----------
current_layer_channels : int
Number of channels in a layer.
absolute_min_channels : int, optional
Absolute lower bound. If it is larger than ``current_layer_channels`` - then ``current_layer_channels`` is
returned. Default value is 16.
relative_min_channels : float, optional
Relative lower bound. Default value is 0.25.
Returns
-------
int
The minimal number of channels to keep among ``current_layer_channels``.
"""
return min(
current_layer_channels,
max(
absolute_min_channels,
int(relative_min_channels * current_layer_channels),
),
)
[docs]class PruningLabelSelector(ABC):
"""
Base class for all label selectors for pruning.
This class defines an abstract method :meth:`PruningLabelSelector.select` which should return labels to prune.
"""
def __init__(self):
self._labels: Optional[List[Label]] = None
@abstractmethod
def select(self, pruning_info: ModelPruningInfo) -> List[Label]:
"""
Method that chooses which labels should be pruned based on current label selector policy and saves them in label
selector instance.
.. warning::
Depending on label selector implementation, this function may have significant execution time.
Parameters
----------
pruning_info : ModelPruningInfo
Pruning-related information collected during calibration stage
(see more in :func:`.calibrate_model_for_pruning`).
Returns
-------
list of int
List of channel labels which should be pruned.
"""
pass
@property
def labels_for_pruning(self) -> List[Label]:
"""
Labels to prune.
Returns
-------
list of int
List of all channel labels to prune.
Raises
------
RuntimeError
If labels were not calculated (:meth:`PruningLabelSelector.select` method was not called).
"""
if self._labels is not None:
return self._labels
raise RuntimeError('Labels were not calculated')
[docs]class TopKPruningLabelSelector(PruningLabelSelector):
"""
Base class for label selectors based on pruning config.
For description of pruning configs see :ref:`here <pruning config>`.
This class implements :meth:`PruningLabelSelector.select` method by calling
:meth:`TopKPruningLabelSelector.get_config_for_pruning` abstract method to generate pruning config, and selects
top-k least important channels from each group with k values based on pruning config.
"""
def __init__(self):
self._pruning_cfg: Optional[PruningConfig] = None
super().__init__()
@abstractmethod
def get_config_for_pruning(self, pruning_info: ModelPruningInfo) -> PruningConfig:
"""
Abstract method which should implement :ref:`pruning config` selection strategy.
Parameters
----------
pruning_info : ModelPruningInfo
Pruning-related information collected during calibration stage
(see more in :func:`.calibrate_model_for_pruning`).
Returns
-------
:ref:`PruningConfig <pruning config>`
Config for pruning.
"""
pass
def select(self, pruning_info: ModelPruningInfo) -> List[Label]:
self._pruning_cfg = self.get_config_for_pruning(pruning_info)
self._labels = self.get_labels_by_config(pruning_info, self._pruning_cfg)
return self._labels
@staticmethod
def get_labels_by_config(pruning_info: ModelPruningInfo, pruning_config: PruningConfig) -> List[Label]:
"""
Transforms :ref:`pruning config` into channel labels to prune.
Parameters
----------
pruning_info : ModelPruningInfo
Pruning-related information collected during calibration stage
(see more in :func:`.calibrate_model_for_pruning`).
pruning_config : :ref:`PruningConfig <pruning config>`
Config for pruning.
Returns
-------
list of int
Channel indices to prune (remove from the model).
"""
return get_least_important_labels_by_config(pruning_info, pruning_config)
@property
def pruning_cfg(self) -> Optional[PruningConfig]:
"""
Returns current pruning config.
Returns
-------
pruning_config : :ref:`PruningConfig <pruning config>`
Config for pruning.
"""
return self._pruning_cfg