Skip to content

Running VLA at 30Hz frame rate and 480Hz trajectory frequency

License

Notifications You must be signed in to change notification settings

dexmal/realtime-vla

Repository files navigation

Running VLA in Real Time (Pi0 + Pi05)

This project provides accelerated inference kernels of Pi0 and Pi05 models from the OpenPI project.

Real-time VLA Inference Demo Real-world demonstration: catching a falling pen with sub-200ms end-to-end latency using 30 FPS inference. (From "Running VLAs at Real-time Speed")

Pi0 (Triton) latency on RTX 4090 (max boosted clock 2.79GHz) for one set of observations (10 flow steps, chunk size 63, empty prompt) is as follows:

1 view 2 views 3 views
20.0ms 27.3ms 36.8ms

For more realistic Pi0 settings (prompt length 20 tokens, chunk size 50), we have:

1 view 2 views 3 views
23.6ms 32.9ms 39.2ms

To match camera speeds, you should consider using 30fps for one or two views, and 25fps for three views.

What's NEW!

  • [2026/01] 🔥 Added Pi05 Triton inference: pi05_infer.py (major upgrade; Pi0 + Pi05 supported).
  • [2026/01] 🔥 Added test.py for Triton vs JAX correctness/consistency checks, supporting both Pi0 and Pi05 (MAE + per-dimension MAE).
  • [2026/01] 🔥 Added a benchmark matrix: RTX 4090/5090 × 1/2/3 views × Pi0/Pi05 × Triton.

Benchmark Matrix

Model / Backend RTX 4090 (1 view) RTX 4090 (2 views) RTX 4090 (3 views) RTX 5090 (1 view) RTX 5090 (2 views) RTX 5090 (3 views)
Pi0 Triton 20.0ms 27.3ms 36.8ms 17.6ms 24.0ms 31.9ms
Pi05 Triton 22.1ms 29.2ms 38.9ms 20.1ms 26.6ms 34.2ms

How to Use

The intended usage is to directly copy the inference file into your project:

  • pi0_infer.py: Pi0 Triton inference
  • pi05_infer.py: Pi05 Triton inference (new)

Pi0 Triton inference

converted_checkpoint = pickle.load(open('converted_checkpoint.pkl', 'rb'))

from pi0_infer import Pi0Inference

infer = Pi0Inference(converted_checkpoint, number_of_images, length_of_trajectory)
output_actions = infer.forward(
   normalized_observation_image_bfloat16, # (number_of_images, 224, 224, 3)
   observation_state_bfloat16, # (32,)
   diffusion_input_noise_bfloat16, # (length_of_trajectory, 32)
)

Pi05 Triton inference (new)

converted_checkpoint = pickle.load(open('converted_checkpoint.pkl', 'rb'))

from pi05_infer import Pi05Inference

infer = Pi05Inference(
  checkpoint=converted_checkpoint,
  num_views=number_of_images,
  chunk_size=length_of_trajectory,
  tokenizer_path="/path/to/paligemma-3b-pt-224",
  # discrete_state_input=True is recommended (and matches `test.py`)
  discrete_state_input=True,
)
output_actions = infer.forward(
   normalized_observation_image_bfloat16, # (number_of_images, 224, 224, 3)
   diffusion_input_noise_bfloat16, # (length_of_trajectory, 32)
   task_prompt, # str
   state_tokens, # np.ndarray (discrete tokens if discrete_state_input=True)
)

Convert from JAX checkpoint

pytho3 convert_from_jax.py \
   --jax_path /path/to/checkpoint/folder\
   --output converted_checkpoint.pkl\
   --prompt "your task prompt"
   --tokenizer_path /path/to/paligemma-3b-pt-224
python3 convert_from_jax_pi05.py \
   --jax_path /path/to/checkpoint/folder\
   --output converted_checkpoint.pkl\
   --prompt "your task prompt"
   --tokenizer_path /path/to/paligemma-3b-pt-224

The code is specifically tuned on RTX 4090, CUDA 12.6, but it should work on similar platforms so long as torch and triton themselves work.

Checking Performance

You can check the inference time on you local machine by

python3 benchmark.py --num_views 2 --prompt_len 0 --chunk_size 63 --model_version pi0

Correctness / Consistency Testing (Pi0 & Pi05)

We provide test.py to compare Triton vs JAX outputs for both Pi0 and Pi05 (simple MAE report per-dimension).

Example (Pi05):

python3 test.py \
  --model_version pi05 \
  --triton_path converted_checkpoint.pkl \
  --jax_path /path/to/jax/checkpoint/folder \
  --norm_stats_dir /path/to/norm_stats_dir \
  --config_name <openpi_config_name> \
  --prompt "your task prompt" \
  --tokenizer_path /path/to/paligemma-3b-pt-224 \
  --discrete_state_input

Example (Pi0):

python3 test.py \
  --model_version pi0 \
  --triton_path converted_checkpoint.pkl \
  --jax_path /path/to/jax/checkpoint/folder \
  --norm_stats_dir /path/to/norm_stats_dir \
  --config_name <openpi_config_name>

Structure of the Code

The implementation is organized into three main components: a vision encoder, an LLM, and an action expert. Each component is decomposed into fundamental operations, with the entire computation graph simplified to 24 GEMM-like operations and associated scalar operators. This modular structure allows for efficient Triton kernel optimization at each computational stage.

Computation Graph Simplified computation graph showing the 24 GEMM-like operations that constitute the core of the inference pipeline.

Acknowledgements

This project is developed based on Physical Intelligence's OpenPI project.

Citation

If you want, you can cite this work with:

@article{ma2025runningvlasrealtimespeed,
  title={Running VLAs at Real-time Speed},
  author={Ma, Yunchao and Zhou, Yizhuang and Yang, Yunhuan and Wang, Tiancai and Fan, Haoqiang},
  journal={arXiv preprint arXiv:2510.26742},
  year={2025}
}

About

Running VLA at 30Hz frame rate and 480Hz trajectory frequency

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages