[docs]def get_labels_for_equal_pruning(
pruning_info: ModelPruningInfo,
pruning_ratio: float = 0.5,
) -> List[int]:
"""
Returns labels of least important channels of every prunable layer.
Parameters
----------
pruning_info : ModelPruningInfo
Pruning-related information collected during calibration stage
(see more in :py:meth:`calibrate_model_for_pruning`).
pruning_ratio : float
Specifies percentage of least important channels selection for pruning.
Returns
-------
list of int
List of least important channel labels.
"""
total_labels_for_prune: List[int] = []
for _, labels, criteria in iterate_over_gate_criteria(pruning_info):
criteria: np.ndarray = np.array(criteria)
prune_channels_num = int(len(criteria) * pruning_ratio)
index_for_pruning = np.argsort(criteria)[:prune_channels_num]
total_labels_for_prune += np.array(labels)[index_for_pruning].tolist()
return total_labels_for_prune
[docs]def get_criteria_label_dict(pruning_info: ModelPruningInfo) -> Dict[str, Dict[str, np.ndarray]]:
"""
Collects information about channel importances in human-readable format.
Parameters
----------
pruning_info : ModelPruningInfo
Pruning-related information collected during calibration stage
(see more in :py:meth:`calibrate_model_for_pruning`).
Returns
-------
Dict[str, Dict[str, numpy.ndarray]]
Dictionary with channel labels and importances. It's keys are channel group names, and values store both global
channel indices (labels) and channel importances.
"""
criteria_label: Dict[str, Dict[str, np.ndarray]] = {}
for gate_id, labels, criteria in iterate_over_gate_criteria(pruning_info):
gate_name_group_id: str = f'gate_{gate_id}_channels_{len(labels)}'
criteria_label[gate_name_group_id] = {
'criteria': np.array(criteria),
'label': np.array(labels),
}
return criteria_label