Distillation

enot.distillation module contains toolset for the knowledge distillation procedure.

Module provides distill context manager and Mapping class:

from enot.distillation import distill
from enot.distillation import Mapping

In knowledge distillation, a student model is generally supervised by a teacher model, therefore distill.__init__() and Mapping.__init__() have teacher and student arguments. To create distill context it is necessary to specify Mapping between modules of a teacher model and modules of student model through which you want to distill:

mapping = Mapping(teacher, student)

mapping.add(teacher.features[3], student.features[3])  # Specify by module reference.
mapping.add('features.5.conv', 'features.5.conv')      # Specify by module name.
mapping.add(teacher.features[7], 'features.7')         # Mixed.

Once the mapping is specified, distill context can be created and as a result it provides new (distillation) model object. It is a normal model, except that it returns paired tensors from the modules specified in the mapping. It remains to apply the desired criterion (loss) to them:

with distill(teacher=teacher, student=student, mapping=mapping) as distill_model:
    inputs = torch.ones(1, 3, 224, 224)  # Sample from dataloader.
    optimizer.zero_grad()

    output = distill_model(inputs)
    loss: torch.Tensor = torch.zeros(1)
    for teacher_output, student_output, _ in output:  # Paired tensors from mapped modules.
        loss += criterion(teacher_output, student_output)

    loss.backward()
    optimizer.step()

In most cases, in addition to the mapping between modules, it is useful to add something else, for example, loss function or weight, or adapter between modules. You can put any of it into payload argument of Mapping.add() and it will be returned with the paired output of distillation model:

mapping.add(teacher.features[3], student.features[3], payload=nn.MSELoss())
mapping.add(teacher.features[5], student.features[5], payload=nn.CrossEntropyLoss())
mapping.add(teacher.features[7], student.features[7], payload=nn.MSELoss())
...
for teacher_output, student_output, criterion in output:  # 3rd value is payload from mapping.
    loss += criterion(teacher_output, student_output)
class distill(teacher, student, mapping=None, nograd_teacher=True)

Context manager that glues teacher and student model into one distillation model.

Examples

>>> from enot.distillation import Mapping
>>> from enot.distillation import distill

Create mapping between teacher and student models, also we additionaly added criterion for each mapping pair as payload.

>>> mapping = Mapping(teacher, student)  # GraphModule - GraphModule mapping.
>>> mapping.add('features_18_0', 'features_18_0', payload=nn.MSELoss())
>>> mapping.add('features_2_conv_1_0', 'features_2_conv_1_0', payload=nn.CrossEntropyLoss())
>>> mapping.add('features_5_conv_0_2', 'features_5_conv_0_2', payload=nn.MSELoss())

Prepare optimizer, scheduler, dataloaders, etc as usual.

>>> optimizer = RAdam(params=student.parameters())

Use distill context to distill teacher knowledge to student.

>>> with distill(teacher=teacher, student=student, mapping=mapping) as distill_model:
>>>     inputs = torch.ones(1, 3, 224, 224)
>>>     optimizer.zero_grad()
>>>     output = distill_model(inputs)
>>>     loss: torch.Tensor = torch.zeros(1)
>>>     for teacher_output, student_output, criterion in output:
>>>         loss += criterion(teacher_output, student_output)
>>>     loss.backward()
>>>     optimizer.step()
__init__(teacher, student, mapping=None, nograd_teacher=True)
Parameters:
  • teacher (Union[nn.Module, GraphModule]) – Teacher; knowledge will be transferred from this model to student model.

  • student (Union[nn.Module, GraphModule]) – Student; knowledge will be transferred to this model from teacher model.

  • mapping (Optional[Mapping]) – Mapping specifies modules from where and where to distill.

  • nograd_teacher (bool) – Use no_grad decorator for teacher or not. Default value is True.

class Mapping(teacher, student)

Mapping between modules of teacher and student models.

__init__(teacher, student)
Parameters:
  • teacher (Module) – Teacher module.

  • student (Module) – Student module.

add(teacher_module, student_module, payload=None)

Add pair to mapping.

Parameters:
  • teacher_module (Mappable) – Teacher module which will be associated with student module.

  • student_module (Mappable) – Student module which will be associated with teacher module.

  • payload (Any) – Payload, default value is None.

Return type:

None

payload()

Payload, order is preserved.

Return type:

List[Any]