Thanks to visit codestin.com
Credit goes to github.com

Skip to content

branyang02/openpi-cuda

 
 

Repository files navigation

openpi with CUDA inference!

Making pi05 go 🚀🚀🚀🚀.

Tested on

  • ubuntu 22.04
  • cuda 12.6
  • python 3.11.14
  • A100 40GB GPU

Installation

export CUDA_HOME=/usr/local/cuda-12.6
export PATH=${CUDA_HOME}/bin:${PATH}
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}

GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .

Setup

# Convert JAX weights -> pytorch weights
uv run python examples/convert_jax_model_to_pytorch.py --checkpoint_dir ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch --config_name pi05_droid --precision float32
cp -r ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid/assets/ ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch/

# PyTorch Hacks
cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/

Benchmarking

uv run scripts/benchmark.py

We also provide individual benchmarking and testing scripts in the tests/ folder.

Benchmark Results

JAX Results

Metric Mean Std P25 P50 P75 P90 P95 P99
inference_ms 92.9 50.9 79.8 80.0 81.0 90.9 102.2 272.0
policy_infer_ms 57.2 2.8 56.2 56.3 56.4 57.4 65.5 65.5

PyTorch Results

Metric Mean Std P25 P50 P75 P90 P95 P99
inference_ms 322.0 0.9 321.5 321.8 322.6 323.3 323.4 323.7
policy_infer_ms 317.3 0.5 317.0 317.3 317.6 317.9 318.2 318.2

CUDA Results

Metric Mean Std P25 P50 P75 P90 P95 P99
inference_ms 303.5 1.3 303.1 303.6 304.5 304.7 305.5 305.9
policy_infer_ms 298.7 1.2 298.6 298.8 299.2 299.7 300.0 301.0

Comparing JAX vs PyTorch:

           JAX vs PyTorch           
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric               ┃     Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ Action Min           │ -0.724857 │
│ Action Max           │  0.966833 │
│ Action Mean          │  0.046649 │
│                      │           │
│ Mean Absolute Diff   │  0.005507 │
│ Max Absolute Diff    │  0.183711 │
│ Median Absolute Diff │  0.002901 │
│                      │           │
│ % within 0.001       │    24.83% │
│ % within 0.01        │    86.12% │
│ % within 0.1         │    99.96% │
│ % within 1.0         │   100.00% │
└──────────────────────┴───────────┘

Comparing JAX vs CUDA:

            JAX vs CUDA             
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric               ┃     Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ Action Min           │ -0.724857 │
│ Action Max           │  0.966833 │
│ Action Mean          │  0.046649 │
│                      │           │
│ Mean Absolute Diff   │  0.004960 │
│ Max Absolute Diff    │  0.192997 │
│ Median Absolute Diff │  0.002596 │
│                      │           │
│ % within 0.001       │    27.29% │
│ % within 0.01        │    87.54% │
│ % within 0.1         │    99.96% │
│ % within 1.0         │   100.00% │
└──────────────────────┴───────────┘

TODO

  • [] still trying to figure out how to pin cuda 12.6 in uv. Getting cmake resolution errors currently. If we don't pin it, uv sync installs 12.6 on my machine by default.

About

Robot inference with custom CUDA kernels!

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 92.7%
  • Cuda 6.5%
  • Other 0.8%