ENOT Prunable Modules
ENOT Prunable Modules is a package designed for replacing unprunable modules with prunable ones. It is published under Apache 2.0 license, anyone can view and modify its code.
Installation
First of all you need Python>=3.8. ENOT Prunable Modules package can be installed from PyPI:
pip install enot-prunable-modules
Overview
To replace unprunable modules with their prunable versions, use replace_implementations
function:
from enot_prunable_modules import ReplaceFactory
from enot_prunable_modules import replace_implementations
model = ...
replace_implementations(model, factory=ReplaceFactory.KANDINSKY)
After that some modules will be replaced following ReplaceFactory
, in KANDINSKY
case:
Attention ---> PrunableAttention
ResnetBlock2D ---> PrunableResnetBlock2D
Upsample2D ---> PrunableUpsample2D
Downsample2D ---> PrunableDownsample2D
ReplaceFactory
has the following strategies for pruning:
KANDINSKY
- fordiffusers.models.UNet2DConditionModel
from Kandinsky2-2.GROUP_OPS
- fortorch.nn.Conv2d
with groups > 1 andtorch.nn.GroupNorm
.TORCH_ATTENTION
- fortorch.nn.MultiheadAttention
.TIMM_ATTENTION
- fortimm.models.vision_transformer.Attention
.…
To revert modules replacement use revert_implementation
function:
from enot_prunable_modules import ReplaceFactory
from enot_prunable_modules import revert_implementations
revert_implementations(model, factory=ReplaceFactory.KANDINSKY)
This function will return original modules if it’s possible:
PrunableAttention ---> Attention
PrunableResnetBlock2D ---> ResnetBlock2D
PrunableUpsample2D -X-> Upsample2D
PrunableDownsample2D -X-> Downsample2D
Note
In case of Kandinsky2-2, if we return original module then assert will trigger, because of pruned channels.
So we have to remove this assert and leave Upsample2D
and Downsample2D
replaced.
Example: Kandinsky2-2 pruning
import torch
from diffusers.models import UNet2DConditionModel
from enot.pruning import prune_model
from enot.pruning import PruningCalibrator
from enot.pruning import UniformPruningLabelSelector
from enot_prunable_modules import replace_implementations
from enot_prunable_modules import revert_implementations
from enot_prunable_modules.replace_factory import ReplaceFactory
unet = UNet2DConditionModel.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder",
subfolder="unet",
)
unet.eval()
# replace other modules
replace_implementations(
unet,
factory=[
ReplaceFactory.KANDINSKY,
ReplaceFactory.GROUP_OPS,
],
)
criterion = torch.nn.MSELoss()
pruning_calibrator = PruningCalibrator(
model=unet,
max_prunable_labels=25600,
)
# Calibrating
with pruning_calibrator:
result = unet(
sample=torch.ones(1, 4, 64, 64),
timestep=torch.ones(1),
encoder_hidden_states=None,
added_cond_kwargs={
"image_embeds": torch.ones(1, 1280)
},
).sample[:, :4]
loss = criterion(result, torch.ones(1, 4, 64, 64))
loss.backward()
pruning_info = pruning_calibrator.pruning_info
# Pruning
pruning_ratio = 0.5
label_selector = UniformPruningLabelSelector(pruning_ratio=pruning_ratio)
prune_labels = label_selector.select(model=unet, pruning_info=pruning_info)
pruned_unet = prune_model(
model=unet,
pruning_info=pruning_info,
prune_labels=prune_labels,
inplace=False,
)
# return original modules
revert_implementations(
pruned_unet,
factory=[
ReplaceFactory.KANDINSKY,
ReplaceFactory.GROUP_OPS,
],
)