[docs]def check_pruning_config(pruning_config: PruningConfig) -> None:
"""
Validates pruning config.
Parameters
----------
pruning_config : :ref:`PruningConfig <pruning config>`
Config for pruning.
Raises
------
ValueError
If pruning config is not valid.
"""
if not isinstance(pruning_config, Sequence):
raise ValueError('Pruning config should be a sequence')
if all(isinstance(p, int) for p in pruning_config) or all(isinstance(p, float) for p in pruning_config):
return
raise ValueError(
'Pruning config should either be a sequence with integer numbers or a sequence with float numbers'
)
[docs]def get_least_important_labels_by_config(
pruning_info: ModelPruningInfo,
pruning_config: PruningConfig,
) -> List[int]:
"""
Extracts least important channel labels based on pruning config.
Parameters
----------
pruning_info : ModelPruningInfo
Pruning-related information collected during calibration stage
(see more in :func:`.calibrate_model_for_pruning`).
pruning_config : :ref:`PruningConfig <pruning config>`
Config for pruning.
Returns
-------
list of int
Least important labels from all layers.
"""
check_pruning_config(pruning_config)
if isinstance(next(iter(pruning_config)), float):
pruning_config = [
int(len(labels) * pruning_ratio)
for pruning_ratio, (_, labels, _) in zip(
pruning_config, iterate_over_gate_criteria(pruning_info)
)
]
total_labels_for_prune: List[int] = []
for i, labels, criteria in iterate_over_gate_criteria(pruning_info):
criteria: np.ndarray = np.array(criteria)
prune_channels_num = pruning_config[i]
index_for_pruning = np.argsort(criteria)[:prune_channels_num]
total_labels_for_prune += np.array(labels)[index_for_pruning].tolist()
return total_labels_for_prune
[docs]def get_criteria_label_dict(
pruning_info: ModelPruningInfo,
) -> Dict[str, Dict[str, np.ndarray]]:
"""
Collects information about channel importance values in human-readable format.
Parameters
----------
pruning_info : ModelPruningInfo
Pruning-related information collected during calibration stage
(see more in :func:`.calibrate_model_for_pruning`).
Returns
-------
Dict[str, Dict[str, numpy.ndarray]]
Dictionary with channel labels and importance values. It's keys are channel group names, and values store both
global channel indices (labels) and channel importance values.
"""
criteria_label: Dict[str, Dict[str, np.ndarray]] = {}
for gate_id, labels, criteria in iterate_over_gate_criteria(pruning_info):
gate_name_group_id: str = f'gate_{gate_id}_channels_{len(labels)}'
criteria_label[gate_name_group_id] = {
'criteria': np.array(criteria),
'label': np.array(labels),
}
return criteria_label