Distillation
enot.distillation
module contains toolset for the knowledge distillation procedure.
Module provides distill
context manager and Mapping
class:
from enot.distillation import distill
from enot.distillation import Mapping
In knowledge distillation, a student model is generally supervised by a teacher model, therefore
distill.__init__()
and Mapping.__init__()
have teacher and student arguments.
To create distill context it is necessary to specify Mapping
between modules of a teacher model
and modules of student model through which you want to distill:
mapping = Mapping(teacher, student)
mapping.add(teacher.features[3], student.features[3]) # Specify by module reference.
mapping.add('features.5.conv', 'features.5.conv') # Specify by module name.
mapping.add(teacher.features[7], 'features.7') # Mixed.
Once the mapping is specified, distill context can be created and as a result it provides new (distillation) model object. It is a normal model, except that it returns paired tensors from the modules specified in the mapping. It remains to apply the desired criterion (loss) to them:
with distill(teacher=teacher, student=student, mapping=mapping) as distill_model:
inputs = torch.ones(1, 3, 224, 224) # Sample from dataloader.
optimizer.zero_grad()
output = distill_model(inputs)
loss: torch.Tensor = torch.zeros(1)
for teacher_output, student_output, _ in output: # Paired tensors from mapped modules.
loss += criterion(teacher_output, student_output)
loss.backward()
optimizer.step()
In most cases, in addition to the mapping between modules, it is useful to add something else, for example,
loss function or weight, or adapter between modules.
You can put any of it into payload
argument of Mapping.add()
and it will be returned with the paired
output of distillation model:
mapping.add(teacher.features[3], student.features[3], payload=nn.MSELoss())
mapping.add(teacher.features[5], student.features[5], payload=nn.CrossEntropyLoss())
mapping.add(teacher.features[7], student.features[7], payload=nn.MSELoss())
...
for teacher_output, student_output, criterion in output: # 3rd value is payload from mapping.
loss += criterion(teacher_output, student_output)
- class distill(teacher, student, mapping=None, nograd_teacher=True)
Context manager that glues teacher and student model into one distillation model.
Examples
>>> from enot.distillation import Mapping >>> from enot.distillation import distill
Create mapping between teacher and student models, also we additionaly added criterion for each mapping pair as payload.
>>> mapping = Mapping(teacher, student) # GraphModule - GraphModule mapping. >>> mapping.add('features_18_0', 'features_18_0', payload=nn.MSELoss()) >>> mapping.add('features_2_conv_1_0', 'features_2_conv_1_0', payload=nn.CrossEntropyLoss()) >>> mapping.add('features_5_conv_0_2', 'features_5_conv_0_2', payload=nn.MSELoss())
Prepare optimizer, scheduler, dataloaders, etc as usual.
>>> optimizer = RAdam(params=student.parameters())
Use distill context to distill teacher knowledge to student.
>>> with distill(teacher=teacher, student=student, mapping=mapping) as distill_model: >>> inputs = torch.ones(1, 3, 224, 224) >>> optimizer.zero_grad() >>> output = distill_model(inputs) >>> loss: torch.Tensor = torch.zeros(1) >>> for teacher_output, student_output, criterion in output: >>> loss += criterion(teacher_output, student_output) >>> loss.backward() >>> optimizer.step()
- __init__(teacher, student, mapping=None, nograd_teacher=True)
- Parameters:
teacher (Union[nn.Module, GraphModule]) – Teacher; knowledge will be transferred from this model to student model.
student (Union[nn.Module, GraphModule]) – Student; knowledge will be transferred to this model from teacher model.
mapping (Optional[Mapping]) – Mapping specifies modules from where and where to distill.
nograd_teacher (bool) – Use no_grad decorator for teacher or not. Default value is True.
- class Mapping(teacher, student)
Mapping between modules of teacher and student models.
- __init__(teacher, student)
- Parameters:
teacher (Module) – Teacher module.
student (Module) – Student module.
- add(teacher_module, student_module, payload=None)
Add pair to mapping.
- Parameters:
teacher_module (Mappable) – Teacher module which will be associated with student module.
student_module (Mappable) – Student module which will be associated with teacher module.
payload (Any) – Payload, default value is None.
- Return type: