Source code for enot.pruning.prune_strategy

[docs]def get_labels_for_equal_pruning( pruning_info: ModelPruningInfo, pruning_ratio: float = 0.5, ) -> List[int]: """ Returns labels of least important channels of every prunable layer. Parameters ---------- pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :py:meth:`calibrate_model_for_pruning`). pruning_ratio : float Specifies percentage of least important channels selection for pruning. Returns ------- list of int List of least important channel labels. """ total_labels_for_prune: List[int] = [] for _, labels, criteria in iterate_over_gate_criteria(pruning_info): criteria: np.ndarray = np.array(criteria) prune_channels_num = int(len(criteria) * pruning_ratio) index_for_pruning = np.argsort(criteria)[:prune_channels_num] total_labels_for_prune += np.array(labels)[index_for_pruning].tolist() return total_labels_for_prune
[docs]def get_criteria_label_dict(pruning_info: ModelPruningInfo) -> Dict[str, Dict[str, np.ndarray]]: """ Collects information about channel importances in human-readable format. Parameters ---------- pruning_info : ModelPruningInfo Pruning-related information collected during calibration stage (see more in :py:meth:`calibrate_model_for_pruning`). Returns ------- Dict[str, Dict[str, numpy.ndarray]] Dictionary with channel labels and importances. It's keys are channel group names, and values store both global channel indices (labels) and channel importances. """ criteria_label: Dict[str, Dict[str, np.ndarray]] = {} for gate_id, labels, criteria in iterate_over_gate_criteria(pruning_info): gate_name_group_id: str = f'gate_{gate_id}_channels_{len(labels)}' criteria_label[gate_name_group_id] = { 'criteria': np.array(criteria), 'label': np.array(labels), } return criteria_label