Tile Language (tile-lang) is a concise domain-specific language designed to streamline the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention). By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM, tile-lang allows developers to focus on productivity without sacrificing the low-level optimizations necessary for state-of-the-art performance.
- 09/29/2025 🎉: Thrilled to announce that AscendC and AscendNPU IR backends targeting Huawei Ascend chips are now supported! Check out the preview here: 🔗 link. This includes implementations across two branches: ascendc_pto and ascendnpu_ir. Feel free to explore and share your feedback!
- 07/04/2025 🚀: Introduced T.gemm_spfor 2:4 sparse tensor core support, check out Pull Request #526 for details.
- 06/05/2025 ✨: Added NVRTC Backend to significantly reduce compilation time for cute templates!
- 04/14/2025 🚀: Added high-performance FlashMLA implementation for AMD MI300X, achieving performance parity with hand-optimized assembly kernels of Aiter! See example_mla_amd for details.
- 03/03/2025 🚀: Added high-performance MLA Decoding support using only 80 lines of Python code, achieving performance on par with FlashMLA on H100 (see example_mla_decode.py)! We also provide documentation explaining how TileLang achieves this.
- 02/15/2025 ✨: Added WebGPU Codegen support, see Pull Request #86!
- 02/12/2025 ✨: Excited to announce the release of v0.1.0!
- 02/10/2025 🚀: Added debug tools for TileLang—T.printfor printing variables/buffers (docs) and a memory layout plotter (examples/plot_layout).
- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!
Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support).
tile-lang provides the building blocks to implement a wide variety of operators. Some examples include:
- Matrix Multiplication
- Dequantization GEMM
- Flash Attention
- Flash Linear Attention
- Flash MLA Decoding
- Native Sparse Attention
Within the examples directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added.
TileLang achieves exceptional performance across a variety of computational patterns. Comprehensive benchmark scripts and settings are available at tilelang-benchmark. Below are selected results showcasing its capabilities:
- 
MLA Decoding Performance on H100 
- 
Flash Attention Performance on H100 
- 
Matmul Performance on GPUs (RTX 4090, A100, H100, MI300X) 
- 
Dequantize Matmul Performance on A100 
The quickest way to get started is to install the latest release from PyPI:
pip install tilelangAlternatively, you can install directly from the GitHub repository:
pip install git+https://github.com/tile-ai/tilelangOr install locally:
# install required system dependencies
sudo apt-get update
sudo apt-get install -y python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
pip install -e . -v # remove -e option if you don't want to install in editable mode, -v for verbose outputWe currently provide three ways to install tile-lang from source:
- Install from Source (using your own TVM installation)
- Install from Source (using the bundled TVM submodule)
- Install Using the Provided Script
For users who want access to the latest features and improvements before official releases, we provide nightly builds of tile-lang.
pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/
# or pip install tilelang --find-links https://tile-ai.github.io/whl/nightly/cu121/Note: Nightly builds contain the most recent code changes but may be less stable than official releases. They're ideal for testing new features or if you need a specific bugfix that hasn't been released yet.
In this section, you'll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling.
Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.
import tilelang
import tilelang.language as T
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
    @T.prim_func
    def matmul_relu_kernel(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            # Enable rasterization for better L2 cache locality (Optional)
            # T.use_swizzle(panel_size=10, enable=True)
            # Clear local accumulation
            T.clear(C_local)
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                # Copy tile of A
                # This is a sugar syntax for parallelized copy
                T.copy(A[by * block_M, ko * block_K], A_shared)
                # Copy tile of B
                T.copy(B[ko * block_K, bx * block_N], B_shared)
                # Perform a tile-level GEMM on the shared buffers
                # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
                T.gemm(A_shared, B_shared, C_local)
            
            # relu
            for i, j in T.Parallel(block_M, block_N):
                C_local[i, j] = T.max(C_local[i, j], 0)
            # Copy result back to global memory
            T.copy(C_local, C[by * block_M, bx * block_N])
    return matmul_relu_kernel
M = 1024  # M = T.symbolic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")In addition to GEMM, we provide a variety of examples to showcase the versatility and power of TileLang, including:
- Dequantize GEMM: Achieve high-performance dequantization by fine-grained control over per-thread operations, with many features now adopted as default behaviors in BitBLAS, which utilizing magic layout transformation and intrins to accelerate dequantize gemm.
- FlashAttention: Enable cross-operator fusion with simple and intuitive syntax, and we also provide an example of auto tuning.
- LinearAttention: Examples include RetNet and Mamba implementations.
- Convolution: Implementations of Convolution with IM2Col.
Check our tilelang v0.2.0 release plan for upcoming features.
TileLang has now been used in project BitBLAS and AttentionEngine.
Welcome to join our Discord community for discussions, support, and collaboration!
We would like to express our gratitude to the TVM community for their invaluable contributions. The initial version of this project was mainly developed by LeiWang1999, chengyupku and nox-410 with supervision from Prof. Zhi Yang at Peking University. Part of this work was carried out during an internship at Microsoft Research, where Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang offered valuable advice and support. We deeply appreciate their mentorship and contributions.