Thanks to visit codestin.com
Credit goes to github.com

Skip to content

PSAL-POSTECH/PyTorchSim

Repository files navigation

PyTorchSim: A Comprehensive, Fast, and Accurate NPU Simulation Framework

Docker Image CI

PyTorchSim is a comprehensive, high-speed, cycle-accurate NPU simulation framework

  • We define a RISC-V-based NPU architecture and implement PyTorch compiler backend to run inference & training for PyTorch models
  • Achieved high speed and accuracy with our novel Tile-Level Simulation (TLS) with compiler-generated Tile-Operation Graph (TOG), exploiting deterministic tile compute latency
  • A generic and extensible NPU architecture based on RISC-V vector extension
  • The functional simulator supports code correctness validation and data-dependent timing simulation

For more details, please refer to our paper!

Navigation

Overview | Model Zoo | Getting Started

PyTorchSim Framework Overview

Overview PyTorchSim consists of two main components:

  • Compiler: Integrated of PyTorch2 compiler stack and generates NPU machine code and TOG for existing PyTorch models.
  • TOGSim: Executes TOG for high-speed simulation and accurately models shared resources (DRAM, NoC) through integrated cycle-accurate simulators (BookSim and Ramulator2).

PyTorchSim supports:

Model Zoo

Model Source Status Note
ResNet-18 channel last format
ResNet-50 channel last format
BERT
GPT-2
ViT
Mistral
Diffusion 🤗
Llama-4 🤗 Under Development
DeepSeek v1 🤗 Under Development

Supported Operations

  • GEMM
  • Batched GEMM
  • Convolution
  • Elementwise
  • Reduction
  • Batchnorm
  • Layernorm
  • Softmax
  • Transpose
  • View
  • Activation
  • Pooling

Getting Started

Quick start with pre-built Docker image

To download the latest Docker image and set up the environment, use the following commands:

# Run the Docker container
docker run -it --ipc=host --name torchsim -w /workspace/PyTorchSim ghcr.io/psal-postech/torchsim-ci:latest bash

Run Examples

The tests directory contains several AI workloads examples.

python tests/test_matmul.py 

The result is stored to TORCHSIM_DUMP_PATH/hash/backendsim_result/. The log file contains detailed core, memory, and interconnect stats.

Run Your Own Model on PyTorchSim

You can run your own PyTorch model on PyTorchSim by setting up a custom NPU device.
This method also applies when you want to simulate models beyond the provided examples.

import torch
from Scheduler.scheduler import ExecutionEngine
# Declare a custom NPU device
device = ExecutionEngine.setup_device().custom_device()

# Declare you own model (e.g. resnet18 from torchvision)
from torchvision.models import resnet18
model = resnet50().eval()

# Move model and input tensors to the custom device
model.to(device)
x.to(device)

# Compile and run the model with PyTorchSim
compiled_model = torch.compile(dynamic=False)(model)
y = compiled_model(x)

model is your PyTorch model to be simulated, and x is the input tensor. PyTorchSim automatically generates a Tile-Operation Graph (TOG), and runs it through the TOGSim backend.

Result

Running log in CLI

Wrapper Codegen Path = /tmp/torchinductor_root/yd/cyda7nhzv5mtakfhfcxtmmhtsv6kg7sza4k6wpkdgk7oxbpvqnlz.py
[Gem5Simulator] cmd>  /workspace/gem5/build/RISCV/gem5.opt -r --stdout-file=sto.log -d /tmp/torchinductor/tmp/fy6nnyudtno/m5out /root/workspace/PyTorchSim/gem5_script/script_systolic.py -c /tmp/torchinductor/tmp/fy6nnyudtno/cycle_bin --vlane 128
[Gem5Simulator] Simulation is still running... 
[SpikeSimulator] cmd>  spike --isa rv64gcv --varch=vlen:256,elen:64 --vectorlane-size=128 -m0x80000000:0x1900000000,0x2000000000:0x1000000 --scratchpad-base-paddr=137438953472 --scratchpad-base-vaddr=3489660928 --scratchpad-size=131072  --kernel-addr=0000000000010400:10846 --base-path=/tmp/torchinductor/tmp/fy6nnyudtno/runtime_0001 /workspace/riscv-pk/build/pk /tmp/torchinductor/tmp/fy6nnyudtno/validation_binary /tmp/torchinductor/tmp/fy6nnyudtno/runtime_0001/arg0_1/0.raw /tmp/torchinductor/tmp/fy6nnyudtno/runtime_0001/arg1_1/0.raw /tmp/torchinductor/tmp/fy6nnyudtno/runtime_0001/buf0/0.raw
[BackendSimulator] cmd>  /root/workspace/PyTorchSim/PyTorchSimBackend/build/bin/Simulator --config /root/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json --models_list /tmp/torchinductor/tmp/fy6nnyudtno/tile_graph.onnx --attributes_list /tmp/torchinductor/tmp/fy6nnyudtno/runtime_0001/attribute/0
[BackendSimulator] Simulation is still running..  
[BackendSimulator] Simulation of "/tmp/torchinductor/tmp/fy6nnyudtno/tile_graph.onnx" is stored to "/tmp/torchinductor/tmp/fy6nnyudtno/backendsim_result/0"
----------------------------
|Matmul Forward Test Passed|
----------------------------

Simulation consists of three steps

  1. Gem5Simulator obatins compute latency for TOG.
  2. SpikeSimulator verifies the output code.
  3. BackendSimulator simulates a NPU architecture.

If you want to turn off the SpikeSimulator for fast simulation, you can set as below.

export TORCHSIM_VALIDATION_MODE=False

Log contains memory & core stats.

[info] HBM2-CH_0: avg BW utilization 37% (255 reads, 128 writes)
[info] Row hits: 359, Row misses: 26, Row conflicts: 0
[info] ========= Core stat =========
[info] Core [0] : Systolic array [0] Utilization(%) 0.00, active cycle 0, idle cycle 1014
[info] Core [0] : Systolic array [1] Utilization(%) 12.62, active cycle 128, idle cycle 886
[info] Core [0] : TMA active cycle 3 TMA idle cycle 1011 DRAM BW 182.000 GB/s (6144)
[info] Core [0] : Vector Unit Utilization(%) 4.34, active cycle 44, idle_cycle 0
[info] Core [0] : Numa hit count : 0, Numa miss count : 0
[info] Core [0] : Total cycle 1014
[info] Total execution cycle: 1014
[info] Simulation time: 0.039296 seconds

The log is dumped in TORCHSIM_DUMP_PATH and you can set the path as below.

export TORCHSIM_DUMP_PATH=/tmp/torchinductor # output file dump path

Training

backward() automatically generates TOG and executes simulation for backward propagation. If you want to simulate optimizers on NPU units, you can compile the optimizer’s step function.

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
compiled_step = torch.compile(dynamic=False)(optimizer.step)

optimizer.zero_grad()
loss.backward()
opt_step()

tests/test_mlp.py provides an example of MLP training.

Multi-tenancy

Load generator supports multi-tenancy experiments. You can simply run tests/test_scheduler.py

python tests/test_scheduler.py

Below is an example code of multi-tenancy target_model1 and target_model2 is your own PyTorch model. You can set the request arrival time and request queue index. Request queue is used for scheduling and you can set the number of queue to each core in TOGSim configuration

# Init scheduler
scheduler = Scheduler(num_request_queue=2, engine_select=Scheduler.FIFO_ENGINE, backend_config=config)
# Register compiled model
opt_model1 = torch.compile(target_model1.to(device=scheduler.execution_engine.module.custom_device(), memory_format=torch.channels_last))
opt_model2 = torch.compile(target_model2.to(device=scheduler.execution_engine.module.custom_device()))
SchedulerDNNModel.register_model("resnet18", opt_model1)
SchedulerDNNModel.register_model("bert", opt_model2)

# Init input data
model_input1 = torch.randn(1, 3, 224, 224)
model_input2 = torch.randn(128, 768)

# Init request
new_request1 = Request("resnet18", [model_input1], [], request_queue_idx=0)
new_request2 = Request("bert", [model_input2], [], request_queue_idx=1)
new_request3 = Request("resnet18", [model_input1], [], request_queue_idx=0)
new_request4 = Request("bert", [model_input2], [], request_queue_idx=1)

# Add request to scheduler
scheduler.add_request(new_request1, request_time=0)
scheduler.add_request(new_request2, request_time=0)
scheduler.add_request(new_request3, request_time=0)
scheduler.add_request(new_request4, request_time=0)

# Run scheduler
while not scheduler.is_finished():
    scheduler.schedule()

Compiler Optimizations

PyTorchSim compiler supports fusions

  • GEMM prologue fusion
  • GEMM epilogue fusion
  • GEMM reduction fusion
  • CONV epilogue fusion

Depending on tensor shape, use different convolution template

  • Single batch optimization
  • Multi-channel optimization

Mapping

PyTorchSim provids three mapping strategies

Heuristic-based mapping

We adopt and modified heuristic-based mapping of GEMMINI by default, which maximizes the utilization of scratchpad memory.

Auto-tuning

Heuristic method is not optimal for some cases. PyTorchSim provides auto-tuning to find best mapping for GEMM, CONV, and vector operations. It reduces searching space by sorting of scratchpad memory utilization and pick top-k candiates. Searching parameters are tile shape and vector lane stride.

export AUTOTUNE=True

Manunal setting

