Thanks to visit codestin.com
Credit goes to github.com

Skip to content

SynthesisLab/bordax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

49 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

BordAX

A High-Performance JAX Framework for Programmatic Reinforcement Learning

Python 3.12+ JAX License: MIT


Overview

BordAX is a research-focused framework for Programmatic Reinforcement Learning (PRL) that combines the speed of JAX with support for structured, interpretable policies.

Key Features

  • πŸš€ High Performance: Fully JIT-compiled training pipelines leveraging JAX's XLA compilation
  • 🧩 Modular Architecture: Clean separation between agents, algorithms, environments, and training logic
  • 🎯 Multiple Policy Types: Support for MLPs, boolean functions (HyperBool), and decision trees (DTSemNet)
  • πŸ”„ Flexible Algorithms: Built-in PPO (on-policy) and DQN (off-policy) with easy extensibility
  • πŸ”§ Extensible: Simple APIs for adding new agents, algorithms, and environments

Installation

Setup

# Clone the repository
git clone https://github.com/SynthesisLab/bordax.git
cd bordax

# Create and activate virtual environment (recommended)
python -m venv .venv
source .venv/bin/activate

# Install dependencies
pip install -r requirements.txt

Quick Test

Verify your installation:

python -c "from bordax.trainer import Trainer; print('βœ“ BordAX installed successfully')"

Quick Start

Training PPO on CartPole

python train_ppo.py

This will:

  • Train an agent with MLP policy using PPO on CartPole-v1
  • Save results to runs/ppo_YYYYMMDD_HHMMSS/
  • Generate training plots (rewards, policy loss, value loss, entropy)

Expected results:

  • Solves CartPole-v1 (reward = 500) in ~100k steps
  • Training time: ~6 seconds on CPU

PPO Training Rewards

Training DQN on CartPole

python train_dqn.py

Expected results:

  • Solves CartPole-v1 (reward = 500) in ~50k steps
  • Training time: ~30 seconds on CPU

DQN Training Rewards

Custom Training Script

from bordax.trainer import Trainer, TrainerConfig
from bordax.algorithms.utils import make_algo
from bordax.environments.utils import make_env
from bordax.agents.utils import make_agent
import jax

# Setup environment
env = make_env("gymnax/CartPole-v1", {}, num_envs=1)
eval_env = make_env("gymnax/CartPole-v1", {}, num_envs=1)

# Create agent
agent = make_agent("mlp/mlp", env, {
    "policy_layers": [64, 64],
    "value_layers": [64, 64],
})

# Configure algorithm
algorithm = make_algo("ppo", {
    "lr": 3e-4,
    "rollout_length": 2048,
    "gamma": 0.99,
    "_lambda": 0.95,
    "clip_schedule": lambda _: 0.2,
    "vf_schedule": lambda _: 0.5,
    "ent_schedule": lambda _: 0.01,
    "num_minibatches": 16,
    "num_sgd_steps": 10,
})

# Setup trainer
config = TrainerConfig(
    num_checkpoints=100,
    epochs_per_checkpoint=1,
    evaluation_episodes=32,
    debug=True,
    save_model=True,
)

trainer = Trainer(env, eval_env, agent, algorithm, config)

# Initialize and train
key = jax.random.PRNGKey(0)
init_key, train_key = jax.random.split(key)
trainer.init(init_key)

metrics, eval_data, model_params = trainer.run(train_key)

Architecture

BordAX uses a modular pipeline architecture:

Trainer
  └─> Algorithm (Collector + BatchBuilder + Updater)
       β”œβ”€> Collector: Generates environment transitions
       β”œβ”€> BatchBuilder: Constructs training batches
       └─> Updater: Computes gradients and updates parameters

Core Components

Component Purpose Examples
Agent Defines policy and value networks MLPPolicyValue, DQNAgent
Algorithm Bundles training pipeline components ppo_algo(), dqn_algo()
Collector Generates transitions via environment interaction OnPolicyCollector, EpsGreedyCollector
BatchBuilder Transforms data into training batches MiniBatch, UniformReplayBatch
Updater Updates parameters using loss functions SGDUpdate, DQNUpdater
Trainer Orchestrates full training loop Trainer

Supported Algorithms

  • PPO
  • DQN

Project Structure

bordax/
β”œβ”€β”€ bordax/                   # Main package
β”‚   β”œβ”€β”€ agents/               # Agent definitions
β”‚   β”‚   β”œβ”€β”€ base.py           # Base classes and implementations
β”‚   β”‚   β”œβ”€β”€ components.py     # Neural network modules
β”‚   β”‚   └── utils.py          # Agent factory
β”‚   β”œβ”€β”€ algorithms/           # RL algorithms
β”‚   β”‚   β”œβ”€β”€ base.py           # Algorithm implementations
β”‚   β”‚   β”œβ”€β”€ losses.py         # Algorithm-specific losses
β”‚   β”‚   └── utils.py          # Algorithm factory
β”‚   β”œβ”€β”€ environments/         # Environment adapters (Gymnax, Gymnasium)
β”‚   β”œβ”€β”€ batchbuilders.py      # Batch construction
β”‚   β”œβ”€β”€ buffer.py             # Replay buffer
β”‚   β”œβ”€β”€ collectors.py         # Data collection strategies
β”‚   β”œβ”€β”€ trainer.py            # Training pipeline orchestration
β”‚   β”œβ”€β”€ types.py              # Type definitions
β”‚   └── updaters.py           # Model parameter updates
β”œβ”€β”€ train_ppo.py              
β”œβ”€β”€ train_dqn.py              
β”œβ”€β”€ requirements.txt          
└── README.md                 

Policy Representations

Standard Neural Networks

MLP Policy-Value (mlp/mlp):

agent = make_agent("mlp/mlp", env, {
    "policy_layers": [128, 128, 64],
    "value_layers": [128, 128, 64],
})

Programmatic Policies

HyperBool (Boolean function-based):

agent = make_agent("boolean/mlp", env, {
    "n": 4,  # Number of boolean variables
    "value_layers": [128, 64, 32],
})

DTSemNet (Decision trees):

agent = make_agent("dt/mlp", env, {
    "tree_depth": 4,
    "value_layers": [64, 64],
})

License

BordAX is released under the MIT License.


Acknowledgments

BordAX builds on excellent work from the JAX ecosystem:

  • JAX: High-performance numerical computing
  • Flax: Neural network library
  • Gymnax: JAX-compatible RL environments
  • Optax: Gradient processing and optimization
  • Distrax: Probability distributions

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages