Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ahayler/tag

Repository files navigation

TAG: A Tabular Approach to Graph Learning

Introduction

This repository can be used to replicate all experiment results for TAG (Tabular Approach to Graph Learning) in our paper "Bringing Graphs to the Table: Zero-shot Node Classification via Tabular Foundation Models". We provide easy reproducibility via different shell scripts located in ./dev_docker (section 2). We provide short installation instructions for the used conda environments in section 3, additional comments in section 4, and an overview of the repository structure in section 5. We generally recommend using the shell scripts provided in section 2.

Docker

Wandb

To make Wandb run from inside a docker container, you need to create a .env file in the root directory with your Wandb configuration. The .env file should contain:

USE_WANDB=true_or_false
WANDB_API_KEY=your_wandb_api_key_here
WANDB_ENTITY=your_wandb_entity_here  
WANDB_PROJECT=your_wandb_project_here

The docker scripts will automatically load these environment variables and use them for Wandb configuration. If you do not want to use Wandb, please set USE_WANDB=false, otherwise set USE_WANDB=true. We recommend the usage of Wandb to track runs.

Running Docker Scripts

Generally, each docker script is already initialized with default hyperparameters that can be overwritten in the shell. The general pattern is ./dev_docker/<your_script> <container_suffix> <gpu_device> [additional_args], where:

  • container_suffix is the suffix appended to the container. This enables running multiple containers of the same type in parallel.
  • gpu_device is the index of the GPU you would like the script to run on.
  • [additional_args] are script-specific parameters that can be overwritten using name=value, separated by spaces. We generally set all default values for all parameters, which enables all scripts to be run without additional parameter specification.

Below follows a description of the individual docker scripts.

Note: Prepending DETACHED=false (e.g., DETACHED=false ./dev_docker/some_script.sh) to one of our Docker script runs it in attached mode, which can be very helpful for debugging.

build_image.sh: Building the image

This has to be run once to create the docker image. There are no additional hyperparameters to change.

./dev_docker/build_image.sh <container_suffix> 

preprocess_node_classification.sh: Preprocess Node Classification Datasets

This has to be run once (after build_image.sh) to create preprocessed node classification datasets for all five random seeds 0, 1, 2, 3, 4. There are no additional hyperparameters to change. This script does not directly start a detached docker container, like all the other scripts, but instead spins up five containers (one for each seed) sequentially. We therefore highly recommend running this script using tmux.

./dev_docker/preprocess_node_classification.sh <container_suffix> <gpu_device> 

Note: At the time of writing some of the dataset downloads through pytorch-geometric are broken. A fix has been proposed, but not yet merged.

run_container_node_classification.sh: Evaluate a TAG model configuration

There are the following additional arguments:

  • model.name: The name of your model. Used for logging (including wandb) and saving the evaluation artifacts to output/node_classification
  • ensemble: Specifies the ensemble config. You should provide the relative path from src/node_classification/configs/ensemble. For an overview of which ensembles are which, please refer to the last section of the README
  • seeds: Should be a list of integers that specifies on which seeds the model should be evaluated.
  • dataset: Specifies on which datasets the model should be evaluated. The three configurations used for the paper are:
    • DatasetsTSGNNs_w_PCA: for direct model-to-model comparison we apply PCA preprocessing on three datasets; for further details see Appendix C; used in Section 5.1
    • DatasetsTSGNNs: The same datasets as above, but without the PCA preprocessing on the three affected datasets
    • DatasetsTSGNNsFinetuned: The 20 datasets that were not used for finetuning the TabPFN checkpoint in Section 5.2 (and on which we evaluate)

For example, if you would like to rerun the evaluation of the finetuned checkpoint presented in Section 5.2 of the paper on the GPU with index 0 on all five seeds, you could do this using:

./dev_docker/run_container_node_classification.sh finetuned_eval 6 model.name=finetuned_eval ensemble=finetuned_models/finetuned.yaml dataset=DatasetsTSGNNsFinetuned seeds=[0,1,2,3,4]

run_container_finetuning.sh: Finetune TabPFN checkpoint

This script finetunes a TabPFN checkpoint using the same configuration as described in Section 5.2 of the paper.

There are the following additional arguments that can be overridden:

  • finetuning.learning_rate: Learning rate for finetuning (default: 1e-6)
  • finetuning.epochs: Maximum number of training epochs (default: 5000)
  • finetuning.early_stopping_patience: Number of epochs to wait before early stopping (default: 500)
  • finetuning.datasets.training: List of datasets used for training
  • finetuning.datasets.additional_validation: List of additional datasets used for validation
  • finetuning.experiment_name: Name of the experiment for logging (default: finetune_tabpfn)
  • finetuning.equal_weight_for_all_datasets: Whether to give equal weight to all datasets (default: true)

For example, to finetune TabPFN with a custom learning rate and fewer epochs (on GPU 0):

./dev_docker/run_container_finetuning.sh finetune_short 0 finetuning.learning_rate=1e-5 finetuning.epochs=1000

If you would like to replicate the results reported in Section 5.2, please use the default configurations:

./dev_docker/run_container_finetuning.sh finetune 0

In the standard configuration, the run used to train the checkpoint in the paper took ~48h to train on an NVIDIA RTX PRO 6000 Blackwell.

Manual Installation of Environments

The tag conda environment is the only environment needed to run all experiments for the node classification task AFTER all datasets have been preprocessed.

TAG

conda env create -y -f environment.yaml
conda activate tag  # or your environment name
pip install -e .

Node Classification Preprocessing

If you would like to preprocess the node classification datasets without the docker setup, please create the preprocessing environment below. Further details on how preprocessing is performed can be found in the relevant shell scripts. (The different graph libraries used to load the Node Classification datasets lead to dependency issues. We therefore have created a dedicated environment to do all dataset preprocessing on the CPU).

conda create -y -n preprocess_node_classification python=3.10
conda activate preprocess_node_classification
pip install -r requirements.txt
pip install -e .

Additional comments

When generating cache you might need to set the following environment variable (if you get an error):

"CUBLAS_WORKSPACE_CONFIG": ":4096:8"

Running Mitra

If you would like to run the Mitra TFM (e.g. using the other_tfms/mitra_10_ges+one_hot.yaml config), you will need to create a separate environment, as Mitra does not work with flash-attn installed in the main tag environment. We provide an environment file under extra/environment_mitra.yaml. Please note that you will have to change the conda environment in e.g. ./dev_docker/run_container_node_classification.sh from tag to tag_mitra.

Structure of the Repository

High-Level

├── README.md
├── environment.yaml                                  # Env for TAG (main env)
├── requirements.txt                                  # Env for node_classification preprocessing
├── limix                                             # Folder containing the LimiX codebase (currently not available as an installable package)

├── extra                                              
│   └── environment_mitra.yaml                        # Dedicated environment for running TAG with the Mitra backbone as it's incompatible with flash-attn
├── src                                              
│   ├── tag                                           # The tag package
│   └── node_classification                           # The node_classification task (imports from tag)
├── dev_docker                                        # See above
├── data                                              # All data files (raw and cache)
├── output                                            # Folder that every task saves (optional) outputs to (e.g. ensemble weights used for plotting)
│   ├── finetuning                                    # Where finetuning checkpoints are saved (We provide the checkpoint used in Section 5.2 called final_model_1e-6_5k.ckpt)
│   └── node_classification                           # Where artifacts from model evaluation runs are saved
├── .gitignore
├── setup.py                                          # Minimal setup for editable package installation
└── .project-root                                     # Needed for rootutils

TAG

The folder structure at src/tag is the following:

├── __init__.py
├── tag.py                                           # The main TAG model, which is used in the downstream tasks
├── ensemble_config_builder.py                       # Scripts to convert the Hydra model configs to Pydantic model configurations
├── ensemble_config.py                               # The Pydantic config classes needed to specify a TAG ensemble
├── ensemble_models.py                               # The actual ensemble models, which are created by TAG based on the ensemble config (all implement the scikit-learn API)
├── many_classes.py                                  # A slightly modified version of the TabPFN ManyClassClassifier (for datasets with >= 10 classes)
├── phe.py                                           # Code relating to Post-Hoc Ensembling. This includes an implementation of ensemble selection and code to get the Out-of-Fold predictions needed for ensemble selection      
└── utils.py                                         

Node Classification

The folder structure at src/node_classification is the following:

├── configs                                           
│   ├── data.yaml                                     # Configs that define the datasets and presets, which determine on which datasets we run
│   ├── main.yaml                                     # Paths, seeds, logging, etc.
│   ├── model.yaml                                    
│   └── ensemble                                      # Includes all the different model configurations as YAML files. Detailed overview is given below.
├── utils
│   ├── __init__.py                                  
│   ├── config.py                                     
│   ├── data.py                                       
│   ├── experiment.py                                 
│   ├── logging.py                                    
│   └── random_walk_pe.py                             # Slight modification of the PyTorch Geometric RandomWalkPE to tackle MKL issues
├── data.py                                           # The main dataloader for both the individual GraphDatasets and the CombinedDataset.
├── finetuning_tabpfn.py                              # Can be used to finetune TabPFN
└── main.py                                           # Main entry point. Runs the node_classification evaluation across multiple random seeds (can be changed in the config).

Ensemble

The folder structure at src/node_classification/configs/ensemble/, which includes the configurations for all the experiments in the paper:

├── component_ablations                               # Section 5.5
├── encoder_ablations                                 # Section 5.3 and Appendix A.1
├── finetuned_models                                  # Section 5.2
├── other_tfms                                        # Appendix A.2
├── scaling                                           # Section 5.4
├── generate_cache.yaml                               # Used to generate cache (see above)
└── 10_ges+one_hot.yaml                               # Section 5.1: The standard TAG configuration

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published