.. _ENOT Prunable Modules: ##################### ENOT Prunable Modules ##################### **ENOT Prunable Modules** is a package designed for replacing unprunable modules with prunable ones. It is published under Apache 2.0 license, anyone can view and modify its code. Installation ============ First of all you need Python>=3.8. ENOT Prunable Modules package can be installed from PyPI: :: pip install enot-prunable-modules Overview ======== To replace unprunable modules with their prunable versions, use ``replace_implementations`` function: .. code-block:: python from enot_prunable_modules import ReplaceFactory from enot_prunable_modules import replace_implementations model = ... replace_implementations(model, factory=ReplaceFactory.KANDINSKY) After that some modules will be replaced following ``ReplaceFactory``, in ``KANDINSKY`` case: .. code-block:: python Attention ---> PrunableAttention ResnetBlock2D ---> PrunableResnetBlock2D Upsample2D ---> PrunableUpsample2D Downsample2D ---> PrunableDownsample2D ``ReplaceFactory`` has the following strategies for pruning: - ``KANDINSKY`` - for ``diffusers.models.UNet2DConditionModel`` from Kandinsky2-2. - ``GROUP_OPS`` - for ``torch.nn.Conv2d`` with groups > 1 and ``torch.nn.GroupNorm`` . - ``TORCH_ATTENTION`` - for ``torch.nn.MultiheadAttention``. - ``TIMM_ATTENTION`` - for ``timm.models.vision_transformer.Attention``. - ... To revert modules replacement use ``revert_implementation`` function: .. code-block:: python from enot_prunable_modules import ReplaceFactory from enot_prunable_modules import revert_implementations revert_implementations(model, factory=ReplaceFactory.KANDINSKY) This function will return original modules if it's possible: .. code-block:: python PrunableAttention ---> Attention PrunableResnetBlock2D ---> ResnetBlock2D PrunableUpsample2D -X-> Upsample2D PrunableDownsample2D -X-> Downsample2D .. note:: In case of Kandinsky2-2, if we return original module then assert will trigger, because of pruned channels. So we have to remove this assert and leave ``Upsample2D`` and ``Downsample2D`` replaced. Example: Kandinsky2-2 pruning ================================= .. code-block:: python import torch from diffusers.models import UNet2DConditionModel from enot.pruning import prune_model from enot.pruning import PruningCalibrator from enot.pruning import UniformPruningLabelSelector from enot_prunable_modules import replace_implementations from enot_prunable_modules import revert_implementations from enot_prunable_modules.replace_factory import ReplaceFactory unet = UNet2DConditionModel.from_pretrained( "kandinsky-community/kandinsky-2-2-decoder", subfolder="unet", ) unet.eval() # replace other modules replace_implementations( unet, factory=[ ReplaceFactory.KANDINSKY, ReplaceFactory.GROUP_OPS, ], ) criterion = torch.nn.MSELoss() pruning_calibrator = PruningCalibrator( model=unet, max_prunable_labels=25600, ) # Calibrating with pruning_calibrator: result = unet( sample=torch.ones(1, 4, 64, 64), timestep=torch.ones(1), encoder_hidden_states=None, added_cond_kwargs={ "image_embeds": torch.ones(1, 1280) }, ).sample[:, :4] loss = criterion(result, torch.ones(1, 4, 64, 64)) loss.backward() pruning_info = pruning_calibrator.pruning_info # Pruning pruning_ratio = 0.5 label_selector = UniformPruningLabelSelector(pruning_ratio=pruning_ratio) prune_labels = label_selector.select(model=unet, pruning_info=pruning_info) pruned_unet = prune_model( model=unet, pruning_info=pruning_info, prune_labels=prune_labels, inplace=False, ) # return original modules revert_implementations( pruned_unet, factory=[ ReplaceFactory.KANDINSKY, ReplaceFactory.GROUP_OPS, ], )