Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Export QAT model is not performing as expected when compared to the original model and FX Graph QAT #150746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Jacobdelgado1002 opened this issue Apr 6, 2025 · 5 comments
Labels
needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: export oncall: pt2 oncall: quantization Quantization support in PyTorch

Comments

@Jacobdelgado1002
Copy link

Jacobdelgado1002 commented Apr 6, 2025

πŸ› Describe the bug

I'm trying to perform QAT utilizing MobileNetV2 with the goal of converting it into TFLite. However, after training the model, I run a bench-marking script to compare its performance to the original model and see that the performance deprecates greatly.

Here are the important code snippets:

from torchvision import models
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config


model = models.mobilenet_v2(weights='DEFAULT')

example_inputs = (next(iter(dataloader))[0].to(device),)
model = torch.export.export_for_training(model, example_inputs).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_qat=True))

model = prepare_qat_pt2e(model, quantizer)

train_model(model)

I only included what I thought was relevant since I didn't want to add confusion with all of my helper functions

def train_model(model):

  for phase in ['train', 'val']:
            is_train = phase == 'train'

            if is_train:
                torch.ao.quantization.move_exported_model_to_train(model)
            else:
                # Switch to evaluation mode to perform inference
                torch.ao.quantization.move_exported_model_to_eval(model)

            data_loader = train_loader if is_train else val_loader

            running_loss = 0.0
            total_samples  = 0.0
            predictions, ground_truths, probabilities = [], [], []

            with tqdm(total=len(data_loader), desc=f"{phase.capitalize()} Epoch {epoch + 1}/{epochs}") as pbar:
                for inputs, labels in data_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    # Zero gradients only during training
                    if is_train:
                        optimizer.zero_grad()
               
                    # Enable gradients only in training phase
                    with torch.set_grad_enabled(is_train):                        

                        model = model.to(device)
                        model_logits = model(inputs)

                        soft_loss = compute_distillation_loss(model_logits)

                        label_loss, probs, preds = compute_loss_and_predictions(model_logits, labels, criterion)
                        
                        # Compute weighted combination of the distillation and cross entropy losses
                        loss = soft_target_loss_weight * soft_loss + ce_loss_weight * label_loss

                        # Backward pass and optimizer step in training phase
                        if is_train:
                            loss.backward()
                            optimizer.step()
                    
                    # Update progress bar with average loss so far
                    pbar.set_postfix(loss=f"{running_loss / total_samples:.4f}")
                    pbar.update(1)

Actual vs expected behavior:

I would expect that the quantized model has better performance than the original model but it does not.

Original QAT
Model Size (MB) 9.1899 11.1504
Inference Time (sec/sample) 0.002896 0.011141
Throughput (samples/sec) 345.29 89.76
Energy per Sample (Joules) 0.3436 1.350853
Throughput per Watt (samples/sec/W) 2.91 0.74

This is even stranger since if I switch to FX Graph QAT, I get the expected behavior. However, I need to use Export quantization since I want to use the ai-edge-torch API to convert my model to TFLite.

Original QAT
Model Size (MB) 9.1899 2.3465
Inference Time (sec/sample) 0.002896 000250
Throughput (samples/sec) 345.29 4003.28
Energy per Sample (Joules) 0.3436 0.0271
Throughput per Watt (samples/sec/W) 2.91 36.85

Additionally, when I print the resulting QAT model I get the following:

