Thanks to visit codestin.com
Credit goes to github.com

Skip to content

TrainLoop/jax2torch

Repository files navigation

JAX2Torch

A command-line tool for converting JAX models to PyTorch, with a focus on generating performant PyTorch code.

Installation

pip install jax2torch

Usage

Command Line Interface

# Convert a single JAX file to PyTorch
jax2torch convert input.py --output output.py

# Convert an entire directory
jax2torch convert ./jax_models --output ./pytorch_models --recursive

# Validate conversion results
jax2torch validate original.py converted.py

Python API

from jax2torch import convert_jax_to_torch

# Convert JAX code string to PyTorch
pytorch_code = convert_jax_to_torch(jax_code_string)

# Convert with custom configuration
pytorch_code = convert_jax_to_torch(
    jax_code_string,
    optimize_performance=True,
    preserve_comments=True
)

Features

  • Complete Model Conversion: Convert JAX/Flax/Haiku models to PyTorch
  • Optimizer Translation: Optax optimizers to torch.optim equivalents
  • Performance Optimization: Generates efficient PyTorch code patterns
  • Weight Migration: Automatic parameter/weight conversion
  • CLI and Python API: Flexible usage options
  • Validation Tools: Verify conversion correctness

Supported Libraries

  • JAX: Core JAX operations and transformations
  • Flax: Neural network layers and modules
  • Haiku: Module system conversion
  • Optax: Optimizer conversion to torch.optim
  • Custom JAX code: User-defined functions and operations

Requirements

  • Python 3.8+
  • JAX >= 0.4.0
  • PyTorch >= 1.12.0
  • Flax (optional, for Flax model conversion)
  • Optax (optional, for optimizer conversion)

About

A cli tool to translate jax models into PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages