ENOT Prunable Modules

ENOT Prunable Modules is a package designed for replacing unprunable modules with prunable ones. It is published under Apache 2.0 license, anyone can view and modify its code.

Installation

First of all you need Python>=3.8. ENOT Prunable Modules package can be installed from PyPI:

pip install enot-prunable-modules

Overview

To replace unprunable modules with their prunable versions, use replace_implementations function:

from enot_prunable_modules import ReplaceFactory
from enot_prunable_modules import replace_implementations

model = ...
replace_implementations(model, factory=ReplaceFactory.KANDINSKY)

After that some modules will be replaced following ReplaceFactory, in KANDINSKY case:

Attention ---> PrunableAttention
ResnetBlock2D ---> PrunableResnetBlock2D
Upsample2D ---> PrunableUpsample2D
Downsample2D ---> PrunableDownsample2D

ReplaceFactory has the following strategies for pruning:

  • KANDINSKY - for diffusers.models.UNet2DConditionModel from Kandinsky2-2.

  • GROUP_OPS - for torch.nn.Conv2d with groups > 1 and torch.nn.GroupNorm .

  • TORCH_ATTENTION - for torch.nn.MultiheadAttention.

  • TIMM_ATTENTION - for timm.models.vision_transformer.Attention.

To revert modules replacement use revert_implementation function:

from enot_prunable_modules import ReplaceFactory
from enot_prunable_modules import revert_implementations

revert_implementations(model, factory=ReplaceFactory.KANDINSKY)

This function will return original modules if it’s possible:

PrunableAttention ---> Attention
PrunableResnetBlock2D ---> ResnetBlock2D
PrunableUpsample2D -X-> Upsample2D
PrunableDownsample2D -X-> Downsample2D

Note

In case of Kandinsky2-2, if we return original module then assert will trigger, because of pruned channels. So we have to remove this assert and leave Upsample2D and Downsample2D replaced.

Example: Kandinsky2-2 pruning

import torch
from diffusers.models import UNet2DConditionModel
from enot.pruning import prune_model
from enot.pruning import PruningCalibrator
from enot.pruning import UniformPruningLabelSelector
from enot_prunable_modules import replace_implementations
from enot_prunable_modules import revert_implementations
from enot_prunable_modules.replace_factory import ReplaceFactory


unet = UNet2DConditionModel.from_pretrained(
    "kandinsky-community/kandinsky-2-2-decoder",
    subfolder="unet",
)
unet.eval()

# replace other modules
replace_implementations(
    unet,
    factory=[
        ReplaceFactory.KANDINSKY,
        ReplaceFactory.GROUP_OPS,
    ],
)

criterion = torch.nn.MSELoss()
pruning_calibrator = PruningCalibrator(
    model=unet,
    max_prunable_labels=25600,
)

# Calibrating
with pruning_calibrator:
    result = unet(
        sample=torch.ones(1, 4, 64, 64),
        timestep=torch.ones(1),
        encoder_hidden_states=None,
        added_cond_kwargs={
            "image_embeds": torch.ones(1, 1280)
        },
    ).sample[:, :4]

    loss = criterion(result, torch.ones(1, 4, 64, 64))
    loss.backward()

pruning_info = pruning_calibrator.pruning_info

# Pruning
pruning_ratio = 0.5
label_selector = UniformPruningLabelSelector(pruning_ratio=pruning_ratio)
prune_labels = label_selector.select(model=unet, pruning_info=pruning_info)

pruned_unet = prune_model(
    model=unet,
    pruning_info=pruning_info,
    prune_labels=prune_labels,
    inplace=False,
)

# return original modules
revert_implementations(
    pruned_unet,
    factory=[
        ReplaceFactory.KANDINSKY,
        ReplaceFactory.GROUP_OPS,
    ],
)