Source code for enot.pruning.label_selector

[docs]def default_channel_search_step(current_layer_channels: int) -> int: """ Default function to specify search step for channel group by its size. It returns step equal to 1 if channel group has size less or equal to 64, and returns step equal to 4 otherwise. Parameters ---------- current_layer_channels : int Number of channels in a layer. Returns ------- int Search step for :ref:`pruning config` generation. """ return 4 if current_layer_channels > 64 else 1
[docs]def default_min_channels_to_leave( current_layer_channels: int, absolute_min_channels: int = 16, relative_min_channels: float = 0.25, ) -> int: """ Default function to specify the number of channels to keep by the total number of channels in a layer. This function allows to set two boundaries for minimal number of channels in each layer: absolute and relative. Absolute boundary is the minimal number of channels, and relative boundary is a fraction of the number of channels in a layer. Result is a maximal value among these two. Parameters ---------- current_layer_channels : int Number of channels in a layer. absolute_min_channels : int, optional Absolute lower bound. If it is larger than ``current_layer_channels`` - then ``current_layer_channels`` is returned. Default value is 16. relative_min_channels : float, optional Relative lower bound. Default value is 0.25. Returns ------- int The minimal number of channels to keep among ``current_layer_channels``. """ return min( current_layer_channels, max( absolute_min_channels, int(relative_min_channels * current_layer_channels), ), )
[docs]class PruningLabelSelector(ABC): """ Base class for all label selectors for pruning. This class defines an abstract method :meth:`PruningLabelSelector.select` which should return labels to prune. """ def __init__(self): self._labels: Optional[List[int]] = None @abstractmethod def select(self, pruning_info: ModelPruningInfo) -> List[int]: """ Method that chooses which labels should be pruned based on current label selector policy and saves them in label selector instance. .. warning:: Depending on label selector implementation, this function may have significant execution time. Parameters ---------- pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :func:`.calibrate_model_for_pruning`). Returns ------- list of int List of channel labels which should be pruned. """ pass @property def labels_for_pruning(self) -> List[int]: """ Labels to prune. Returns ------- list of int List of all channel labels to prune. Raises ------ RuntimeError If labels were not calculated (:meth:`PruningLabelSelector.select` method was not called). """ if self._labels is not None: return self._labels raise RuntimeError('Labels were not calculated')
[docs]class TopKPruningLabelSelector(PruningLabelSelector): """ Base class for label selectors based on pruning config. For description of pruning configs see :ref:`here <pruning config>`. This class implements :meth:`PruningLabelSelector.select` method by calling :meth:`TopKPruningLabelSelector.get_config_for_pruning` abstract method to generate pruning config, and selects top-k least important channels from each group with k values based on pruning config. """ def __init__(self): self._pruning_cfg: Optional[PruningConfig] = None super().__init__() @abstractmethod def get_config_for_pruning(self, pruning_info: ModelPruningInfo) -> PruningConfig: """ Abstract method which should implement :ref:`pruning config` selection strategy. Parameters ---------- pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :func:`.calibrate_model_for_pruning`). Returns ------- :ref:`PruningConfig <pruning config>` Config for pruning. """ pass def select(self, pruning_info: ModelPruningInfo) -> List[int]: self._pruning_cfg = self.get_config_for_pruning(pruning_info) self._labels = self.get_labels_by_config(pruning_info, self._pruning_cfg) return self._labels @staticmethod def get_labels_by_config(pruning_info: ModelPruningInfo, pruning_config: PruningConfig) -> List[int]: """ Transforms :ref:`pruning config` into channel labels to prune. Parameters ---------- pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :func:`.calibrate_model_for_pruning`). pruning_config : :ref:`PruningConfig <pruning config>` Config for pruning. Returns ------- list of int Channel indices to prune (remove from the model). """ return get_least_important_labels_by_config(pruning_info, pruning_config) @property def pruning_cfg(self) -> Optional[PruningConfig]: """ Returns current pruning config. Returns ------- pruning_config : :ref:`PruningConfig <pruning config>` Config for pruning. """ return self._pruning_cfg
[docs]class UniformPruningLabelSelector(TopKPruningLabelSelector): """ Label selector for uniform (equal) pruning. Removes an equal percentage of channels in each prunable layer. """ def __init__(self, pruning_ratio: float): """ Parameters ---------- pruning_ratio : float Percentage of prunable channels to remove in each group. """ self.pruning_ratio: float = pruning_ratio super().__init__() def get_config_for_pruning(self, pruning_info: ModelPruningInfo) -> PruningConfig: return [self.pruning_ratio] * pruning_info.n_prunable_groups