Rescaling
The enot.rescaling
module contains tool to improve accuracy when using quantization.
Rescaling consists of two steps:
Collecting activation statistics (calibration).
Rescaling activations and weights.
Both of these steps can be performed using RescalingCalibrator
class.
- class RescalingCalibrator(model, *, args=(), kwargs=None, excluded_modules=None, inplace=False)
Is a simple tool to improve accuracy when using quantization.
Examples
>>> from enot.rescaling import RescalingCalibrator >>> rescaling_calibrator = RescalingCalibrator(model) >>> # calibration (collecting activation statistics): >>> for sample, _ in dataloader: >>> rescaling_calibrator(sample) # pass model args/kwargs as usual >>> # rescaling (activations and weights): >>> model = rescaling_calibrator.rescale(alpha=0.5)
- __init__(model, *, args=(), kwargs=None, excluded_modules=None, inplace=False)
- Parameters:
model (torch.nn.Module) – Model for rescaling.
args (Tuple) – Model positional arguments. Optional.
kwargs (Optional[Dict[str, Any]]) – Model keyword arguments. Optional.
excluded_modules (Optional[List[Union[Type[torch.nn.Module], torch.nn.Module]]]) – Modules or types of modules that should not be recaled.
inplace (bool) – Rescaling not only inserts new submodules to the model, but also changes the weights of some submodules. In the case
inplace=True
calibrator copies the model and original model will not be changed. In the caseinplace=False
the weights of the original model will be changed. Default value is False.
- rescale(*, alpha=0.5)
Rescale model using the statistics collected during the calibration process.
- Parameters:
alpha (float) – Migration strength coefficient. Controls how much difficulty migrates from activation to weights. 0.5 is a well-balanced point to evenly split the quantization difficulty. Choose a larger alpha to migrate more quantization difficulty to weights (like 0.75). Should be in the range [0, 1]. Default value is 0.5.
- Returns:
Rescaled model. Note, in case of
inplace=True
, the model passed toRescalingCalibrator
will be completely broken, only use the model returned by this method.- Return type: