MMENOT

To integrate ENOT Framework with the OpenMMLab codebase we provide the MMENOT package. MMENOT is published under Apache 2.0 license, anyone can view and modify its code.

Installation

MMENOT package can be installed from PyPI:

pip install mmenot

Note

The package depends on the enot-autodl package, please read carefully the ENOT Framework license setup guide and install it first.

Note

The package supports optional dependencies, to make sure everything is installed correctly you can specify the type of the OpenMMLab package: mmdet, mmseg or both.

Example: pip install mmenot[mmdet,mmseg]

MMENOT depends on OpenMMLab packages like mmengine and mmcv, they can be installed using the mim package manager distributed by OpenMMLab:

pip install openmim
mim install mmengine mmcv

Depending on whether you want to work with MMDetection or MMSegmentation you need to install one of them or both:

  • MMDetection: mim install mmdet

  • MMSegmentation: pip install mmsegmentation

You can also install them from the git repositories.

Since OpenMMLab packages are installed from different sources, the correct installation of them together with the mmenot package looks like this:

# MMDetection:
pip install openmim
mim install mmengine mmcv
mim install mmdet
pip install mmenot[mmdet]

# MMSegmentation:
pip install openmim
mim install mmengine mmcv
pip install mmsegmentation
pip install mmenot[mmseg]

Package overview

The MMENOT package provides four tools:

  • mmprune — the core of the package, it integrates the ENOT pruning module into the OpenMMLab codebase. Only this module actually depends on the enot-autodl package and requires a license. For details you can see pruning module documentation or our tutorials.

  • mmtune — fine-tuning script.

  • mmhpo — hyperparameters search script.

  • mmval — validation script.

  • mmexport — PyTorch to ONNX export script.

Note

The last three scripts are standard training/validation/export scripts from OpenMMLab except that they can work with checkpoints and are obtained from mmprune.

Warning

Here and further we will call a checkpoint a torch.nn.Module saved as a checkpoint, a state dict - a state dict with metadata which is usually saved by OpenMMLab.

When installing the package, each of these scripts is added to the path of the virtual environment, for example, you can see mmprune arguments like this:

mmprune --help

If you enter this, you can see a lot of arguments, but these three are the most interesting:

--pruning-type PRUNING_TYPE
    pruning type: equal, global, global-mmacs, optimal-hw-aware
--pruning-parameter PRUNING_PARAMETER
    has its own meaning for each type of pruning
--input-shape INPUT_SHAPE [INPUT_SHAPE ...]
    input shape for MMAC / Latency calculation

pruning-type parameter allows you to choose the type of pruning, for each of these types pruning-parameter has its own meaning and value range:

pruning-type

pruning-parameter

range

equal

percentage of prunable channels to remove in each group

(0, 1)

global

percentage of prunable channels to remove within all network

(0, 1)

global-mmacs

acceleration ratio:
pruned_latency < baseline_latency * acceleration_ratio
so if you want to speed up the model by mmacs twice,
use the parameter equal to 0.5.

(0, 1)

optimal-hw-aware

acceleration ratio:
pruned_latency < baseline_latency * acceleration_ratio

(0, 1)

Note

For optimal-hw-aware pruning type ENOT Latency Server must be deployed. The server host and port are passed as arguments: --host <host> --port <port>.

Deploy configuration (from the MMDeploy package) also should be passed for implementing correct export to ONNX for remote latency measurement: --deploy-cfg <path/to/deploy/config.py>

input-shape parameter is used to calculate the total number of multiply-accumulate operations or latency of a neural network, this is an important metric, be careful when specifying this parameter.

Workflow

Suppose you have decided on the model and dataset, prepared OpenMMLab configuration files, decided on the type of pruning and input shape of your model (for example equal with percentage 0.2, 1280x800), then the first step is the following command:

mmprune path/to/config.py \
    --pruning-type equal \
    --pruning-parameter 0.2 \
    --input-shape 1280 800

After the pruning is completed, a checkpoint of the pruned model will be generated in the working folder: pruned.pth. It is this checkpoint that must be used for the second step — hyperparameters search (optional) and fine-tuning:

mmhpo path/to/config.py --checkpoint path/to/pruned.pth --lr-range 0.01 0.1
mmtune path/to/config.py path/to/pruned.pth

Note

mmtune supports knowledge distillation, see mmtune --help for details.

Note

mmprune, mmhpo and mmtune produce checkpoints and state dicts.

Checkpoints must be used for tools in the mmenot package, state dicts can be used for standard OpenMMLab tools.

You can easily distinguish them from each other by the _state_dict postfix.

After you have trained the pruned model to the required metrics, you can also validate it and export it to ONNX:

mmval path/to/config.py path/to/best.pth
mmexport path/to/deploy/config.py \
    path/to/model/config.py \
    path/to/best.pth \
    path/to/image.jpg

As a result, you will get an ONNX model that will run faster than the original model.

Distributed

For distributed fine-tuning of a pruned model you can use the following script:

dist_tune.sh
#!/usr/bin/env bash

CONFIG=$1
CHECKPOINT=$2
WORK_DIR=$3
GPUS=$4
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
TUNE_EXEC=$(python -c "import mmenot.trainer; print(mmenot.trainer.__file__)")

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
    --nnodes=$NNODES \
    --node_rank=$NODE_RANK \
    --master_addr=$MASTER_ADDR \
    --nproc_per_node=$GPUS \
    --master_port=$PORT \
    $TUNE_EXEC \
    $CONFIG \
    $CHECKPOINT \
    --work-dir $WORK_DIR \
    --launcher pytorch ${@:5}

For example:

bash dist_tune.sh \
    path/to/config.py \
    path/to/pruned.pth \
    work/dir \
    4  # number of GPUs