Thanks to visit codestin.com
Credit goes to github.com

Skip to content

philipphager/clax

Repository files navigation

CLAX: Fast and Flexible Neural Click Models in JAX

CLAX is a modular framework to build click models with gradient-based optimization in JAX and Flax NNX. CLAX is built to be fast, providing orders of magnitudes speed-up compared to classic EM-based frameworks, such as PyClick, by leveraging auto-diff and vectorized computations on GPUs.

The current documentation is available here and our pre-print here.

Installation

CLAX requires JAX. For installing JAX with CUDA support, please refer to the JAX documentation. CLAX itself is available via pypi:

pip install clax-models

Basic Usage

CLAX is designed with sensible defaults, while also allowing for a high-level of customization. E.g., training a User Browsing Model in CLAX is as simple as:

from clax import Trainer, UserBrowsingModel
from flax import nnx
from optax import adamw

model = UserBrowsingModel(
    query_doc_pairs=100_000_000, # Number of query-document pairs in the dataset
    positions=10, # Number of ranks per result page
    rngs=nnx.Rngs(42), # NNX random number generator
)
trainer = Trainer(
    optimizer=adamw(0.003),
    epochs=50,
)
train_df = trainer.train(model, train_loader, val_loader)
test_df = trainer.test(model, test_loader)

However, the modular design of CLAX also allows for more complex models from two-tower models, mixture models, or plugging-in custom FLAX modules as model parameters. We provide usage examples for getting started under examples/.

Reproducibility & Development

In the following, we cover how to reproduce the experiments from our paper or how to set up a fork of CLAX for development.

Initial Setup

  1. Install the UV package manager
    UV is a fast Python dependency manager. Install it from: https://github.com/astral-sh/uv

  2. Clone the CLAX repository

   git clone [email protected]:philipphager/clax.git
   cd clax/
  1. Install dependencies
   uv sync

This creates a virtual environment and installs all required dependencies.

Running Experiments

Our paper's experiments are located in the experiments/ directory. Each experiment contains:

  • A Python script with the experiment logic: main.py
  • A Hydra config file for configuration management: config.yaml
  • A bash script with all experimental configurations: main.sh

To run an experiment, follow these steps.

  1. Install experiment dependencies Installs additional packages for SLURM support and data analysis/plotting.
   uv sync --group experiments
  1. Download datasets
    Clone the Yandex and Baidu-ULTR datasets from HuggingFace. If you have GIT LFS installed, clone the datasets using:
   git lfs install
   git clone https://huggingface.co/datasets/philipphager/clax-datasets

Otherwise, download the datasets manually from HuggingFace. Note: The full datasets require 85GB of disk space. By default, CLAX expects datasets at ./clax-datasets/ relative to the project root. To use a custom path, update the dataset_dir parameter in each experiment's config.yaml:

   dataset_dir: /my/custom/path/to/datasets/
  1. Run an experiment script
    Navigate to your experiment of interest and run the bash script, e.g.:
   cd experiments/1-yandex-baseline/
   chmod +x ./main.sh
   ./main.sh

Optionally, you can run the script directly on a SLURM cluster using:

   sbatch ./main.sh +launcher=slurm

You can adjust the SLURM configuration to your cluster under: experiments/config/slurm.yaml

PyClick Experiments

Baseline experiments using PyClick require the PyPy interpreter and are maintained in a separate repository: https://github.com/philipphager/clax-baselines

Reference

If CLAX is useful to you, please consider citing our paper:

@misc{hager2025clax,
  title = {CLAX: Fast and Flexible Neural Click Models in JAX},
  author  = {Philipp Hager and Onno Zoeter and Maarten de Rijke},
  year  = {2025},
  booktitle = {arxiv}
}

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •