BordAX is a research-focused framework for Programmatic Reinforcement Learning (PRL) that combines the speed of JAX with support for structured, interpretable policies.
- π High Performance: Fully JIT-compiled training pipelines leveraging JAX's XLA compilation
- π§© Modular Architecture: Clean separation between agents, algorithms, environments, and training logic
- π― Multiple Policy Types: Support for MLPs, boolean functions (HyperBool), and decision trees (DTSemNet)
- π Flexible Algorithms: Built-in PPO (on-policy) and DQN (off-policy) with easy extensibility
- π§ Extensible: Simple APIs for adding new agents, algorithms, and environments
# Clone the repository
git clone https://github.com/SynthesisLab/bordax.git
cd bordax
# Create and activate virtual environment (recommended)
python -m venv .venv
source .venv/bin/activate
# Install dependencies
pip install -r requirements.txtVerify your installation:
python -c "from bordax.trainer import Trainer; print('β BordAX installed successfully')"python train_ppo.pyThis will:
- Train an agent with MLP policy using PPO on CartPole-v1
- Save results to
runs/ppo_YYYYMMDD_HHMMSS/ - Generate training plots (rewards, policy loss, value loss, entropy)
Expected results:
- Solves CartPole-v1 (reward = 500) in ~100k steps
- Training time: ~6 seconds on CPU
python train_dqn.pyExpected results:
- Solves CartPole-v1 (reward = 500) in ~50k steps
- Training time: ~30 seconds on CPU
from bordax.trainer import Trainer, TrainerConfig
from bordax.algorithms.utils import make_algo
from bordax.environments.utils import make_env
from bordax.agents.utils import make_agent
import jax
# Setup environment
env = make_env("gymnax/CartPole-v1", {}, num_envs=1)
eval_env = make_env("gymnax/CartPole-v1", {}, num_envs=1)
# Create agent
agent = make_agent("mlp/mlp", env, {
"policy_layers": [64, 64],
"value_layers": [64, 64],
})
# Configure algorithm
algorithm = make_algo("ppo", {
"lr": 3e-4,
"rollout_length": 2048,
"gamma": 0.99,
"_lambda": 0.95,
"clip_schedule": lambda _: 0.2,
"vf_schedule": lambda _: 0.5,
"ent_schedule": lambda _: 0.01,
"num_minibatches": 16,
"num_sgd_steps": 10,
})
# Setup trainer
config = TrainerConfig(
num_checkpoints=100,
epochs_per_checkpoint=1,
evaluation_episodes=32,
debug=True,
save_model=True,
)
trainer = Trainer(env, eval_env, agent, algorithm, config)
# Initialize and train
key = jax.random.PRNGKey(0)
init_key, train_key = jax.random.split(key)
trainer.init(init_key)
metrics, eval_data, model_params = trainer.run(train_key)BordAX uses a modular pipeline architecture:
Trainer
ββ> Algorithm (Collector + BatchBuilder + Updater)
ββ> Collector: Generates environment transitions
ββ> BatchBuilder: Constructs training batches
ββ> Updater: Computes gradients and updates parameters
| Component | Purpose | Examples |
|---|---|---|
| Agent | Defines policy and value networks | MLPPolicyValue, DQNAgent |
| Algorithm | Bundles training pipeline components | ppo_algo(), dqn_algo() |
| Collector | Generates transitions via environment interaction | OnPolicyCollector, EpsGreedyCollector |
| BatchBuilder | Transforms data into training batches | MiniBatch, UniformReplayBatch |
| Updater | Updates parameters using loss functions | SGDUpdate, DQNUpdater |
| Trainer | Orchestrates full training loop | Trainer |
- PPO
- DQN
bordax/
βββ bordax/ # Main package
β βββ agents/ # Agent definitions
β β βββ base.py # Base classes and implementations
β β βββ components.py # Neural network modules
β β βββ utils.py # Agent factory
β βββ algorithms/ # RL algorithms
β β βββ base.py # Algorithm implementations
β β βββ losses.py # Algorithm-specific losses
β β βββ utils.py # Algorithm factory
β βββ environments/ # Environment adapters (Gymnax, Gymnasium)
β βββ batchbuilders.py # Batch construction
β βββ buffer.py # Replay buffer
β βββ collectors.py # Data collection strategies
β βββ trainer.py # Training pipeline orchestration
β βββ types.py # Type definitions
β βββ updaters.py # Model parameter updates
βββ train_ppo.py
βββ train_dqn.py
βββ requirements.txt
βββ README.md
MLP Policy-Value (mlp/mlp):
agent = make_agent("mlp/mlp", env, {
"policy_layers": [128, 128, 64],
"value_layers": [128, 128, 64],
})HyperBool (Boolean function-based):
agent = make_agent("boolean/mlp", env, {
"n": 4, # Number of boolean variables
"value_layers": [128, 64, 32],
})DTSemNet (Decision trees):
agent = make_agent("dt/mlp", env, {
"tree_depth": 4,
"value_layers": [64, 64],
})BordAX is released under the MIT License.
BordAX builds on excellent work from the JAX ecosystem: