############ Distillation ############ The ``enot.distillation`` module contains toolset for the knowledge distillation procedure. Module provides :class:`~enot.distillation.distill` context manager and :class:`~enot.distillation.Mapping` class: .. code-block:: python from enot.distillation import distill from enot.distillation import Mapping In knowledge distillation, a student model is generally supervised by a teacher model, therefore :meth:`enot.distillation.distill.__init__` and :meth:`enot.distillation.Mapping.__init__` have student and teacher arguments. To create **distill** context it is necessary to specify :class:`~enot.distillation.Mapping` between modules of a teacher model and modules of student model through which you want to distill: .. code-block:: python mapping = Mapping(student, teacher) mapping.add(student.features[3], teacher.features[3]) # Specify by module reference. mapping.add('features.5.conv', 'features.5.conv') # Specify by module name. mapping.add(student.features[7], 'features.7') # Mixed. Once the **mapping** is specified, **distill** context can be created and as a result it provides new (distillation) model object. It is a normal model, except that it returns paired tensors from the modules specified in the mapping. It remains to apply the desired criterion (loss) to them: .. code-block:: python optimizer = torch.optim.RAdam(params=student.parameters(), lr=0.005) with distill(student=student, teacher=teacher, mapping=mapping) as distill_model: inputs = torch.ones(1, 3, 224, 224) # Sample from dataloader. optimizer.zero_grad() output = distill_model(inputs) loss: torch.Tensor = torch.zeros(1) for student_output, teacher_output, _ in output: # Paired tensors from mapped modules. loss += criterion(student_output, teacher_output) loss.backward() optimizer.step() In most cases, in addition to the mapping between modules, it is useful to add something else, for example, loss function or weight, or adapter between modules. You can put any of it into ``payload`` argument of :meth:`.Mapping.add` and it will be returned with the paired output of distillation model: .. code-block:: python mapping.add(student.features[3], teacher.features[3], payload=nn.MSELoss()) mapping.add(student.features[5], teacher.features[5], payload=nn.CrossEntropyLoss()) mapping.add(student.features[7], teacher.features[7], payload=nn.MSELoss()) ... for student_output, teacher_output, criterion in output: # 3rd value is payload from mapping. loss += criterion(student_output, teacher_output) .. autoclass:: enot.distillation.distill :members: .. autoclass:: enot.distillation.Mapping :members: __init__, add, payload ******* Adapter ******* Suppose we have two models: student and teacher, where student is a **pruned** version of the teacher; and we want to distill through corresponding conv2d modules. Consider specific conv2d module in the teacher and corresponding to it conv2d module in the student. Since student is a pruned version of the teacher, the student module may have fewer output channels than the teacher module, so we cannot distill directly through these modules (tensors). The solution is an adapter (trainable) that narrows teacher's conv2d output tensor to student's conv2d output tensor: .. mermaid:: :align: center graph LR subgraph teacher direction TB module_teacher_0([...]) --> conv2d_teacher([conv2d]) conv2d_teacher -- 512 --> module_teacher_1([...]) end adapter{{adapter}} subgraph student direction TB module_student_0([...]) --> conv2d_student([conv2d]) conv2d_student -- 256 --> module_student_1([...]) end teacher -- 512 --> adapter -- 256 --> student Such an adapter must be properly initialized, otherwise distillation may fail. We provide :func:`.conv2d_adapter` function to create a properly initialized adapter: it maps all unpruned (with regard to the student model) channels in the teacher's conv2d module to the student's conv2d module: .. code-block:: python from enot.distillation.adapter import conv2d_adapter adapter = conv2d_adapter(source=conv2d_teacher, target=conv2d_student) mapping.add(conv2d_student, conv2d_teacher, payload=adapter) # Put adapter to payload. # To use different criteria for different adapters, add a tuple to the payload: (adapter, criterion). # Don't forget to add adapter parameters to the optimizer to make it trainable. ... with distill(student=student, teacher=teacher, mapping=mapping) as distill_model: ... output = distill_model(inputs) loss: torch.Tensor = torch.zeros(1) for student_output, teacher_output, adapter in output: # Payload is adapter. loss += criterion(student_output, adapter(teacher_output)) ... We also provide :func:`.inverted_conv2d_adapter` if you want to do something like this: .. code-block:: python loss += criterion(adapter(student_output), teacher_output) .. autofunction:: enot.distillation.adapter.conv2d_adapter .. autofunction:: enot.distillation.adapter.inverted_conv2d_adapter