Rate this Page
torch.export AOTInductor Tutorial for Python runtime (Beta)">

torch.export AOTInductor Tutorial for Python runtime (Beta)#

Created On: Aug 23, 2024 | Last Updated: Jan 24, 2025 | Last Verified: Nov 05, 2024

Author: Ankith Gunapal, Bin Bao, Angela Yi

Warning

torch._inductor.aoti_compile_and_package and torch._inductor.aoti_load_package are in Beta status and are subject to backwards compatibility breaking changes. This tutorial provides an example of how to use these APIs for model deployment using Python runtime.

It has been shown previously how AOTInductor can be used to do Ahead-of-Time compilation of PyTorch exported models by creating an artifact that can be run in a non-Python environment. In this tutorial, you will learn an end-to-end example of how to use AOTInductor for Python runtime.

Contents

Prerequisites#

What you will learn#

Model Compilation#

We will use the TorchVision pretrained ResNet18 model as an example.

The first step is to export the model to a graph representation using torch.export.export(). To learn more about using this function, you can check out the docs or the tutorial.

Once we have exported the PyTorch model and obtained an ExportedProgram, we can apply torch._inductor.aoti_compile_and_package() to AOTInductor to compile the program to a specified device, and save the generated contents into a “.pt2” artifact.

Note

This API supports the same available options that torch.compile() has, such as mode and max_autotune (for those who want to enable CUDA graphs and leverage Triton based matrix multiplications and convolutions)

import os
import torch
import torch._inductor
from torchvision.models import ResNet18_Weights, resnet18

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()

with torch.inference_mode():
    inductor_configs = {}

    if torch.cuda.is_available():
        device = "cuda"
        inductor_configs["max_autotune"] = True
    else:
        device = "cpu"

    model = model.to(device=device)
    example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

    exported_program = torch.export.export(
        model,
        example_inputs,
    )
    path = torch._inductor.aoti_compile_and_package(
        exported_program,
        package_path=os.path.join(os.getcwd(), "resnet18.pt2"),
        inductor_configs=inductor_configs
    )
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

  0%|          | 0.00/44.7M [00:00<?, ?B/s]
 72%|███████▏  | 32.4M/44.7M [00:00<00:00, 339MB/s]
100%|██████████| 44.7M/44.7M [00:00<00:00, 336MB/s]
/usr/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:320: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py:4628: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  current_out_size = out_base.storage().size()
Autotune Choices Stats:
{"num_choices": 10, "num_triton_choices": 9, "best_kernel": "convolution", "best_time": 0.09830400347709656, "best_triton_pos": 1, "best_triton_time": 0.37785598635673523, "best_triton_kernel": "triton_convolution2d_8", "best_triton_kernel_desc": "ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=4, num_warps=4"}
AUTOTUNE convolution(2x3x224x224, 64x3x7x7)
strides: [150528, 1, 672, 3], [147, 1, 21, 3]
dtypes: torch.float32, torch.float32
  convolution 0.0983 ms 100.0%
  triton_convolution2d_8 0.3779 ms 26.0% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=4, num_warps=4
  triton_convolution2d_7 0.4014 ms 24.5% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=4, num_warps=4
  triton_convolution2d_0 0.4250 ms 23.1% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_5 0.4516 ms 21.8% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_6 0.4946 ms 19.9% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=3, num_warps=8
  triton_convolution2d_3 0.5038 ms 19.5% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_2 0.5407 ms 18.2% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_1 0.6410 ms 15.3% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_4 0.7035 ms 14.0% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.4810 seconds and 0.0193 seconds precompiling for 10 choices
Autotune Choices Stats:
{"num_choices": 12, "num_triton_choices": 11, "best_kernel": "convolution", "best_time": 0.09417600184679031, "best_triton_pos": 1, "best_triton_time": 0.11366400122642517, "best_triton_kernel": "triton_convolution2d_13", "best_triton_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4"}
AUTOTUNE convolution(2x64x56x56, 64x64x3x3)
strides: [200704, 1, 3584, 64], [576, 1, 192, 64]
dtypes: torch.float32, torch.float32
  convolution 0.0942 ms 100.0%
  triton_convolution2d_13 0.1137 ms 82.9% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_16 0.1178 ms 80.0% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=3, num_warps=8
  triton_convolution2d_12 0.1208 ms 77.9% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_20 0.1229 ms 76.6% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_9 0.1280 ms 73.6% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_17 0.1280 ms 73.6% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=4, num_warps=4
  triton_convolution2d_15 0.1423 ms 66.2% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_10 0.1495 ms 63.0% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_19 0.1526 ms 61.7% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.3546 seconds and 0.0162 seconds precompiling for 12 choices
Autotune Choices Stats:
{"num_choices": 13, "num_triton_choices": 12, "best_kernel": "convolution", "best_time": 0.09728000313043594, "best_triton_pos": 1, "best_triton_time": 0.11366400122642517, "best_triton_kernel": "triton_convolution2d_25", "best_triton_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4"}
AUTOTUNE convolution(2x64x56x56, 64x64x3x3)
strides: [200704, 1, 3584, 64], [576, 1, 192, 64]
dtypes: torch.float32, torch.float32
  convolution 0.0973 ms 100.0%
  triton_convolution2d_25 0.1137 ms 85.6% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_28 0.1178 ms 82.6% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=3, num_warps=8
  triton_convolution2d_24 0.1198 ms 81.2% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_32 0.1229 ms 79.2% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_21 0.1280 ms 76.0% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_29 0.1280 ms 76.0% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=4, num_warps=4
  triton_convolution2d_27 0.1423 ms 68.3% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_22 0.1495 ms 65.1% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_31 0.1526 ms 63.8% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2759 seconds and 0.0001 seconds precompiling for 13 choices
Autotune Choices Stats:
{"num_choices": 13, "num_triton_choices": 12, "best_kernel": "triton_convolution2d_61", "best_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4", "best_time": 0.07065600156784058, "best_triton_pos": 0}
AUTOTUNE convolution(2x64x56x56, 128x64x3x3)
strides: [200704, 1, 3584, 64], [576, 1, 192, 64]
dtypes: torch.float32, torch.float32
  triton_convolution2d_61 0.0707 ms 100.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  convolution 0.0942 ms 75.0%
  triton_convolution2d_57 0.0963 ms 73.4% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_62 0.1198 ms 59.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_63 0.1413 ms 50.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_68 0.1413 ms 50.0% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_58 0.1423 ms 49.6% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_60 0.1454 ms 48.6% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_59 0.2048 ms 34.5% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_67 0.4065 ms 17.4% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 34.2879 seconds and 2.1925 seconds precompiling for 13 choices
Autotune Choices Stats:
{"num_choices": 17, "num_triton_choices": 16, "best_kernel": "convolution", "best_time": 0.09318400174379349, "best_triton_pos": 1, "best_triton_time": 0.13516800105571747, "best_triton_kernel": "triton_convolution2d_73", "best_triton_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4"}
AUTOTUNE convolution(2x128x28x28, 128x128x3x3)
strides: [100352, 1, 3584, 128], [1152, 1, 384, 128]
dtypes: torch.float32, torch.float32
  convolution 0.0932 ms 100.0%
  triton_convolution2d_73 0.1352 ms 68.9% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_69 0.1864 ms 50.0% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_74 0.2447 ms 38.1% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_72 0.2693 ms 34.6% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_75 0.2734 ms 34.1% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_80 0.2765 ms 33.7% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_70 0.2857 ms 32.6% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_71 0.4219 ms 22.1% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_79 0.7926 ms 11.8% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 53.7784 seconds and 0.0003 seconds precompiling for 17 choices
Autotune Choices Stats:
{"num_choices": 13, "num_triton_choices": 12, "best_kernel": "triton_convolution2d_85", "best_kernel_desc": "ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4", "best_time": 0.027648000046610832, "best_triton_pos": 0}
AUTOTUNE convolution(2x64x56x56, 128x64x1x1)
strides: [200704, 1, 3584, 64], [64, 1, 1, 1]
dtypes: torch.float32, torch.float32
  triton_convolution2d_85 0.0276 ms 100.0% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_89 0.0276 ms 100.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_90 0.0287 ms 96.4% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_88 0.0317 ms 87.1% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_91 0.0317 ms 87.1% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_86 0.0338 ms 81.8% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_93 0.0369 ms 75.0% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=4, num_warps=4
  triton_convolution2d_94 0.0369 ms 75.0% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=4, num_warps=4
  triton_convolution2d_87 0.0399 ms 69.2% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
  triton_convolution2d_92 0.0410 ms 67.5% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=3, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 16.7395 seconds and 0.0002 seconds precompiling for 13 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "convolution", "best_time": 0.08089599758386612, "best_triton_pos": 1, "best_triton_time": 0.13414399325847626, "best_triton_kernel": "triton_convolution2d_133", "best_triton_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4"}
AUTOTUNE convolution(2x128x28x28, 256x128x3x3)
strides: [100352, 1, 3584, 128], [1152, 1, 384, 128]
dtypes: torch.float32, torch.float32
  convolution 0.0809 ms 100.0%
  triton_convolution2d_133 0.1341 ms 60.3% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_129 0.2724 ms 29.7% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_135 0.2744 ms 29.5% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_130 0.2765 ms 29.3% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_132 0.2826 ms 28.6% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_134 0.2877 ms 28.1% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_131 0.3656 ms 22.1% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_140 0.7700 ms 10.5% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_139 0.7844 ms 10.3% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 67.3476 seconds and 0.0002 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "convolution", "best_time": 0.08396799862384796, "best_triton_pos": 1, "best_triton_time": 0.26214399933815, "best_triton_kernel": "triton_convolution2d_150", "best_triton_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4"}
AUTOTUNE convolution(2x256x14x14, 256x256x3x3)
strides: [50176, 1, 3584, 256], [2304, 1, 768, 256]
dtypes: torch.float32, torch.float32
  convolution 0.0840 ms 100.0%
  triton_convolution2d_150 0.2621 ms 32.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_148 0.4813 ms 17.4% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_146 0.5294 ms 15.9% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_149 0.5315 ms 15.8% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_152 0.5356 ms 15.7% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_151 0.5530 ms 15.2% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_147 0.5663 ms 14.8% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_157 1.5043 ms 5.6% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_156 1.5534 ms 5.4% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 67.3523 seconds and 0.0002 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_convolution2d_167", "best_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4", "best_time": 0.02969600073993206, "best_triton_pos": 0}
AUTOTUNE convolution(2x128x28x28, 256x128x1x1)
strides: [100352, 1, 3584, 128], [128, 1, 1, 1]
dtypes: torch.float32, torch.float32
  triton_convolution2d_167 0.0297 ms 100.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_164 0.0399 ms 74.4% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_168 0.0410 ms 72.5% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_163 0.0430 ms 69.0% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_166 0.0430 ms 69.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_169 0.0430 ms 69.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_165 0.0461 ms 64.4% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
  convolution 0.0594 ms 50.0%
  triton_convolution2d_174 0.1085 ms 27.4% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_173 0.1106 ms 26.9% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 83.8629 seconds and 0.0002 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "convolution", "best_time": 0.08806400001049042, "best_triton_pos": 1, "best_triton_time": 0.26419198513031006, "best_triton_kernel": "triton_convolution2d_218", "best_triton_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4"}
AUTOTUNE convolution(2x256x14x14, 512x256x3x3)
strides: [50176, 1, 3584, 256], [2304, 1, 768, 256]
dtypes: torch.float32, torch.float32
  convolution 0.0881 ms 100.0%
  triton_convolution2d_218 0.2642 ms 33.3% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_216 0.4710 ms 18.7% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_214 0.5396 ms 16.3% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_220 0.5560 ms 15.8% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_217 0.5581 ms 15.8% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_219 0.5693 ms 15.5% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_215 0.5898 ms 14.9% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_224 1.5012 ms 5.9% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_225 1.5032 ms 5.9% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 66.8366 seconds and 0.0002 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 17, "num_triton_choices": 16, "best_kernel": "convolution", "best_time": 0.09932799637317657, "best_triton_pos": 1, "best_triton_time": 0.5222399830818176, "best_triton_kernel": "triton_convolution2d_235", "best_triton_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4"}
