Quick Start with QAT (eager mode)

Directory

  • 1. Necessary steps Quick Overview
  • 2. Detail each step
  • 2.1 Floating point model preparation
  • < font color = “blue” size = “4” > 2.2 calibration data < br > < / font >
  • < font color = “blue” size = “4” > 2.3 quantitative training < br > < / font >
  • < font color = “blue” size = “4” > 2.4 fixed-point conversion < br > < / font >
  • < font color = “blue” size = “4” > 2.5 model compiled < br > < / font >
  • 3. Common Problems

Pytorch officially provides two Quantization modes: Eager Mode Quantization and FX Graph Mode Quantization, the differences between the two modes are as follows (source: PyTorch official document [* * * ] # (https://pytorch.org/docs/stable/quantization.html#quantization-api-summary)) : ! For the novice, or recommend FX priority mode ([ * QAT quick-and-dirty (FX mode) * *] (https://developer.horizon.cc/forumDetail/177840589839214597#)), Although FX mode does have some limitations that require the model to be “symbolically traceable”, it is much more automated overall, requiring users to write operators to merge, manually replace unsupported operators, and retreat to Eager mode if they run into problems they can’t solve.

1. Quick overview of necessary steps

The whole QAT solution consists of five steps from floating point to deployment model: floating point model preparation, data calibration, quantization training (optional), fixed point conversion, and model compilation. The necessary steps and sample code are shown below. For detailed instructions and precautions for each step, refer to the following section. For a complete example, refer to the eager_mode.py script in the /ddk/samples/ai_toolchain/horizon_model_train_sample/plugin_basic directory in the OE development package. ** It is highly recommended to skip the training process and complete the prepare->convert->check step before quantization training (even during the floating-point model design phase) to ensure that the model is hardware supported. **

from horizon_plugin_pytorch.quantization import (
    convert, 
    prepare_qat,
    set_fake_quantize,
    FakeQuantState,
    check_model,
    compile_model,
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
    default_qat_8bit_weight_32bit_out_fake_quant_qconfig
)
from horizon_plugin_pytorch.march import March, set_march
import torch

set_march(March.BAYES)
float_model = load_float_model(pretrain=True) 
float_model.fuse_model()
ori_float_model = float_model
float_model = copy.deepcopy(ori_float_model)
float_model.qconfig = default_calib_8bit_fake_quant_qconfig
float_model.classifier.qconfig = default_calib_8bit_weight_32bit_out_fake_quant_qconfig
calib_model = prepare_qat(float_model)
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calibrate(calib_model)
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
evaluate(calib_model)
torch.save(calib_model.state_dict(), "calib-checkpoint.ckpt")
float_model = copy.deepcopy(ori_float_model)
float_model.qconfig = default_qat_8bit_fake_quant_qconfig
float_model.classifier.qconfig = default_qat_8bit_weight_32bit_out_fake_quant_qconfig
qat_model = prepare_qat(float_model) 
qat_model.load_state_dict(calib_model.state_dict())
qat_model.train()
set_fake_quantize(qat_model, FakeQuantState.QAT)
train(qat_model)
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
evaluate(qat_model)
base_model = qat_model
quantized_model = convert(base_model)
evaluate(quantized_model)
script_model = torch.jit.trace(quantized_model.cpu(), example_input)
check_model(script_model, [example_input])
compile_model(script_model,[example_input],hbm="model.hbm",input_source="pyramid",opt=3)

2. Explain each step in detail

2.1 Floating-point model preparation

a. Please use enough data to train the floating point model normally until it converges before quantizing training. b. It is strongly recommended to normalize the input data, which is conducive to floating point convergence and makes the model more quantization-friendly. c. You are advised to check the supported operators list during the design phase of the floating point model to avoid prepare qat or compilation errors caused by using unsupported operators. d. Refer to the user manual for more instructions on how to build a quantitative friendly model 4.2.4.1 Requirements for Floating point models

If you select eager mode, In preparation for the floating-point model that need to complete the following steps necessary to (see [* * * *] Pytorch official document (https://pytorch.org/docs/stable/quantization.html#model-preparation-for-eage r-mode-static-quantization#)) :

  • ** Insert QuantStub node before model input and DequantStub node after model output **. There are the following precautions:

  • Multiple inputs can share a QuantStub only if they are of the same scale, otherwise define a separate QuantStub for each input

  • it is recommended to use horizon_plugin_pytorch. Quantization. QuantStub default input scale dynamic statistics, if can calculate the scale of the scene ahead suggest manually scale (for example, bev homo matrix) of the model, Version of the corresponding interface torch. The quantization. QuantStub does not support manual Settings.

  • ** Reference operator constraint list to complete operator substitution **. The reason for doing this step is that the eager mode has a certain limit on the support of operators, and it has some special processing on the quantization of some operators. (such as the need to use torch. Nn. ReLU replace torch. The nn. Functional. ReLU, add/cat/matmul etc need to be replaced for horizon. The nn. The quantized. FloatFunctional, specific refer to the following sample).

    from horizon_plugin_pytorch.nn.quantized import FloatFunctional

    class ModelName(nn.Module):
    def init(···):
    self.matmul1 = FloatFunctional()
    def forward(···):
    result = self.matmul1.matmul(a, b)
    ···

operator fusion, fusion operator can significantly improve the model of quantitative accuracy and the deployment of performance, so it is an indispensable step, can now support conv/ConvTranspose2d/linear, bn, relu/relu6, add the fusion between (specific may refer to: PythonPath/horizon_plugin_pytorch/quantizationfuse_modules.py) supports using operator subscripts or operator names to write fusion functions, as shown in the following example.

from horizon_plugin_pytorch.quantization import fuse_known_modules

class ConvBnRelu(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 2, 1, bias=None)
        self.bn = nn.BatchNorm2d(2)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        ···
    def fuse_model(self):
        torch.quantization.fuse_modules(
            self, 
            ["0", "1", "2"], 
            inplace=True,
            fuser_func=fuse_known_modules
        )
        torch.quantization.fuse_modules(
            self,
            ["conv", "bn", "relu"],
            inplace=True,
            fuser_func=fuse_known_modules,
        )

2.2 Data calibration

For some models, the accuracy can be achieved by Calibration alone, without time-consuming quantitative perception training. Even if the model cannot meet the accuracy requirements after quantitative calibration, this process can reduce the difficulty of subsequent quantitative perception training, shorten the training time, and improve the final training accuracy. Data Calibration way of specific configuration and tuning and Suggestions reference [QAT solution Calibration instructions * * * *] (https://developer.horizon.cc/forumDetail/177840589839214596).

2.3 Quantitative training

Some recommended hyperparameter configurations for quantitative training are shown in the following table:

Hyperparameter

Recommended configuration

Advanced configuration (try if the recommended configuration is invalid)

LR

Do scale=0.1 lr decay 2 times with StepLR starting from 0.001

1. Adjust lr between 0.0001 and >0.001 to match 1-2 lr decay. 2. The LR update policy can also try replacing StepLR with CosLR. 3. QAT uses AMP, appropriately lower lr, and too large results in nan.

Epoch

10% of the floating epoch

1. Based on the convergence of loss and metric, consider whether the epoch needs to be extended appropriately.

Weight decay

Consistent with floating point

1. It is recommended to make appropriate adjustments near 4e-5. Too small weight decay results in too large weight variance, which results in too large weight variance at the output layer of tasks with large output.

optimizer

Consistent with floating point

1. If floating-point training uses an optimizer such as OneCycle that affects LR Settings, it is recommended not to be consistent with floating-point and to use SGD instead.

transforms (data enhancement)

Transforms

Consistent with floating point

1. QAT stage can be appropriately weakened, such as the color conversion of classification can be removed, and the proportion range of RandomResizeCrop can be appropriately reduced

averaging_constant(qconfig_params)

1. Use calibration to recommend de-activation update: weight averaging_constant=1.0 activation averaging_constant=0.0

1. If there is a large calibration difference between activation averaging_constant and floating point, do not set it to 0.0 2. weight averaging_constant generally does not need to be set to 0.0, but can be adjusted between (0,1.0)

It is highly recommended that you try data calibration first, and then quantize training if the accuracy does not meet expectations (be careful to load the weight parameters after data calibration). Suggestions for tuning the quantitative training phase can be found in the [** User Manual] Quantitative training accuracy tuning recommendation * *] (HTTP: / / https://developer.horizon.cc/api/v1/fileData/horizon_j5_open_explorer_cn_doc/plugin/source/user_guide/debu G_precision.html #a-name-recommended-configuration-a).

2.4 Fixed point conversion

** Please note that there is no exact numerical agreement between the fixed-point model and the pseudo-quantized model, so please use the accuracy of the fixed-point model. If the fixed point accuracy is not up to the standard, it is still necessary to continue the quantization training, and it is recommended to retain several epochs of the qat model weights to facilitate the search for the optimal fixed point accuracy. (High accuracy of qat or calibrate does not necessarily mean high accuracy of fixed point, you can consider some backtracking to balance the final accuracy of fixed point) **

Under normal circumstances, the accuracy of the fixed-point model is exactly the same as that of the board deployment, so the model can be used to evaluate the final deployment accuracy.

2.5 Model compilation

The model compilation phase consists of the following three steps:

script_model = torch.jit.trace(quantized_model, example_input)
check_model(script_model.cpu(), [example_input])
compile_model(script_model,[example_input],hbm="model.hbm",input_source="pyramid",opt=O3)

compile_model() See compile_model() For more configuration items [user manual - model compilation of * * * *] (HTTP: / / https://developer.horizon.cc/api/v1/fileData/horizon_j5_open_explorer_cn_doc/plugin/source/api_reference /apis/compiler.html)

The script_model generated after trace can be saved with horizon_plugin_pytorch.jit.save interface and then transferred to other machines for inference evaluation. Since the inference saved model requires device to be consistent with trace, Using to(device) operation may result in a forward error. Specific reasons and recommended solutions can be found in [** User Manual - Quantifying Cross-device Inference of deploying PT Models] * ] (HTTP: / / https://developer.horizon.cc/api/v1/fileData/horizon_j5_open_explorer_cn_doc/plugin/source/user_guide/pt_device_mo ving.html). If the rgb/bgr format is used to train the model and input_souce is set to pyramid or resizer during deployment, the pre-processing nodes centered_yuv2rgb and centered_yuv2bgr need to be manually inserted before trace. For details, refer to [* User manual -RGB888 Data deployment * *] (HTTP: / / https://developer.horizon.cc/api/v1/fileData/horizon_j5_open_explorer_cn_doc/plugin/source/advanced_content/rgb8 88_deploy.html).

3. Common Problems

**1. Why set high precision output? ** A: According to the introduction in the background of neural network quantization, the activation value calculated by the multiplication accumulator is int32. In order to continue the calculation of the next layer op, it will be changed to int8/int16 through requantization. Therefore, if the last layer is a conv/linear node, it is recommended to set a high-precision output. The model can be output directly in int32 format, which is of great benefit for accuracy preservation.

plugin ≤ v1.6.2 To configure high-precision output, use default_calib_out_8bit_fake_quant_qconfig. However, this parameter will be deprecated in later versions

**2. How to understand several states of fake quantize? ** There are three states of the fake quantize, and set_fake_quantize should be used before QAT, calibration, and validation respectively to set the fake quantize of the model to the corresponding state. In calibration state, only the statistics of input and output of each operator are observed. In the QAT state, a pseudo-quantization operation is performed in addition to the observed statistics. In the validation state, statistics are not observed and only pseudo-quantization operations are performed. Interested people can be found under the python path horizon_plugin_pytorch quantization/fake_quantize. Py, see set_fake_quantize implementation, the following to intercept the key segments:

···
if mode == FakeQuantState.QAT:
    assert (
        mod.training
    ), "Call model.train() before set fake quant to QAT mode."
    enable_fake_quant(mod)
elif mode == FakeQuantState.CALIBRATION:
    assert (
        not mod.training
    ), "Call model.eval() before set fake quant to CALIBRATION mode."
    disable_fake_quant(mod)
    if isinstance(
        mod, FakeQuantizeBase
    ) or _is_fake_quant_script_module(mod):
        mod.train()
elif mode == FakeQuantState.VALIDATION:
    assert (
        not mod.training
    ), "Call model.eval() before set fake quant to VALIDATION mode."
    # observer won't work in eval mode
    enable_fake_quant(mod)

From this code we can learn: CALIBRATION and VALIDATION require the model to be in the train() state before QAT, and the calibration and validation require the model to be in the eval() state. This is mainly to ensure that bn and dropout are in the correct state (bn will be updated during training, but bn will not be updated during evaluation).

  • CALIBRATION disable_fake_quant and set the status of fake_quant to train(), that is, no pseudo quantization operation is performed, only observation operator input and output statistics are performed, and scale is updated. QAT will observe statistics and perform pseudo-quantization operations; In VALIDATION, statistics are not observed and only pseudo-quantization operations are performed.

Therefore, the following common misoperations may lead to some abnormal phenomena:

  • Before data calibration, the model is set to train() state, and ‘set_fake_quantize’ is not used, which is equivalent to running QAT training;
  • Before data calibration, the model is set to the eval() state, and ‘set_fake_quantize’ is not used, which causes the scale to remain in the initial state, all 1;
  • The model is set to eval() state before data calibration, and ‘set_fake_quantize’ is used correctly, but model.eval() is set again after this, which results in fake_quant not being trained, and scale remaining in the initial state, all 1;

**3. How do I fix that some parameters of the model are not updated? ** A: Especially in the multi-task model scenario, the model is quantized again after the head is added, and the quantization parameters of backbone and existing head are not changed. According to the introduction of the previous topic, Can be in set_fake_quantize (model, FakeQuantState. CALIBRATION) or ` set_fake_quantize (model, FakeQuantState.QAT) ‘and then execute’ model.backbone.eval() ', then backbone’s quantization parameters will not be updated.