|
| 1 | +PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor |
| 2 | +======================================================================================== |
| 3 | + |
| 4 | +**Author**: `Leslie Fang <https://github.com/leslie-fang-intel>`_, `Weiwen Xia <https://github.com/Xia-Weiwen>`_, `Jiong Gong <https://github.com/jgong5>`_, `Jerry Zhang <https://github.com/jerryzh168>`_ |
| 5 | + |
| 6 | +Prerequisites |
| 7 | +^^^^^^^^^^^^^^^ |
| 8 | + |
| 9 | +- `PyTorch 2 Export Post Training Quantization <https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html>`_ |
| 10 | +- `TorchInductor and torch.compile concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ |
| 11 | + |
| 12 | +Introduction |
| 13 | +^^^^^^^^^^^^^^ |
| 14 | + |
| 15 | +This tutorial introduces the steps for utilizing the PyTorch 2 Export Quantization flow to generate a quantized model customized |
| 16 | +for the x86 inductor backend and explains how to lower the quantized model into the inductor. |
| 17 | + |
| 18 | +The new quantization 2 flow uses the PT2 Export to capture the model into a graph and perform quantization transformations on top of the ATen graph. This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX. |
| 19 | +TorchInductor is the new compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels. |
| 20 | + |
| 21 | +This flow of quantization 2 with Inductor mainly includes three steps: |
| 22 | + |
| 23 | +- Step 1: Capture the FX Graph from the eager Model based on the `torch export mechanism <https://pytorch.org/docs/main/export.html>`_. |
| 24 | +- Step 2: Apply the Quantization flow based on the captured FX Graph, including defining the backend-specific quantizer, generating the prepared model with observers, |
| 25 | + performing the prepared model's calibration, and converting the prepared model into the quantized model. |
| 26 | +- Step 3: Lower the quantized model into inductor with the API ``torch.compile``. |
| 27 | + |
| 28 | +The high-level architecture of this flow could look like this: |
| 29 | + |
| 30 | +:: |
| 31 | + |
| 32 | + float_model(Python) Example Input |
| 33 | + \ / |
| 34 | + \ / |
| 35 | + —-------------------------------------------------------- |
| 36 | + | export | |
| 37 | + —-------------------------------------------------------- |
| 38 | + | |
| 39 | + FX Graph in ATen |
| 40 | + | X86InductorQuantizer |
| 41 | + | / |
| 42 | + —-------------------------------------------------------- |
| 43 | + | prepare_pt2e | |
| 44 | + | | | |
| 45 | + | Calibrate/Train | |
| 46 | + | | | |
| 47 | + | convert_pt2e | |
| 48 | + —-------------------------------------------------------- |
| 49 | + | |
| 50 | + Quantized Model |
| 51 | + | |
| 52 | + —-------------------------------------------------------- |
| 53 | + | Lower into Inductor | |
| 54 | + —-------------------------------------------------------- |
| 55 | + | |
| 56 | + Inductor |
| 57 | + |
| 58 | +Combining Quantization in PyTorch 2 Export and TorchInductor, we have flexibility and productivity with the new Quantization frontend |
| 59 | +and outstanding out-of-box performance with the compiler backend. Especially on Intel fourth generation (SPR) Xeon processors which can |
| 60 | +further boost the models' performance by leveraging the |
| 61 | +`advanced-matrix-extensions <https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/advanced-matrix-extensions/overview.html>`_ feature. |
| 62 | + |
| 63 | +Now, we will walk you through a step-by-step tutorial for how to use it with `torchvision resnet18 model <https://download.pytorch.org/models/resnet18-f37072fd.pth>`_. |
| 64 | + |
| 65 | +1. Capture FX Graph |
| 66 | +--------------------- |
| 67 | + |
| 68 | +We will start by performing the necessary imports, capturing the FX Graph from the eager module. |
| 69 | + |
| 70 | +:: |
| 71 | + |
| 72 | + import torch |
| 73 | + import torchvision.models as models |
| 74 | + import copy |
| 75 | + from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e |
| 76 | + import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq |
| 77 | + from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer |
| 78 | + from torch._export import capture_pre_autograd_graph |
| 79 | + |
| 80 | + # Create the Eager Model |
| 81 | + model_name = "resnet18" |
| 82 | + model = models.__dict__[model_name](pretrained=True) |
| 83 | + |
| 84 | + # Set the model to eval mode |
| 85 | + model = model.eval() |
| 86 | + |
| 87 | + # Create the data, using the dummy data here as an example |
| 88 | + traced_bs = 50 |
| 89 | + x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last) |
| 90 | + example_inputs = (x,) |
| 91 | + |
| 92 | + # Capture the FX Graph to be quantized |
| 93 | + with torch.no_grad(): |
| 94 | + # if you are using the PyTorch nightlies or building from source with the pytorch master, |
| 95 | + # use the API of `capture_pre_autograd_graph` |
| 96 | + # Note 1: `capture_pre_autograd_graph` is also a short-term API, it will be updated to use the official `torch.export` API when that is ready. |
| 97 | + exported_model = capture_pre_autograd_graph( |
| 98 | + model, |
| 99 | + example_inputs |
| 100 | + ) |
| 101 | + # Note 2: if you are using the PyTorch 2.1 release binary or building from source with the PyTorch 2.1 release branch, |
| 102 | + # please use the API of `torch._dynamo.export` to capture the FX Graph. |
| 103 | + # exported_model, guards = torch._dynamo.export( |
| 104 | + # model, |
| 105 | + # *copy.deepcopy(example_inputs), |
| 106 | + # aten_graph=True, |
| 107 | + # ) |
| 108 | + |
| 109 | + |
| 110 | +Next, we will have the FX Module to be quantized. |
| 111 | + |
| 112 | +2. Apply Quantization |
| 113 | +---------------------------- |
| 114 | + |
| 115 | +After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to |
| 116 | +quantize the model. |
| 117 | + |
| 118 | +:: |
| 119 | + |
| 120 | + quantizer = X86InductorQuantizer() |
| 121 | + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) |
| 122 | + |
| 123 | +.. note:: |
| 124 | + |
| 125 | + The default quantization configuration in ``X86InductorQuantizer`` uses 8-bits for both activations and weights. |
| 126 | + When Vector Neural Network Instruction is not available, the oneDNN backend silently chooses kernels that assume |
| 127 | + `multiplications are 7-bit x 8-bit <https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html#inputs-of-mixed-type-u8-and-s8>`_. In other words, potential |
| 128 | + numeric saturation and accuracy issue may happen when running on CPU without Vector Neural Network Instruction. |
| 129 | + |
| 130 | +After we import the backend-specific Quantizer, we will prepare the model for post-training quantization. |
| 131 | +``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model. |
| 132 | + |
| 133 | +:: |
| 134 | + |
| 135 | + prepared_model = prepare_pt2e(exported_model, quantizer) |
| 136 | + |
| 137 | +Now, we will calibrate the ``prepared_model`` after the observers are inserted in the model. |
| 138 | + |
| 139 | +:: |
| 140 | + |
| 141 | + # We use the dummy data as an example here |
| 142 | + prepared_model(*example_inputs) |
| 143 | + |
| 144 | + # Alternatively: user can define the dataset to calibrate |
| 145 | + # def calibrate(model, data_loader): |
| 146 | + # model.eval() |
| 147 | + # with torch.no_grad(): |
| 148 | + # for image, target in data_loader: |
| 149 | + # model(image) |
| 150 | + # calibrate(prepared_model, data_loader_test) # run calibration on sample data |
| 151 | + |
| 152 | +Finally, we will convert the calibrated Model to a quantized Model. ``convert_pt2e`` takes a calibrated model and produces a quantized model. |
| 153 | + |
| 154 | +:: |
| 155 | + |
| 156 | + converted_model = convert_pt2e(prepared_model) |
| 157 | + |
| 158 | +After these steps, we finished running the quantization flow and we will get the quantized model. |
| 159 | + |
| 160 | + |
| 161 | +3. Lower into Inductor |
| 162 | +------------------------ |
| 163 | + |
| 164 | +After we get the quantized model, we will further lower it to the inductor backend. |
| 165 | + |
| 166 | +:: |
| 167 | + |
| 168 | + optimized_model = torch.compile(converted_model) |
| 169 | + |
| 170 | + # Running some benchmark |
| 171 | + optimized_model(*example_inputs) |
| 172 | + |
| 173 | + |
| 174 | +Put all these codes together, we will have the toy example code. |
| 175 | +Please note that since the Inductor ``freeze`` feature does not turn on by default yet, run your example code with ``TORCHINDUCTOR_FREEZING=1``. |
| 176 | + |
| 177 | +For example: |
| 178 | + |
| 179 | +:: |
| 180 | + |
| 181 | + TORCHINDUCTOR_FREEZING=1 python example_x86inductorquantizer_pytorch_2_1.py |
| 182 | + |
| 183 | +4. Conclusion |
| 184 | +--------------- |
| 185 | + |
| 186 | +With this tutorial, we introduce how to use Inductor with X86 CPU in PyTorch 2 Quantization. Users can learn about |
| 187 | +how to use ``X86InductorQuantizer`` to quantize a model and lower it into the inductor with X86 CPU devices. |
0 commit comments