AUTOTUNE convolution(2x512x7x7, 512x512x3x3)
strides: [25088, 1, 3584, 512], [4608, 1, 1536, 512]
dtypes: torch.float32, torch.float32
  convolution 0.0993 ms 100.0%
  triton_convolution2d_235 0.5222 ms 19.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_233 0.6164 ms 16.1% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
  triton_convolution2d_232 0.7506 ms 13.2% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_237 0.9124 ms 10.9% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_231 1.0639 ms 9.3% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_234 1.0660 ms 9.3% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_241 1.0762 ms 9.2% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_236 1.1049 ms 9.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_242 2.9839 ms 3.3% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 48.6255 seconds and 0.0002 seconds precompiling for 17 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_convolution2d_251", "best_kernel_desc": "ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4", "best_time": 0.043007999658584595, "best_triton_pos": 0}
AUTOTUNE convolution(2x256x14x14, 512x256x1x1)
strides: [50176, 1, 3584, 256], [256, 1, 1, 1]
dtypes: torch.float32, torch.float32
  triton_convolution2d_251 0.0430 ms 100.0% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_249 0.0532 ms 80.8% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
  triton_convolution2d_250 0.0645 ms 66.7% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_253 0.0645 ms 66.7% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_247 0.0666 ms 64.6% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_252 0.0676 ms 63.6% ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_248 0.0686 ms 62.7% ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
  convolution 0.0748 ms 57.5%
  triton_convolution2d_257 0.1925 ms 22.3% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=256, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_258 0.1925 ms 22.3% ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 65.7551 seconds and 0.0002 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 19, "num_triton_choices": 18, "best_kernel": "triton_mm_299", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=5, num_warps=2", "best_time": 0.027648000046610832, "best_triton_pos": 0}
AUTOTUNE addmm(2x1000, 2x512, 512x1000)
strides: [0, 1], [512, 1], [1, 512]
dtypes: torch.float32, torch.float32, torch.float32
  triton_mm_299 0.0276 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=5, num_warps=2
  triton_mm_300 0.0307 ms 90.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=5, num_warps=2
  triton_mm_297 0.0328 ms 84.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=2, num_warps=2
  addmm 0.0410 ms 67.5%
  triton_mm_310 0.0420 ms 65.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_309 0.0429 ms 64.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_296 0.0430 ms 64.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=1, num_warps=2
  triton_mm_298 0.0430 ms 64.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_303 0.0430 ms 64.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_302 0.0440 ms 62.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, OUT_DTYPE='tl.float32', USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 9.7501 seconds and 0.0002 seconds precompiling for 19 choices

The result of aoti_compile_and_package() is an artifact “resnet18.pt2” which can be loaded and executed in Python and C++.

The artifact itself contains a bunch of AOTInductor generated code, such as a generated C++ runner file, a shared library compiled from the C++ file, and CUDA binary files, aka cubin files, if optimizing for CUDA.

Structure-wise, the artifact is a structured .zip file, with the following specification:

We can use the following command to inspect the artifact contents:

$ unzip -l resnet18.pt2
Archive:  resnet18.pt2
  Length      Date    Time    Name
---------  ---------- -----   ----
        1  01-08-2025 16:40   version
        3  01-08-2025 16:40   archive_format
    10088  01-08-2025 16:40   data/aotinductor/model/cagzt6akdaczvxwtbvqe34otfe5jlorktbqlojbzqjqvbfsjlge4.cubin
    17160  01-08-2025 16:40   data/aotinductor/model/c6oytfjmt5w4c7onvtm6fray7clirxt7q5xjbwx3hdydclmwoujz.cubin
    16616  01-08-2025 16:40   data/aotinductor/model/c7ydp7nocyz323hij4tmlf2kcedmwlyg6r57gaqzcsy3huneamu6.cubin
    17776  01-08-2025 16:40   data/aotinductor/model/cyqdf46ordevqhiddvpdpp3uzwatfbzdpl3auj2nx23uxvplnne2.cubin
    10856  01-08-2025 16:40   data/aotinductor/model/cpzfebfgrusqslui7fxsuoo4tvwulmrxirc5tmrpa4mvrbdno7kn.cubin
    14608  01-08-2025 16:40   data/aotinductor/model/c5ukeoz5wmaszd7vczdz2qhtt6n7tdbl3b6wuy4rb2se24fjwfoy.cubin
    11376  01-08-2025 16:40   data/aotinductor/model/csu3nstcp56tsjfycygaqsewpu64l5s6zavvz7537cm4s4cv2k3r.cubin
    10984  01-08-2025 16:40   data/aotinductor/model/cp76lez4glmgq7gedf2u25zvvv6rksv5lav4q22dibd2zicbgwj3.cubin
    14736  01-08-2025 16:40   data/aotinductor/model/c2bb5p6tnwz4elgujqelsrp3unvkgsyiv7xqxmpvuxcm4jfl7pc2.cubin
    11376  01-08-2025 16:40   data/aotinductor/model/c6eopmb2b4ngodwsayae4r5q6ni3jlfogfbdk3ypg56tgpzhubfy.cubin
    11624  01-08-2025 16:40   data/aotinductor/model/chmwe6lvoekzfowdbiizitm3haiiuad5kdm6sd2m6mv6dkn2zk32.cubin
    15632  01-08-2025 16:40   data/aotinductor/model/c3jop5g344hj3ztsu4qm6ibxyaaerlhkzh2e6emak23rxfje6jam.cubin
    25472  01-08-2025 16:40   data/aotinductor/model/chaiixybeiuuitm2nmqnxzijzwgnn2n7uuss4qmsupgblfh3h5hk.cubin
   139389  01-08-2025 16:40   data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.cpp
       27  01-08-2025 16:40   data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t_metadata.json
 47195424  01-08-2025 16:40   data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.so
---------                     -------
 47523148                     18 files

Model Inference in Python#

To load and run the artifact in Python, we can use torch._inductor.aoti_load_package().

import os
import torch
import torch._inductor

model_path = os.path.join(os.getcwd(), "resnet18.pt2")

compiled_model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)

with torch.inference_mode():
    output = compiled_model(example_inputs)

When to use AOTInductor with a Python Runtime#

There are mainly two reasons why one would use AOTInductor with a Python Runtime:

  • torch._inductor.aoti_compile_and_package generates a singular serialized artifact. This is useful for model versioning for deployments and tracking model performance over time.

  • With torch.compile() being a JIT compiler, there is a warmup cost associated with the first compilation. Your deployment needs to account for the compilation time taken for the first inference. With AOTInductor, the compilation is done ahead of time using torch.export.export and torch._inductor.aoti_compile_and_package. At deployment time, after loading the model, running inference does not have any additional cost.

The section below shows the speedup achieved with AOTInductor for first inference

We define a utility function timed to measure the time taken for inference

import time
def timed(fn):
    # Returns the result of running `fn()` and the time it took for `fn()` to run,
    # in seconds. We use CUDA events and synchronization for accurate
    # measurement on CUDA enabled devices.
    if torch.cuda.is_available():
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
    else:
        start = time.time()

    result = fn()
    if torch.cuda.is_available():
        end.record()
        torch.cuda.synchronize()
    else:
        end = time.time()

    # Measure time taken to execute the function in miliseconds
    if torch.cuda.is_available():
        duration = start.elapsed_time(end)
    else:
        duration = (end - start) * 1000

    return result, duration

Lets measure the time for first inference using AOTInductor

torch._dynamo.reset()

model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)

with torch.inference_mode():
    _, time_taken = timed(lambda: model(example_inputs))
    print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
Time taken for first inference for AOTInductor is 3.75 ms

Lets measure the time for first inference using torch.compile

torch._dynamo.reset()

model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
model.eval()

model = torch.compile(model)
example_inputs = torch.randn(1, 3, 224, 224, device=device)

with torch.inference_mode():
    _, time_taken = timed(lambda: model(example_inputs))
    print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
Time taken for first inference for torch.compile is 4372.93 ms

We see that there is a drastic speedup in first inference time using AOTInductor compared to torch.compile

Conclusion#

In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by compiling and loading a pretrained ResNet18 model. This process demonstrates the practical application of generating a compiled artifact and running it within a Python environment. We also looked at the advantage of using AOTInductor in model deployments, with regards to speed up in first inference time.

Total running time of the script: (9 minutes 8.232 seconds)