This example demonstrates a practical pattern for running a persistent kernel on NVIDIA GPUs while hot-swapping device-side operators at runtime using NVRTC JIT and a device function-pointer jump table.
Highlights:
- Persistent kernel with a global work queue (single producer on host, many consumers on device).
- Device jump table
g_op_table[]of__device__function pointers. - Host compiles new operators at runtime via NVRTC, loads them with the CUDA Driver API, fetches a device function pointer from the JIT module, and patches the jump table.
src/common.h— shared types:Task,WorkQueue,OpFnand extern globals.src/persistent_kernel.cu— persistent worker + built-inop_add+ processed counter;g_op_tableis__managed__.src/host.cpp— host program: sets up queue, launches workers, NVRTC-compilesop_mul, updates jump table, verifies results.CMakeLists.txt— builds with-rdc=trueand links againstcudart,cuda_driver,nvrtc.
- CUDA Toolkit 12.x (or newer) with NVRTC.
- GPU with device function pointer support (sm_50+; recommend sm_70+).
mkdir -p build && cd build
cmake -DCMAKE_BUILD_TYPE=Release ..
cmake --build . -j
./persistent_jit
You should see the host JIT-inject op[1] = op_mul and the persistent kernel will call through the updated function pointer for tasks with op=1. The program waits for completion, signals the kernel to stop, and verifies C = A * B on a few elements.
- Online in-place switch (overwrite slot 0):
./test_online_switch - Dual-slot alias switch with rollback (logical 0 -> slot 0/1):
./test_dual_slot_switch
A minimal PyTorch extension demonstrates micro-batching many tiny ops and executing them as a single aggregated operator compiled at runtime with NVRTC.
What it does:
- Exposes functions to submit tiny add/mul requests on CUDA tensors without launching individual kernels.
- Accumulates pending requests on the host and, on
flush(), JIT-compiles a batch operator (op_batch) via NVRTC (once) and enqueues a single Task that processes the entire batch on the persistent kernel. - The batch operator iterates sub-requests and uses block-local threading to process each, maximizing GPU utilization without per-request launch overhead.
Build/run (example):
- Using dynamic build in-place with PyTorch tools:
python examples/pytorch_batch_demo.py
- Or pre-build the extension:
cd pytorch_ext && python setup.py build_ext --inplace- Then import
gpuos_extin Python.
API (gpuos_ext):
init(capacity=4096, threads_per_block=256)— allocates queue, launches persistent kernel, installs builtins.submit_add(a, b, out)/submit_mul(a, b, out)— enqueue micro-requests to host-side pending buffer (expects float32 CUDA tensors).flush(sync=False)— JIT-installop_batch(once) and publish a single aggregated Task pointing to the batch of requests. Withsync=True, waits for completion.shutdown()— signals quit and joins the persistent kernel.
Notes:
- Set
GPUOS_NVRTC_ARCH(e.g.,compute_90) to override NVRTC arch if needed. - For simplicity, async
flush(sync=False)does not reclaim the per-batch descriptor buffer immediately; usesync=Trueor add a small GC loop in production.
A lightweight TorchDispatch-based scheduler that transparently aggregates tiny pointwise ops (add/mul) into the persistent-kernel runtime. Use it as a context manager:
from pytorch_ext.scheduler import scheduler_context
import torch
with torch.no_grad():
with scheduler_context(capacity=8192, threads_per_block=256, size_threshold=1<<15, auto_flush_ms=2.0):
y = a + b # small ops are queued and batched
z = a * b
# leaving the context flushes and waits
Demo:
python examples/pytorch_scheduler_demo.py(builds the extension on the fly and runs a quick correctness check).python examples/pytorch_scheduler_advanced_demo.py(broadcast, non-contiguous views, mixed dtypes like fp16+fp32).python examples/pytorch_reduce_demo.py(single-dimension sum/mean over last dim, keepdim/nostack; non-contiguous input).
Behavior and constraints:
- Intercepts many unary/binary elementwise ops (relu/sigmoid/tanh/exp/log/sqrt/abs/sin/cos/gelu/hardsigmoid/hardswish/maximum/minimum/pow/leaky_relu/hardtanh/elu/softplus/clamp variants) in addition to add/mul/div/sub.
- Supports broadcasting and non-contiguous input strides; outputs are contiguous by default. Mixed dtypes (fp16/bf16/fp32) compute in fp32 and cast to output dtype.
- Falls back to regular PyTorch for unsupported ops, large tensors (beyond
size_threshold), or when autograd is enabled (best used undertorch.no_grad()). - Ensures correctness by auto-flushing synchronously if a downstream op consumes a tensor produced by the scheduler before the batch is flushed.
- Background timer (
auto_flush_ms) opportunistically flushes pending work to keep latency bounded.
Reductions (beta):
- Scheduler intercepts
aten::sum.dim_IntListandaten::mean.dimfor a single dimension equal to the last axis. Generates and caches a dedicated JIT reduce kernel (sum/mean), supports keepdim and non-contiguous inputs. More general multi-d reductions can be added similarly.
- The queue is implemented with Unified Memory for simplicity. For production, prefer explicit device memory plus lightweight doorbells (atomics in mapped pinned memory) to avoid UM migration overhead.
g_op_tableis declared__managed__to simplify host updates (we usecudaMemcpyToSymbolwith an offset). Workers call__threadfence()before reading the table.- The JIT module exports a bridge
__device__ void* op_mul_ptr = (void*)op_mul;so the host can fetch the function pointer value viacuModuleGetGlobal+cuMemcpyDtoHand store it ing_op_table[op_id]. - The sample keeps the CUmodule alive. In a real system, track modules per operator to unload/replace safely when no tasks are executing that operator.
- If you need to support multiple operator signatures, create multiple jump tables or a thin bytecode interpreter.
- RDC: Both host build and NVRTC must enable relocatable device code (
-rdc=true/--relocatable-device-code=true). - Arch: Default NVRTC target is
--gpu-architecture=compute_90. Override via envGPUOS_NVRTC_ARCHif needed. - ABI: Keep
extern "C" __device__and identicalTasklayout on host and in JIT sources. - Pointer bridge: Always expose
__device__ void* op_x_ptr = (void*)op_x;in JIT modules to fetch the function pointer value. - Prefer PTX loading with the Driver API;
cudaLibraryLoadDatais also viable on CUDA 12+ if you use the runtime loader variants.