This repository contains the code associated with the paper: Compositional Risk Minimization
Authors: Divyat Mahajan, Mohammad Pezeshki, Charles Arnal, Ioannis Mitliagkas, Kartik Ahuja†, Pascal Vincent†
Accepted at ICML 2025. OpenReview
Compositional generalization is a crucial step towards developing data-efficient intelligent machines that generalize in human-like ways. In this work, we tackle a challenging form of distribution shift, termed compositional shift, where some attribute combinations are completely absent at training but present in the test distribution. This shift tests the model's ability to generalize compositionally to novel attribute combinations in discriminative tasks. We model the data with flexible additive energy distributions, where each energy term represents an attribute, and derive a simple alternative to empirical risk minimization termed compositional risk minimization (CRM). We first train an additive energy classifier to predict the multiple attributes and then adjust this classifier to tackle compositional shifts. We provide an extensive theoretical analysis of CRM, where we show that our proposal extrapolates to special affine hulls of seen attribute combinations. Empirical evaluations on benchmark datasets confirms the improved robustness of CRM compared to other methods from the literature designed to tackle various forms of subpopulation shifts.
To set up the environment, clone this repository and install the required packages:
conda your_env create -f environment.ymlpython -m src.scripts.download --datasets waterbirds celeba civilcomments multinli metashift nicopp --data_path ./datasets --download --gen_meta_data --precompute_featNote that --precompute_feat flag does a forward pass for one epoch to store the feateures from the pretrained backbone for each dataset. This is useful for linear probing (no finetuning) experiments as we can directly load these precomputed features instead of computing them during the training/inference loop. For larger datasets (CivilComments, MultiNLI) this step might take a while to complete, so the user can skip the --precompute_feat flag if they do not intend to do linear probing experiments.
We illustrate the commands for CRM on the Waterbirds dataset, where we discard the group 0 during training. For launching the complete list of experiments, check the launcher.py file.
DATASETS="Waterbirds"
ALGOS="CRM"For a fast linear probing experiment with fixed hyperparameters, run the following command.
python -m src.main --case train --datasets $DATASETS --algorithms $ALGOS --data_path ./datasets --out_dir ./out --add_discarded_group 0 --quick_run --precompute_features --seed 1 The --quick_run flag uses the hyperparameters defined in the launcher.py file, which can be changed by the user. Also, removing the --precompute_features flag allows for full finetuning experiments.
For each (method, dataset), we obtain results for 5 different hyperparameter combinations with a fixed seed.
python -m src.main --case train --datasets $DATASETS --algorithms $ALGOS --data_path ./datasets --out_dir ./out --add_discarded_group 0 --num_hparams_combs 5 --num_seeds 1 For each (method, dataset), we use the group balanced accuracy on the validation set to obtain the best hyperparamater combination. The best hyperparameter combinations are then run fo 3 random seeds.
python -m src.main --case train --datasets $DATASETS --algorithms $ALGOS --data_path ./datasets --out_dir ./out --add_discarded_group 0 --best_hparams_comb_selection_metric va_unf_acc --num_seeds 3To read the results from the log files, run the following command.
python -m src.scripts.analyze_results --dataset $DATASETS --algorithms $ALGOS --dir ./out --selection_criterion unf_accTo reproduce results for the CRM's analysis with varying group size (Section G.6), use the following command to generate data for the specific (
python -m src.data.multivariate_gaussian --total_attr m --total_cat dThe launch the following command to train model with
python -m src.main --case train --datasets SynAED_{m}_{d} --precompute_feat --quick_run --algorithms CRM --data_path ./datasets --out_dir ./out --add_discarded_group {(x-1)/10} --num_seeds 1 This source code is released under the CC-BY-NC license, included here.
