Authors: Wei-Yang Alex Lee, Rudrasis Chakraborty, Vishnu Suresh Lokhande
The 29th International Conference on Artificial Intelligence and Statistics (AISTATS)
We present GeoTTER, a novel framework that redefines optimal transport in the realm of zero-shot classification. Conventional methods often suffer from miscalibration and a lack of adaptability, as they rely on fixed cost matrices derived solely from pre-trained model embeddings. In contrast, GeoTTER addresses these limitations by incorporating two key techniques. First, to alleviate high-frequency label jaggedness (sample-level manifold jitter that assigns neighboring embeddings to different classes), GeoTTER integrates local geometric structure into the optimal transport formulation via graph-Laplacian smoothing, a technique grounded in spectral graph theory that enforces neighborhood consistency. Second, to correct coherent angular drift (a low-frequency orientation bias in which large groups of samples share the same angular offset from their true label prototypes), we fuse clustering-guided cost components with a globally adjusted transport cost, achieving a multi-objective optimization that respects both global distribution constraints and latent data structure. With a median improvement of +6.82% compared to zero-shot and +2.13% compared to OTTER, GeoTTER shows robust improvements across a diverse set of benchmarks. The code is available on https://github.com/TeleViaBox/GeoTTER
It is recommended to use a Python virtual environment (venv).
On macOS / Linux:
python3 -m venv venvOn Windows:
python -m venv venvOn macOS / Linux:
source venv/bin/activateOn Windows (Command Prompt):
venv\Scripts\activateOn Windows (PowerShell):
venv\Scripts\Activate.ps1cd GeoTTER/
pip install -r requirements.txtIn run_geotter.py, set:
use_dummy = TrueThen run:
python run_geotter.pyThis will:
- generate a synthetic dataset
- run ZS, OT, and GeoTTER prediction
- print accuracy information
- save results to
GeoTTER_results/
If you already have your real dataset loading pipeline, set:
use_dummy = FalseThen make sure the required dataset-loading utilities are available, such as:
- embedding loader
- label encoding loader
- label loader
- class count utility
- class balance utility
After that, run:
python run_geotter.pyThe script writes experiment results to a CSV file, for example:
GeoTTER_results/DummySet_results.csv
The CSV typically includes:
- dataset name
- backbone name
- method name (
GeoTTER) - parameter configuration
- accuracy
- prediction distribution
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
python run_geotter.py