######################
Automatic Quantization
######################

The ``enot.quantization`` module contains functions for automatic quantization
of user models. Best suitable for preparing user models for `ENOT Lite`_ INT8
engines.

With ``enot.quantization`` package, you can automatically convert your
PyTorch model to our intermediate representation which allows you to
perform multiple kinds of quantization including vector quantization for
`TensorRT`_, `OpenVINO`_ and `STM`_ devices.

This package features automatic distillation for weight fine-tuning, automatic
quantization threshold search described in `Fast Adjustable Threshold`_ paper,
different methods for layer selection for distillation and a number of
fake-quantization algorithms.

.. _ENOT Lite: https://enot-lite.rtd.enot.ai/en/stable/
.. _OpenVINO: https://docs.openvino.ai/latest/index.html
.. _TensorRT: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html
.. _STM: https://stm32ai.st.com
.. _Fast Adjustable Threshold: https://arxiv.org/abs/1812.07872

The module provides **fake quantized model classes** -- special model wrappers that implement quantization schemes
and **two context managers**:

- :class:`~enot.quantization.TensorRTFakeQuantizedModel` /
  :class:`~enot.quantization.OpenVINOFakeQuantizedModel` /
  :class:`~enot.quantization.STMFakeQuantizedModel`
- :class:`~enot.quantization.calibrate`, :class:`~enot.quantization.distill`

The quantization procedure is as follows:

- wrap the float model in one of the fake quantized models listed above
- write calibration loop using :class:`~enot.quantization.calibrate` context decorator
- write distillation loop using :class:`~enot.quantization.distill` context decorator

Example:

.. code-block:: python

    from enot.quantization import TensorRTFakeQuantizedModel
    from enot.quantization import calibrate
    from enot.quantization import distill
    from enot.quantization import RMSELoss

    # wrap float model to fake quantized model
    fq_model = TensorRTFakeQuantizedModel(model).cuda()

    # calibration
    with torch.no_grad(), calibrate(fq_model):
        for batch in itertools.islice(dataloader, 10):  # 10 batches for calibration
            batch = batch[0].cuda()
            fq_model(batch)

    # distillation
    n_epochs = 5
    with distill(fq_model=fq_model, tune_weight_scale_factors=True) as (qdistill_model, params):
        optimizer = RAdam(params=params, lr=0.005, betas=(0.9, 0.95))
        scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=len(dataloader) * n_epochs)
        criterion = RMSELoss()

        for _ in range(n_epochs):
            for batch in (tqdm_it := tqdm(dataloader)):
                batch = batch[0].cuda()

                optimizer.zero_grad()
                loss: torch.Tensor = torch.tensor(0.0).cuda()
                for student_output, teacher_output in qdistill_model(batch):
                    loss += criterion(student_output, teacher_output)

                loss.backward()
                optimizer.step()
                scheduler.step()

                tqdm_it.set_description(f'loss: {loss.item():.5f}')

*********************
Fake Quantized Models
*********************

.. autoclass:: enot.quantization.TensorRTFakeQuantizedModel
    :members: __init__
    :show-inheritance:

.. autoclass:: enot.quantization.OpenVINOFakeQuantizedModel
    :members: __init__
    :show-inheritance:

.. autoclass:: enot.quantization.STMFakeQuantizedModel
    :members: __init__
    :show-inheritance:

.. autoclass:: enot.quantization.FakeQuantizedModel
    :members: quantization_parameters, regular_parameters, enable_calibration_mode, enable_quantization_mode

.. autofunction:: enot.quantization.utils.float_model_from_quantized_model

.. autofunction:: enot.quantization.utils.optimal_quantization_scheme


***********
Calibration
***********

.. autoclass:: enot.quantization.calibrate
    :members: __init__
    :show-inheritance:


************
Distillation
************

The listed classes and functions provide utilities and
procedures for quantized model fine-tuning using `knowledge distillation`_.

.. _knowledge distillation: https://arxiv.org/abs/1503.02531

.. autoclass:: enot.quantization.distill
    :members: __init__
    :show-inheritance:

Helper functions for distillation:

.. autoclass:: enot.quantization.QuantDistillationModule

.. autoclass:: enot.quantization.DistillationLayerSelectionStrategy
    :show-inheritance:

.. autoclass:: enot.quantization.RMSELoss