Distillation

The enot.distillation module contains toolset for the knowledge distillation procedure.

Module provides distill context manager and Mapping class:

from enot.distillation import distill
from enot.distillation import Mapping

In knowledge distillation, a student model is generally supervised by a teacher model, therefore enot.distillation.distill.__init__() and enot.distillation.Mapping.__init__() have teacher and student arguments. To create distill context it is necessary to specify Mapping between modules of a teacher model and modules of student model through which you want to distill:

mapping = Mapping(teacher, student)

mapping.add(teacher.features[3], student.features[3])  # Specify by module reference.
mapping.add('features.5.conv', 'features.5.conv')      # Specify by module name.
mapping.add(teacher.features[7], 'features.7')         # Mixed.

Once the mapping is specified, distill context can be created and as a result it provides new (distillation) model object. It is a normal model, except that it returns paired tensors from the modules specified in the mapping. It remains to apply the desired criterion (loss) to them:

optimizer = torch.optim.RAdam(params=student.parameters(), lr=0.005)

with distill(teacher=teacher, student=student, mapping=mapping) as distill_model:
    inputs = torch.ones(1, 3, 224, 224)  # Sample from dataloader.
    optimizer.zero_grad()

    output = distill_model(inputs)
    loss: torch.Tensor = torch.zeros(1)
    for teacher_output, student_output, _ in output:  # Paired tensors from mapped modules.
        loss += criterion(teacher_output, student_output)

    loss.backward()
    optimizer.step()

In most cases, in addition to the mapping between modules, it is useful to add something else, for example, loss function or weight, or adapter between modules. You can put any of it into payload argument of Mapping.add() and it will be returned with the paired output of distillation model:

mapping.add(teacher.features[3], student.features[3], payload=nn.MSELoss())
mapping.add(teacher.features[5], student.features[5], payload=nn.CrossEntropyLoss())
mapping.add(teacher.features[7], student.features[7], payload=nn.MSELoss())
...
for teacher_output, student_output, criterion in output:  # 3rd value is payload from mapping.
    loss += criterion(teacher_output, student_output)
class distill(teacher, student, mapping=None, nograd_teacher=True)

Context manager that glues teacher and student model into one distillation model.

Examples

>>> from enot.distillation import Mapping
>>> from enot.distillation import distill

Create mapping between teacher and student models, also we additionaly added criterion for each mapping pair as payload.

>>> mapping = Mapping(teacher, student)  # GraphModule - GraphModule mapping.
>>> mapping.add('features_18_0', 'features_18_0', payload=nn.MSELoss())
>>> mapping.add('features_2_conv_1_0', 'features_2_conv_1_0', payload=nn.CrossEntropyLoss())
>>> mapping.add('features_5_conv_0_2', 'features_5_conv_0_2', payload=nn.MSELoss())

Prepare optimizer, scheduler, dataloaders, etc as usual.

>>> optimizer = RAdam(params=student.parameters())

Use distill context to distill teacher knowledge to student.

>>> with distill(teacher=teacher, student=student, mapping=mapping) as distill_model:
>>>     inputs = torch.ones(1, 3, 224, 224)
>>>     optimizer.zero_grad()
>>>     output = distill_model(inputs)
>>>     loss: torch.Tensor = torch.zeros(1)
>>>     for teacher_output, student_output, criterion in output:
>>>         loss += criterion(teacher_output, student_output)
>>>     loss.backward()
>>>     optimizer.step()
__init__(teacher, student, mapping=None, nograd_teacher=True)
Parameters:
  • teacher (Union[nn.Module, GraphModule]) – Teacher; knowledge will be transferred from this model to student model.

  • student (Union[nn.Module, GraphModule]) – Student; knowledge will be transferred to this model from teacher model.

  • mapping (Optional[Mapping]) – Mapping specifies modules from where and where to distill.

  • nograd_teacher (bool) – Use no_grad decorator for teacher or not. Default value is True.

class Mapping(teacher, student)

Mapping between modules of teacher and student models.

__init__(teacher, student)
Parameters:
  • teacher (Module) – Teacher module.

  • student (Module) – Student module.

add(teacher_module, student_module, payload=None)

Add pair to mapping.

Parameters:
  • teacher_module (Mappable) – Teacher module which will be associated with student module.

  • student_module (Mappable) – Student module which will be associated with teacher module.

  • payload (Any) – Payload, default value is None.

Return type:

None

payload()

Payload, order is preserved.

Return type:

List[Any]

Adapter

Suppose we have two models: teacher and student, where student is a pruned version of the teacher; and we want to distill through corresponding conv2d modules.

Consider specific conv2d module in the teacher and corresponding to it conv2d module in the student. Since student is a pruned version of the teacher, the student module may have fewer output channels than the teacher module, so we cannot distill directly through these modules (tensors). The solution is an adapter (trainable) that narrows teacher’s conv2d output tensor to student’s conv2d output tensor:

graph LR subgraph teacher direction TB module_teacher_0([...]) --> conv2d_teacher([conv2d]) conv2d_teacher -- 512 --> module_teacher_1([...]) end adapter{{adapter}} subgraph student direction TB module_student_0([...]) --> conv2d_student([conv2d]) conv2d_student -- 256 --> module_student_1([...]) end teacher -- 512 --> adapter -- 256 --> student

Such an adapter must be properly initialized, otherwise distillation may fail. We provide conv2d_adapter() function to create a properly initialized adapter: it maps all unpruned (with regard to the student model) channels in the teacher’s conv2d module to the student’s conv2d module:

from enot.distillation.adapter import conv2d_adapter

adapter = conv2d_adapter(source=conv2d_teacher, target=conv2d_student)
mapping.add(conv2d_teacher, conv2d_student, payload=adapter)  # Put adapter to payload.
# To use different criteria for different adapters, add a tuple to the payload: (adapter, criterion).
# Don't forget to add adapter parameters to the optimizer to make it trainable.
...
with distill(teacher=teacher, student=student, mapping=mapping) as distill_model:
    ...
    output = distill_model(inputs)
    loss: torch.Tensor = torch.zeros(1)
    for teacher_output, student_output, adapter in output:  # Payload is adapter.
        loss += criterion(adapter(teacher_output), student_output)
    ...

We also provide inverted_conv2d_adapter() if you want to do something like this:

loss += criterion(teacher_output, adapter(student_output))
conv2d_adapter(source, target)

Conv2d 1x1 adapter for two convolutions: source and target. Target convolution should be pruned version of source convolution.

Parameters:
Returns:

Adapter that maps all target (not pruned) channels in source convolution into target convolution.

Return type:

torch.nn.Conv2d

inverted_conv2d_adapter(source, target)

Inverted Conv2d 1x1 adapter for two convolutions: source and target. Target convolution should be pruned version of source convolution, but instead of narrowing source to target it extends target conv to source.

Parameters:
Returns:

Adapter that maps all target (not pruned) channels in target convolution into source convolution.

Return type:

torch.nn.Conv2d