Source code for enot.pruning.prune_strategy

[docs]def check_pruning_config(pruning_config: PruningConfig) -> None: """ Validates pruning config. Parameters ---------- pruning_config : :ref:`PruningConfig <pruning config>` Config for pruning. Raises ------ ValueError If pruning config is not valid. """ if not isinstance(pruning_config, Sequence): raise ValueError('Pruning config should be a sequence') if all(isinstance(p, int) for p in pruning_config) or all(isinstance(p, float) for p in pruning_config): return raise ValueError( 'Pruning config should either be a sequence with integer numbers or a sequence with float numbers' )
[docs]def get_least_important_labels_by_config( pruning_info: ModelPruningInfo, pruning_config: PruningConfig, ) -> List[int]: """ Extracts least important channel labels based on pruning config. Parameters ---------- pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :func:`.calibrate_model_for_pruning`). pruning_config : :ref:`PruningConfig <pruning config>` Config for pruning. Returns ------- list of int Least important labels from all layers. """ check_pruning_config(pruning_config) if isinstance(next(iter(pruning_config)), float): pruning_config = [ int(len(labels) * pruning_ratio) for pruning_ratio, (_, labels, _) in zip( pruning_config, iterate_over_gate_criteria(pruning_info) ) ] total_labels_for_prune: List[int] = [] for i, labels, criteria in iterate_over_gate_criteria(pruning_info): criteria: np.ndarray = np.array(criteria) prune_channels_num = pruning_config[i] index_for_pruning = np.argsort(criteria)[:prune_channels_num] total_labels_for_prune += np.array(labels)[index_for_pruning].tolist() return total_labels_for_prune
[docs]def get_labels_for_uniform_pruning( pruning_info: ModelPruningInfo, pruning_ratio: float = 0.5, ) -> List[int]: """ Returns labels of least important channels of every prunable layer. Parameters ---------- pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :func:`.calibrate_model_for_pruning`). pruning_ratio : float Specifies percentage of least important channels selection for pruning. Returns ------- list of int List of least important channel labels. """ pruning_cfg = [pruning_ratio] * pruning_info.n_prunable_groups return get_least_important_labels_by_config( pruning_info=pruning_info, pruning_config=pruning_cfg, )
[docs]def get_criteria_label_dict( pruning_info: ModelPruningInfo, ) -> Dict[str, Dict[str, np.ndarray]]: """ Collects information about channel importance values in human-readable format. Parameters ---------- pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :func:`.calibrate_model_for_pruning`). Returns ------- Dict[str, Dict[str, numpy.ndarray]] Dictionary with channel labels and importance values. It's keys are channel group names, and values store both global channel indices (labels) and channel importance values. """ criteria_label: Dict[str, Dict[str, np.ndarray]] = {} for gate_id, labels, criteria in iterate_over_gate_criteria(pruning_info): gate_name_group_id: str = f'gate_{gate_id}_channels_{len(labels)}' criteria_label[gate_name_group_id] = { 'criteria': np.array(criteria), 'label': np.array(labels), } return criteria_label