Making pi05 go 🚀🚀🚀🚀.
Tested on
- ubuntu 22.04
- cuda 12.6
- python 3.11.14
- A100 40GB GPU
export CUDA_HOME=/usr/local/cuda-12.6
export PATH=${CUDA_HOME}/bin:${PATH}
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
# Convert JAX weights -> pytorch weights
uv run python examples/convert_jax_model_to_pytorch.py --checkpoint_dir ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid --output_path ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch --config_name pi05_droid --precision float32
cp -r ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid/assets/ ~/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch/
# PyTorch Hacks
cp -r ./src/openpi/models_pytorch/transformers_replace/* .venv/lib/python3.11/site-packages/transformers/
uv run scripts/benchmark.py
We also provide individual benchmarking and testing scripts in the tests/ folder.
| Metric | Mean | Std | P25 | P50 | P75 | P90 | P95 | P99 |
|---|---|---|---|---|---|---|---|---|
| inference_ms | 92.9 | 50.9 | 79.8 | 80.0 | 81.0 | 90.9 | 102.2 | 272.0 |
| policy_infer_ms | 57.2 | 2.8 | 56.2 | 56.3 | 56.4 | 57.4 | 65.5 | 65.5 |
| Metric | Mean | Std | P25 | P50 | P75 | P90 | P95 | P99 |
|---|---|---|---|---|---|---|---|---|
| inference_ms | 322.0 | 0.9 | 321.5 | 321.8 | 322.6 | 323.3 | 323.4 | 323.7 |
| policy_infer_ms | 317.3 | 0.5 | 317.0 | 317.3 | 317.6 | 317.9 | 318.2 | 318.2 |
| Metric | Mean | Std | P25 | P50 | P75 | P90 | P95 | P99 |
|---|---|---|---|---|---|---|---|---|
| inference_ms | 303.5 | 1.3 | 303.1 | 303.6 | 304.5 | 304.7 | 305.5 | 305.9 |
| policy_infer_ms | 298.7 | 1.2 | 298.6 | 298.8 | 299.2 | 299.7 | 300.0 | 301.0 |
Comparing JAX vs PyTorch:
JAX vs PyTorch
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ Action Min │ -0.724857 │
│ Action Max │ 0.966833 │
│ Action Mean │ 0.046649 │
│ │ │
│ Mean Absolute Diff │ 0.005507 │
│ Max Absolute Diff │ 0.183711 │
│ Median Absolute Diff │ 0.002901 │
│ │ │
│ % within 0.001 │ 24.83% │
│ % within 0.01 │ 86.12% │
│ % within 0.1 │ 99.96% │
│ % within 1.0 │ 100.00% │
└──────────────────────┴───────────┘
Comparing JAX vs CUDA:
JAX vs CUDA
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Metric ┃ Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ Action Min │ -0.724857 │
│ Action Max │ 0.966833 │
│ Action Mean │ 0.046649 │
│ │ │
│ Mean Absolute Diff │ 0.004960 │
│ Max Absolute Diff │ 0.192997 │
│ Median Absolute Diff │ 0.002596 │
│ │ │
│ % within 0.001 │ 27.29% │
│ % within 0.01 │ 87.54% │
│ % within 0.1 │ 99.96% │
│ % within 1.0 │ 100.00% │
└──────────────────────┴───────────┘
- [] still trying to figure out how to pin cuda 12.6 in uv. Getting cmake resolution errors currently. If we don't pin it, uv sync installs 12.6 on my machine by default.