A command-line tool for converting JAX models to PyTorch, with a focus on generating performant PyTorch code.
pip install jax2torch
# Convert a single JAX file to PyTorch
jax2torch convert input.py --output output.py
# Convert an entire directory
jax2torch convert ./jax_models --output ./pytorch_models --recursive
# Validate conversion results
jax2torch validate original.py converted.py
from jax2torch import convert_jax_to_torch
# Convert JAX code string to PyTorch
pytorch_code = convert_jax_to_torch(jax_code_string)
# Convert with custom configuration
pytorch_code = convert_jax_to_torch(
jax_code_string,
optimize_performance=True,
preserve_comments=True
)
- Complete Model Conversion: Convert JAX/Flax/Haiku models to PyTorch
- Optimizer Translation: Optax optimizers to torch.optim equivalents
- Performance Optimization: Generates efficient PyTorch code patterns
- Weight Migration: Automatic parameter/weight conversion
- CLI and Python API: Flexible usage options
- Validation Tools: Verify conversion correctness
- JAX: Core JAX operations and transformations
- Flax: Neural network layers and modules
- Haiku: Module system conversion
- Optax: Optimizer conversion to torch.optim
- Custom JAX code: User-defined functions and operations
- Python 3.8+
- JAX >= 0.4.0
- PyTorch >= 1.12.0
- Flax (optional, for Flax model conversion)
- Optax (optional, for optimizer conversion)