User can exploit third-party(e.g. Timeloop) mapping. Set the cheatsheet path and write down their own mapping.

export CONFIG_GEMM_CHEATSHEET_PATH=validation/gemm_tpuv3_cheatsheet.json

Key: "M_K_N" for GEMM

{
    "512_2048_8192" : {
        "TILE_M" : 512,
        "TILE_K" : 512,
        "TILE_N" : 1024
    },
    "512_2048_2048" : {
        "TILE_M" : 512,
        "TILE_K" : 512,
        "TILE_N" : 1024
    },
    "2048_2048_512" : {
        "TILE_M" : 1024,
        "TILE_K" : 512,
        "TILE_N" : 512
    }
}

If you want to explore specific tile size, set the environment variable as below.

export TORCHSIM_MANUAL_TILE_SIZE=1
export TORCHSIM_TILE_M=512
export TORCHSIM_TILE_N=512
export TORCHSIM_TILE_K=512

Compiler Configuration

PyTorchSimFrontend/extension_config.py contains target hardware configuration to compile

You can configure these options using environment variables.

export TORCHSIM_VECTOR_LANE=128 # vector lane size
export TORCHSIM_VECTOR_LANE_STRIDE=2  # vector lane stride for DMA
export TORCHSIM_DIR=/workspace/PyTorchSim # home directory

export BLOCK_SPARSE=0 # If you want to use block sparse workload, turn it on

# Plan which tensor allocated in TPUv4's CMEM
export SRAM_BUFFER_PLAN_PATH=/workspace/PyTorchSim/tpuv4/gemm_plan.py

export TORCHSIM_TLS_MODE=1 # User can choose TLS or ILS mode
export TORCHSIM_USE_TIMING_POOLING=0 # use lightweight pooling for timing

TOGSim Configuration

NPU_Core

PyTorchSimBackend/configs directory contains example NPU configuration files in the JSON format.

  "num_cores" : 2,                   // Number of NPU cores
  "core_freq" : 940,                 // Core's frequency (MHz)
  "num_systolic_array_per_core" : 2, // Number of systolic array per core

  "dram_type" : "ramulator2",        // DRAM type (ex. ramulator2, simple)
  "dram_freq" : 940,                 // DRAM frequency (MHz)
  "dram_channels": 32,               // Number of DRAM channels
  "dram_req_size": 32,               // DRAM request size (B)
  "dram_latency" : 10,               // DRAM latency (cycle)
  "dram_nbl" : 2,                    // DRAM burst length size
  "dram_config_path" : "../configs/ramulator2_configs/HBM2_TPUv3.yaml", // Ramulator2 config file path

  "icnt_type" : "simple",            // Interconnect type (ex. booksim, simple)
  "icnt_latency" : 7,                // Interconnect latency (cycle)
  "icnt_freq" : 28000,               // Interconnect frequency (MHz)
  "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", // Booksim2 config file path

  "precision" : 4,                   // Element's precision in tensor (Byte)
  "scheduler" : "simple",            // Scheduler type (Now, only support simple scheduler)
  "num_partition" : 2,               // Multi-core Partitioning
  "partition": {                     // allocate request queue index
    "core_0":0,
    "core_1":1
  }

You can set TOGSim config path as below.

export TORCHSIM_CONFIG=/workspace/PyTorchSim/PyTorchSimBackend/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.json

Future Works

Currently, PyTorchSim supports PyTorch 2.2. Support for newer versions will be added soon.

Artifact Evaluation

Artifact evaluation is being prepared for v1.0.0. The following scripts reproduce the validation and speedup results from the paper.

Build

docker run -it --ipc=host --name torchsim -w /workspace/PyTorchSim ghcr.io/psal-postech/torchsim-ci:v1.0.0 bash

To generate validation results

# Run a cycle accuracy script
./experiments/artifact/cycle_validation/run_cycle.sh

To generate speedup results

# Run a speedup accuracy script
./experiments/artifact/speedup/run_speedup.sh

Contributing

We welcome any contributions and issue reports. Contribution guideline will be posted.

Citation

If you use PyTorchSim for your research, please cite the following paper.

@INPROCEEDINGS{yang2025pytorchsim,
  author={Yang, Wonhyuk and Shin, Yunseon and Woo, Okkyun and Park, Geonwoo and Ham, Hyungkyu and Kang, Jeehoon and Park, Jongse and Kim, Gwangsun},
  title={PyTorchSim: A Comprehensive, Fast, and Accurate NPU Simulation Framework},
  booktitle={2025 58th IEEE/ACM International Symposium on Microarchitecture (MICRO)}, 
  volume={},
  number={},
  pages={},
  year={2025},
  doi={10.1145/3725843.3756045}
}

About

PyTorchSim is a Comprehensive, Fast, and Accurate NPU Simulation Framework

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors 5