from abc import ABCMeta
from abc import abstractmethod
from os import getpid
from pathlib import Path
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch_optimizer import RAdam
from tqdm.auto import tqdm
from enot.logging.trackers import RunningMeanLogger
# We are using absolute imports here instead of simple imports like
# ``from enot.quantization import calibrate_quantized_model``. This
# is because this file is part of the ``enot.quantization`` package,
# and it is used in enot.quantization.__init__.py file, so simple
# imports would cause circular import error.
#
# Please use simple imports in your projects instead of absolute ones.
from enot.quantization.calibration.network_calibration import calibrate_quantized_model
from enot.quantization.distillation.utils import DistillationLayerSelectionStrategy
from enot.quantization.distillation.utils import distillation_context
from enot.quantization.fake_quantized_model import FakeQuantizedModel
from enot.quantization.utils.common import float_model_from_quantized_model
from enot.utils.data.dataloaders import recursive_to
from enot.utils.dataloader2model import DataLoaderSampleToModelInputs
from enot.utils.dataloader2model import default_sample_to_model_inputs
from enot.utils.train import Scheduler
[docs]class RMSELoss(nn.Module):
def __init__(self, eps: float = 1e-6):
super().__init__()
self.mse: nn.MSELoss = nn.MSELoss()
self.eps: float = eps
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss = torch.sqrt(self.mse(input, target) + self.eps)
return loss
KNOWN_CRITERIA = {
'RMSELoss': RMSELoss,
'MSELoss': nn.MSELoss,
'CrossEntropyLoss': nn.CrossEntropyLoss,
}
def parse_distillation_criterion(criterion: Union[torch.nn.Module, str]) -> torch.nn.Module:
if isinstance(criterion, nn.Module):
return criterion
if isinstance(criterion, str):
return KNOWN_CRITERIA[criterion]()
raise TypeError(f'Unknown criterion type: {type(criterion)}')
[docs]class DistillerInterface(metaclass=ABCMeta):
"""Distiller base interface."""
@abstractmethod
def distill(self) -> None:
"""Launches distillation procedure."""
pass
[docs]class QuantizationDistiller(DistillerInterface):
"""Quantized model distillation class with a simple distillation implementation."""
[docs] def __init__(
self,
quantized_model: FakeQuantizedModel,
dataloader: DataLoader,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[Scheduler] = None,
distillation_layer_selection_strategy: DistillationLayerSelectionStrategy = DistillationLayerSelectionStrategy.DISTILL_LAST_QUANT_LAYERS,
distillation_criterion: Union[torch.nn.Module, str] = 'RMSELoss',
n_epochs: int = 1,
device: Union[str, torch.device] = 'cuda:0',
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
logdir: Optional[Union[str, Path]] = None,
save_every: Optional[int] = None,
verbose: int = 0,
):
"""
Parameters
----------
quantized_model : FakeQuantizedModel
Fake-quantized model.
dataloader : torch.utils.data.DataLoader
Dataloader with model inputs for distillation.
optimizer : torch.optim.Optimizer or None, optional
Optimizer instance.
scheduler : Scheduler or None, optional
Scheduler instance.
distillation_layer_selection_strategy : DistillationLayerSelectionStrategy, optional
Distillation layer selection strategy. Default value is DISTILL_LAST_QUANT_LAYERS.
distillation_criterion : Callable, optional
Distillation criterion module. Default criterion is RMSE.
n_epochs : int, optional
Number of epochs for distillation. Default value is 1.
device : str or torch.device, optional
Device to use during distillation. Default value is "cuda:0".
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
logdir : str or Path or None, optional
Save directory. Default value is None, which disables logging to directory.
save_every : int or None, optional
Save checkpoint every n steps. Default value is None, which disables intermediate model checkpoints.
verbose : int, optional
Verbosity level. Default value is 0.
"""
self.quantized_model = quantized_model
self.dataloader = dataloader
self.optimizer = optimizer
self.scheduler = scheduler
self.distillation_strategy = distillation_layer_selection_strategy
self.distillation_criterion = parse_distillation_criterion(distillation_criterion)
self.n_epochs = n_epochs
self.device = device
self.sample_to_model_inputs = sample_to_model_inputs
self.logdir = logdir
self.save_every = save_every
self.verbose = verbose
self.quantized_model.to(self.device)
self.quantized_model.train()
self.quantized_model.enable_quantization_mode(True)
self.loss_logger = RunningMeanLogger()
self.minimal_loss_value: Optional[float] = None
if self.logdir is not None:
self.logdir = Path(self.logdir)
self.logdir.mkdir(exist_ok=True, parents=True)
[docs] def distill(self) -> None:
with distillation_context(self.quantized_model, self.distillation_strategy):
quantized_model = self.quantized_model
regular_model = float_model_from_quantized_model(quantized_model)
self._base_distill(quantized_model, regular_model)
[docs] def save(self, checkpoint_path: Path) -> None:
"""Saves quantized model state dict to the specified path."""
torch.save(self.quantized_model.state_dict(), checkpoint_path)
def _base_distill(
self,
quantized_model: FakeQuantizedModel,
regular_model: FakeQuantizedModel,
) -> None:
"""Main distillation fine-tuning loop."""
best_model_path = None
checkpoint_dir = None
if self.logdir is not None:
checkpoint_dir = self.logdir / 'model_checkpoints'
checkpoint_dir.mkdir(exist_ok=True)
total_steps = 0
for epoch in range(self.n_epochs):
tqdm_iterator = tqdm(self.dataloader, disable=(self.verbose == 0))
for batch in tqdm_iterator:
self.optimizer.zero_grad()
model_args, model_kwargs = self.sample_to_model_inputs(batch)
recursive_to(model_args, device=self.device, ignore_non_tensors=True)
recursive_to(model_kwargs, device=self.device, ignore_non_tensors=True)
quantized_output = quantized_model(*model_args, **model_kwargs)
with torch.no_grad():
regular_output = regular_model(*model_args, **model_kwargs)
loss: Union[float, torch.Tensor] = 0.0
for q_out, r_out in zip(quantized_output, regular_output):
loss += self.distillation_criterion(q_out, r_out)
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
total_steps += 1
self.loss_logger.tracked_value = loss.item()
min_loss_updated = False
if self.minimal_loss_value is None or (self.loss_logger.tracked_value < self.minimal_loss_value):
self.minimal_loss_value = self.loss_logger.tracked_value
min_loss_updated = True
if checkpoint_dir is not None:
if self.save_every is not None and total_steps % self.save_every == self.save_every - 1:
self.save(checkpoint_dir / f'model_{epoch}_{total_steps}.pth')
if min_loss_updated:
best_model_path = checkpoint_dir / 'best_model.pth'
self.save(best_model_path)
tqdm_iterator.set_description(
f'loss: {self.loss_logger.tracked_value:.5f}, ' f'min loss: {self.minimal_loss_value:.5f}'
)
if best_model_path is not None:
quantized_model.load_state_dict(torch.load(best_model_path))
[docs]class SequentialDistiller(DistillerInterface):
"""Compound distillation class which performs sequential distillation with multiple strategies."""
[docs] def __init__(self, *distillers: DistillerInterface):
"""
Parameters
----------
distillers : tuple with DistillerInterface
Tuple with distiller instances.
"""
self.distillers: Tuple[DistillerInterface] = distillers
[docs] def distill(self) -> None:
for distiller in self.distillers:
distiller.distill()
[docs]class ThresholdsAndScaleFactorsQuantizationDistiller(QuantizationDistiller):
"""
Quantization distiller for thresholds and scale factors
with a default well-performing distillation configuration.
"""
[docs] def __init__(
self,
quantized_model: FakeQuantizedModel,
dataloader: DataLoader,
learning_rate: float = 0.005,
device: Union[str, torch.device] = 'cuda:0',
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
logdir: Optional[Union[str, Path]] = None,
save_every: Optional[int] = None,
n_batches_calibrate: int = 10,
tune_scale_factors: bool = True,
distillation_layer_selection_strategy: DistillationLayerSelectionStrategy = DistillationLayerSelectionStrategy.DISTILL_LAST_QUANT_LAYERS,
verbose: int = 0,
):
"""
Parameters
----------
quantized_model : FakeQuantizedModel
Fake-quantized model.
dataloader : torch.utils.data.DataLoader
Dataloader with model inputs for distillation.
learning_rate : float, optional
learning rate (default: 5e-3)
device : str or torch.device, optional
Device to use during distillation. Default value is "cuda:0".
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
logdir : str or Path or None, optional
Save directory. Default value is None, which disables logging to directory.
save_every : int or None, optional
Save checkpoint every n steps. Default value is None, which disables intermediate model checkpoints.
n_batches_calibrate: int, optional
Number of batches used for calibration. Default is 10.
tune_scale_factors : bool, optional
Whether to tune scale factors or not. True by default.
distillation_layer_selection_strategy : DistillationLayerSelectionStrategy, optional
Distillation layer selection strategy. Default value is DISTILL_LAST_QUANT_LAYERS.
verbose : int, optional
Verbosity level. Default value is 0.
"""
n_epochs = 1
batches_in_epoch = len(dataloader)
self.n_batches_calibrate: int = n_batches_calibrate
if logdir is not None:
logdir = Path(logdir) / f'{getpid()}_{id(quantized_model)}'
if tune_scale_factors:
logdir /= 'thresholds_and_scale_factors_distillation'
else:
logdir /= 'thresholds_distillation'
if verbose:
print(f'The working directory is {logdir.as_posix()}')
params = [*quantized_model.quantization_parameters()]
if tune_scale_factors:
params += [*quantized_model.scale_factors()]
optimizer = RAdam(
params=params,
lr=learning_rate,
betas=(0.9, 0.95),
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer,
T_max=batches_in_epoch * n_epochs,
)
super().__init__(
quantized_model=quantized_model,
dataloader=dataloader,
optimizer=optimizer,
scheduler=scheduler,
distillation_layer_selection_strategy=distillation_layer_selection_strategy,
distillation_criterion=RMSELoss(),
n_epochs=n_epochs,
device=device,
sample_to_model_inputs=sample_to_model_inputs,
logdir=logdir,
save_every=save_every,
verbose=verbose,
)
[docs] def distill(self) -> None:
calibrate_quantized_model(
quantized_model=self.quantized_model,
dataloader=self.dataloader,
n_steps=self.n_batches_calibrate,
sample_to_model_inputs=self.sample_to_model_inputs,
verbose=self.verbose,
)
if self.verbose >= 2:
print('Fine-tuning quantization parameters.')
super().distill()
[docs]class ThresholdsQuantizationDistiller(ThresholdsAndScaleFactorsQuantizationDistiller):
"""Quantization distiller for thresholds with a default well-performing distillation configuration."""
[docs] def __init__(
self,
quantized_model: FakeQuantizedModel,
dataloader: DataLoader,
learning_rate: float = 0.05,
device: Union[str, torch.device] = 'cuda:0',
sample_to_model_inputs: DataLoaderSampleToModelInputs = default_sample_to_model_inputs,
logdir: Optional[Union[str, Path]] = None,
save_every: Optional[int] = None,
n_batches_calibrate: int = 10,
distillation_layer_selection_strategy: DistillationLayerSelectionStrategy = DistillationLayerSelectionStrategy.DISTILL_LAST_QUANT_LAYERS,
verbose: int = 0,
):
"""
Parameters
----------
quantized_model : FakeQuantizedModel
Fake-quantized model.
dataloader : torch.utils.data.DataLoader
Dataloader with model inputs for distillation.
learning_rate : float, optional
learning rate (default: 5e-2)
device : str or torch.device, optional
Device to use during distillation. Default value is "cuda:0".
sample_to_model_inputs : Callable, optional
Function to map dataloader samples to model input format. Default value is
:func:`.default_sample_to_model_inputs`. See more :ref:`here <s2mi ref>`.
logdir : str or Path or None, optional
Save directory. Default value is None, which disables logging to directory.
save_every : int or None, optional
Save checkpoint every n steps. Default value is None, which disables intermediate model checkpoints.
n_batches_calibrate: int, optional
Number of batches used for calibration. Default is 10.
distillation_layer_selection_strategy : DistillationLayerSelectionStrategy, optional
Distillation layer selection strategy. Default value is DISTILL_LAST_QUANT_LAYERS.
verbose : int, optional
Verbosity level. Default value is 0.
"""
super().__init__(
quantized_model=quantized_model,
dataloader=dataloader,
learning_rate=learning_rate,
device=device,
sample_to_model_inputs=sample_to_model_inputs,
logdir=logdir,
save_every=save_every,
tune_scale_factors=False,
n_batches_calibrate=n_batches_calibrate,
distillation_layer_selection_strategy=distillation_layer_selection_strategy,
verbose=verbose,
)
DefaultQuantizationDistiller = ThresholdsAndScaleFactorsQuantizationDistiller