Distillation
The enot.distillation
module contains toolset for the knowledge distillation procedure.
Module provides distill
context manager and Mapping
class:
from enot.distillation import distill
from enot.distillation import Mapping
In knowledge distillation, a student model is generally supervised by a teacher model, therefore
enot.distillation.distill.__init__()
and enot.distillation.Mapping.__init__()
have student and teacher arguments.
To create distill context it is necessary to specify Mapping
between modules of a teacher model
and modules of student model through which you want to distill:
mapping = Mapping(student, teacher)
mapping.add(student.features[3], teacher.features[3]) # Specify by module reference.
mapping.add('features.5.conv', 'features.5.conv') # Specify by module name.
mapping.add(student.features[7], 'features.7') # Mixed.
Once the mapping is specified, distill context can be created and as a result it provides new (distillation) model object. It is a normal model, except that it returns paired tensors from the modules specified in the mapping. It remains to apply the desired criterion (loss) to them:
optimizer = torch.optim.RAdam(params=student.parameters(), lr=0.005)
with distill(student=student, teacher=teacher, mapping=mapping) as distill_model:
inputs = torch.ones(1, 3, 224, 224) # Sample from dataloader.
optimizer.zero_grad()
output = distill_model(inputs)
loss: torch.Tensor = torch.zeros(1)
for student_output, teacher_output, _ in output: # Paired tensors from mapped modules.
loss += criterion(student_output, teacher_output)
loss.backward()
optimizer.step()
In most cases, in addition to the mapping between modules, it is useful to add something else, for example,
loss function or weight, or adapter between modules.
You can put any of it into payload
argument of Mapping.add()
and it will be returned with the paired
output of distillation model:
mapping.add(student.features[3], teacher.features[3], payload=nn.MSELoss())
mapping.add(student.features[5], teacher.features[5], payload=nn.CrossEntropyLoss())
mapping.add(student.features[7], teacher.features[7], payload=nn.MSELoss())
...
for student_output, teacher_output, criterion in output: # 3rd value is payload from mapping.
loss += criterion(student_output, teacher_output)
- class distill(student, teacher, *, mapping=None, nograd_teacher=True)
Context manager that glues teacher and student model into one distillation model.
Examples
>>> from enot.distillation import Mapping >>> from enot.distillation import distill
Create mapping between teacher and student models, also we additionaly added criterion for each mapping pair as payload.
>>> mapping = Mapping(student, teacher) # GraphModule - GraphModule mapping. >>> mapping.add('features_18_0', 'features_18_0', payload=nn.MSELoss()) >>> mapping.add('features_2_conv_1_0', 'features_2_conv_1_0', payload=nn.CrossEntropyLoss()) >>> mapping.add('features_5_conv_0_2', 'features_5_conv_0_2', payload=nn.MSELoss())
Prepare optimizer, scheduler, dataloaders, etc as usual.
>>> optimizer = RAdam(params=student.parameters())
Use distill context to distill teacher knowledge to student.
>>> with distill(student=student, teacher=teacher, mapping=mapping) as distill_model: >>> inputs = torch.ones(1, 3, 224, 224) >>> optimizer.zero_grad() >>> output = distill_model(inputs) >>> loss: torch.Tensor = torch.zeros(1) >>> for student_output, teacher_output criterion in output: >>> loss += criterion(student_output, teacher_output) >>> loss.backward() >>> optimizer.step()
- __init__(student, teacher, *, mapping=None, nograd_teacher=True)
- Parameters:
student (Union[nn.Module, GraphModule]) – Student; knowledge will be transferred to this model from teacher model.
teacher (Union[nn.Module, GraphModule]) – Teacher; knowledge will be transferred from this model to student model.
mapping (Optional[Mapping]) – Mapping specifies modules from where and where to distill.
nograd_teacher (bool) – Use no_grad decorator for teacher or not. Default value is True.
- class Mapping(student, teacher)
Mapping between modules of student and teacher models.
- __init__(student, teacher)
- Parameters:
student (Module) – Student module.
teacher (Module) – Teacher module.
- add(student_module, teacher_module, *, payload=None)
Add pair to mapping.
- Parameters:
student_module (Mappable) – Student module which will be associated with teacher module.
teacher_module (Mappable) – Teacher module which will be associated with student module.
payload (Any) – Payload, default value is None.
- Return type:
Adapter
Suppose we have two models: student and teacher, where student is a pruned version of the teacher; and we want to distill through corresponding conv2d modules.
Consider specific conv2d module in the teacher and corresponding to it conv2d module in the student. Since student is a pruned version of the teacher, the student module may have fewer output channels than the teacher module, so we cannot distill directly through these modules (tensors). The solution is an adapter (trainable) that narrows teacher’s conv2d output tensor to student’s conv2d output tensor:
Such an adapter must be properly initialized, otherwise distillation may fail.
We provide conv2d_adapter()
function to create a properly initialized adapter:
it maps all unpruned (with regard to the student model) channels in the teacher’s conv2d module
to the student’s conv2d module:
from enot.distillation.adapter import conv2d_adapter
adapter = conv2d_adapter(source=conv2d_teacher, target=conv2d_student)
mapping.add(conv2d_student, conv2d_teacher, payload=adapter) # Put adapter to payload.
# To use different criteria for different adapters, add a tuple to the payload: (adapter, criterion).
# Don't forget to add adapter parameters to the optimizer to make it trainable.
...
with distill(student=student, teacher=teacher, mapping=mapping) as distill_model:
...
output = distill_model(inputs)
loss: torch.Tensor = torch.zeros(1)
for student_output, teacher_output, adapter in output: # Payload is adapter.
loss += criterion(student_output, adapter(teacher_output))
...
We also provide inverted_conv2d_adapter()
if you want to do something like this:
loss += criterion(adapter(student_output), teacher_output)
- conv2d_adapter(source, target)
Conv2d 1x1 adapter for two convolutions: source and target. Target convolution should be pruned version of source convolution.
- Parameters:
source (torch.nn.Conv2d) – Convolution from which target is obtained.
target (torch.nn.Conv2d) – Pruned version of source convolution.
- Returns:
Adapter that maps all target (not pruned) channels in source convolution into target convolution.
- Return type:
- inverted_conv2d_adapter(source, target)
Inverted Conv2d 1x1 adapter for two convolutions: source and target. Target convolution should be pruned version of source convolution, but instead of narrowing source to target it extends target conv to source.
- Parameters:
source (torch.nn.Conv2d) – Convolution from which target is obtained.
target (torch.nn.Conv2d) – Pruned version of source convolution.
- Returns:
Adapter that maps all target (not pruned) channels in target convolution into source convolution.
- Return type: