.. _ENOT Prunable Modules:

#####################
ENOT Prunable Modules
#####################

**ENOT Prunable Modules** is a package designed for replacing unprunable modules with prunable ones.
It is published under Apache 2.0 license, anyone can view and modify its code.


Installation
============

First of all you need Python>=3.8.
ENOT Prunable Modules package can be installed from PyPI:
::

    pip install enot-prunable-modules


Overview
========

To replace unprunable modules with their prunable versions, use ``replace_implementations`` function:

.. code-block:: python

    from enot_prunable_modules import ReplaceFactory
    from enot_prunable_modules import replace_implementations

    model = ...
    replace_implementations(model, factory=ReplaceFactory.KANDINSKY)

After that some modules will be replaced following ``ReplaceFactory``, in ``KANDINSKY`` case:

.. code-block:: python

    Attention ---> PrunableAttention
    ResnetBlock2D ---> PrunableResnetBlock2D
    Upsample2D ---> PrunableUpsample2D
    Downsample2D ---> PrunableDownsample2D

``ReplaceFactory`` has the following strategies for pruning:

- ``KANDINSKY`` - for ``diffusers.models.UNet2DConditionModel`` from Kandinsky2-2.
- ``GROUP_OPS`` - for ``torch.nn.Conv2d`` with groups > 1 and ``torch.nn.GroupNorm`` .
- ``TORCH_ATTENTION`` - for ``torch.nn.MultiheadAttention``.
- ``TIMM_ATTENTION`` - for ``timm.models.vision_transformer.Attention``.
- ...


To revert modules replacement use ``revert_implementation`` function:

.. code-block:: python

    from enot_prunable_modules import ReplaceFactory
    from enot_prunable_modules import revert_implementations

    revert_implementations(model, factory=ReplaceFactory.KANDINSKY)

This function will return original modules if it's possible:

.. code-block:: python

    PrunableAttention ---> Attention
    PrunableResnetBlock2D ---> ResnetBlock2D
    PrunableUpsample2D -X-> Upsample2D
    PrunableDownsample2D -X-> Downsample2D

.. note::
    In case of Kandinsky2-2, if we return original module then assert will trigger, because of pruned channels. 
    So we have to remove this assert and leave ``Upsample2D`` and ``Downsample2D`` replaced.


Example: Kandinsky2-2 pruning
=================================

.. code-block:: python

    import torch
    from diffusers.models import UNet2DConditionModel
    from enot.pruning import prune_model
    from enot.pruning import PruningCalibrator
    from enot.pruning import UniformPruningLabelSelector
    from enot_prunable_modules import replace_implementations
    from enot_prunable_modules import revert_implementations
    from enot_prunable_modules.replace_factory import ReplaceFactory


    unet = UNet2DConditionModel.from_pretrained(
        "kandinsky-community/kandinsky-2-2-decoder", 
        subfolder="unet",
    )
    unet.eval()

    # replace other modules
    replace_implementations(
        unet, 
        factory=[
            ReplaceFactory.KANDINSKY, 
            ReplaceFactory.GROUP_OPS,
        ],
    )

    criterion = torch.nn.MSELoss()
    pruning_calibrator = PruningCalibrator(
        model=unet, 
        max_prunable_labels=25600,
    )

    # Calibrating
    with pruning_calibrator:
        result = unet(
            sample=torch.ones(1, 4, 64, 64),
            timestep=torch.ones(1),
            encoder_hidden_states=None,
            added_cond_kwargs={
                "image_embeds": torch.ones(1, 1280)
            },
        ).sample[:, :4]

        loss = criterion(result, torch.ones(1, 4, 64, 64))
        loss.backward()

    pruning_info = pruning_calibrator.pruning_info

    # Pruning
    pruning_ratio = 0.5
    label_selector = UniformPruningLabelSelector(pruning_ratio=pruning_ratio)
    prune_labels = label_selector.select(model=unet, pruning_info=pruning_info)

    pruned_unet = prune_model(
        model=unet,
        pruning_info=pruning_info,
        prune_labels=prune_labels,
        inplace=False,
    )

    # return original modules
    revert_implementations(
        pruned_unet, 
        factory=[
            ReplaceFactory.KANDINSKY, 
            ReplaceFactory.GROUP_OPS,
        ],
    )