GraphModule(
(features): Module(
(0): Module(
  (1): Module()
)
(1): Module(
  (conv): Module(
    (0): Module(
      (1): Module()
    )
    (2): Module()
  )
)
(2): Module(
  (conv): Module(
    (0): Module(
      (1): Module()
    )
    (1): Module(
      (1): Module()
    )
    (3): Module()
  )
)
(3): Module(
...

I would think that it would be more similar to the resulting QAT model from FX Graph quantization which leads me to believe that it is not training correctly. The FX Graph is added below:

GraphModule(
  (features): Module(
    (0): Module(
      (0): QuantizedConv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.22475136816501617, zero_point=113, padding=(1, 1))
      (2): ReLU6(inplace=True)
    )
    (1): Module(
      (conv): Module(
        (0): Module(
          (0): QuantizedConv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.36381739377975464, zero_point=112, padding=(1, 1), groups=32)
          (2): ReLU6(inplace=True)
        )
        (1): QuantizedConv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), scale=0.5194709300994873, zero_point=139)
      )
    )
...

Versions

My system has a AMD Ryzenβ„’ Threadripperβ„’ 7960Xs Γ— 48 and a NVIDIA GeForce RTX 4090

Here is my virtual env:

absl-py==2.2.1
ai-edge-litert==1.2.0
ai-edge-quantizer==0.1.0
ai-edge-torch==0.4.0
anyio==4.8.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
attrs==25.3.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2024.12.14
cffi==1.17.1
charset-normalizer==3.4.1
coloredlogs==15.0.1
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.6
decorator==5.1.1
defusedxml==0.7.1
execnet==2.1.1
executing==2.1.0
executorch==0.5.0
expecttest==0.3.0
fastjsonschema==2.21.1
filelock==3.17.0
flatbuffers==25.2.10
fonttools==4.55.8
fqdn==1.5.1
fsspec==2024.12.0
gast==0.6.0
google-pasta==0.2.0
grpcio==1.71.0
h11==0.14.0
h5py==3.13.0
httpcore==1.0.7
httpx==0.28.1
humanfriendly==10.0
hypothesis==6.130.8
idna==3.10
immutabledict==4.2.1
iniconfig==2.1.0
ipykernel==6.29.5
ipython==8.28.0
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.5.3
jaxlib==0.5.3
jedi==0.19.1
Jinja2==3.1.5
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter==1.1.1
jupyter-console==6.6.3
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kaggle==1.6.17
keras==3.9.1
kiwisolver==1.4.8
libclang==18.1.1
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.10.0
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.1.2
ml_dtypes==0.5.1
mpmath==1.3.0
namex==0.0.8
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
notebook==7.3.2
notebook_shim==0.2.4
numpy==2.0.0
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
onnx==1.16.1
onnx-graphsurgeon==0.5.7
onnx-tf==1.6.0
onnx2tf==1.27.1
onnxruntime==1.21.0
onnxscript==0.2.3
opt_einsum==3.4.0
optree==0.14.1
overrides==7.7.0
packaging==24.2
pandas==2.2.2
pandocfilters==1.5.1
parameterized==0.9.0
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.6
pluggy==1.5.0
prometheus_client==0.21.1
prompt_toolkit==3.0.48
protobuf==3.20.3
psutil==6.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
Pygments==2.19.1
pyparsing==3.2.1
pyRAPL==0.2.3.1
pytest==8.3.5
pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
python-json-logger==3.3.0
python-slugify==8.0.4
pytz==2024.2
PyYAML==6.0.2
pyzmq==26.2.0
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.9.4
rpds-py==0.23.1
ruamel.yaml==0.18.10
ruamel.yaml.clib==0.2.12
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.1
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==75.8.0
six==1.17.0
sng4onnx==1.0.4
sniffio==1.3.1
sortedcontainers==2.4.0
soupsieve==2.6
stack-data==0.6.3
sympy==1.13.1
tabulate==0.9.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tensorflow==2.19.0
termcolor==2.5.0
terminado==0.18.1
text-unidecode==1.3
tf2onnx==1.16.1
tf_keras==2.19.0
tflite==2.18.0
threadpoolctl==3.5.0
tinycss2==1.4.0
torch==2.6.0
torch_xla2==0.0.1.dev202412041639
torchaudio==2.6.0
torchsummary==1.5.1
torchvision==0.21.0
tornado==6.4.1
tqdm==4.67.1
traitlets==5.14.3
triton==3.2.0
types-python-dateutil==2.9.0.20241206
typing_extensions==4.12.2
tzdata==2025.1
uri-template==1.3.0
urllib3==2.3.0
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.1.3
wheel==0.45.1
widgetsnbextension==4.0.13
wrapt==1.17.2

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @msaroufim @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@jbschlosser
Copy link
Contributor

Hey @Jacobdelgado1002, is it possible to get a small runnable script reproducing the perf issue? This will help us investigate if indeed there is a problem.

@jbschlosser jbschlosser added oncall: quantization Quantization support in PyTorch oncall: export labels Apr 7, 2025
@jbschlosser jbschlosser added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Apr 7, 2025
@Jacobdelgado1002
Copy link
Author

Hello @jbschlosser, thank you very much for the quick reply. My code is part of a larger project so I took the important parts and made the script attached below to reproduce the issue. Since I am using a custom dataset, I substituted the data-loaders for random data. Please let me know if you require anything else from me.
https://drive.google.com/file/d/1lClCANRKs16hmkMxL9evoWYuulG79U0H/view?usp=sharing

@Jacobdelgado1002
Copy link
Author

Hello @jbschlosser, I was re-reading the documentation, and I believe I found my error. I am converting my QAT model into ATen operators and evaluating the performance of this resulting model. This should not be done since it is just a representation of quantized computation in ATen operators, and it's not lowered on a specific device. Is my analysis correct?

@jerryzh168
Copy link
Contributor

@Jacobdelgado1002 that's correct, the model after convert_pt2e is not the final model for performance, more details in https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html#convert-the-calibrated-model-to-a-quantized-model

@jerryzh168
Copy link
Contributor

closing for now since the issue seems to be resolved, please feel free to reopen if not

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: export oncall: pt2 oncall: quantization Quantization support in PyTorch
Projects
None yet
Development

No branches or pull requests

3 participants