Thanks to visit codestin.com
Credit goes to github.com

Skip to content

The official implementation of Global Occlusion-Aware Transformer for Robust Stereo Matching.(WACV2024)

License

Notifications You must be signed in to change notification settings

Magicboomliu/GOAT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

7 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

GOAT: Global Occlusion-Aware Transformer for Robust Stereo Matching

Paper Project Page Video

Global Occlusion-Aware Transformer for Robust Stereo Matching
WACV 2024

πŸ“‹ Table of Contents

🎯 Overview

GOAT (Global Occlusion-Aware Transformer) is a robust stereo matching network that achieves state-of-the-art performance on multiple benchmarks. The network features:

  • Global Context Modeling: Transformer-based architecture for capturing long-range dependencies
  • Occlusion Awareness: Explicit occlusion detection and handling mechanism
  • Multi-Scale Processing: Pyramid cost volume construction for robust matching
  • Model Architecture: GOAT-T (Tiny) optimized for accuracy and efficiency

✨ Features

  • βœ… Global attention mechanism for robust feature matching
  • βœ… Occlusion-aware design for handling challenging scenarios
  • βœ… Support for multiple datasets: SceneFlow, KITTI, MiddleBurry, ETH3D, FAT
  • βœ… Distributed training (DDP) support
  • βœ… TensorBoard integration for training visualization
  • βœ… Flexible loss configuration
  • βœ… Comprehensive evaluation metrics (EPE, P1-Error, mIOU)

πŸ”§ Installation

Prerequisites

  • Python >= 3.7
  • PyTorch >= 1.7.0
  • CUDA >= 10.2 (for GPU support)
  • GCC >= 5.4 (for building deformable convolution)

Quick Setup

Step 1: Clone the repository

git clone [email protected]:Magicboomliu/GOAT.git
cd GOAT

Step 2: Create and activate conda environment

conda create -n goat python=3.8
conda activate goat

Step 3: Install PyTorch

# For CUDA 11.0
pip install torch==1.7.0+cu110 torchvision==0.8.0+cu110 -f https://download.pytorch.org/whl/torch_stable.html

# Or for CUDA 10.2
pip install torch==1.7.0+cu102 torchvision==0.8.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html

Step 4: Install dependencies

pip install -r requirements.txt

Step 5: Install GOAT package

pip install -e .

Step 6: Build Deformable Convolution (Optional)

cd third_party/deform
bash make.sh
cd ../..

Note: Deformable convolution is optional. The model works without it, but may achieve slightly better performance with it enabled via --use_deform flag.

Verify Installation

# Test imports
python -c "import goat; from goat.models.networks.Methods.GOAT_T import GOAT_T; print('Installation successful!')"

# Check GPU availability
python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}, GPU count: {torch.cuda.device_count()}')"

πŸ“‚ Project Structure

The repository is organized as follows:

GOAT/
β”œβ”€β”€ goat/              # Main source package (models, losses, utilities)
β”œβ”€β”€ data/              # Dataloaders and dataset file lists
β”œβ”€β”€ scripts/           # Training and evaluation scripts
β”œβ”€β”€ configs/           # Configuration files
β”œβ”€β”€ third_party/       # External dependencies (deformable convolution)
β”œβ”€β”€ docs/              # Documentation and assets
└── tests/             # Unit tests (future)

Key directories:

  • goat/models/: Network architectures (GOAT-T, GOAT-L, attention modules, etc.)
  • goat/losses/: Loss functions for training
  • goat/utils/: Utility functions and metrics
  • data/dataloaders/: Dataset loaders for KITTI, SceneFlow, MiddleBurry, ETH3D, FAT
  • data/filenames/: Dataset file lists organized by dataset type
  • scripts/: Executable training scripts

For a detailed structure guide, see STRUCTURE.md.

πŸ“‹ Quick Links:

πŸ“Š Dataset Preparation

SceneFlow Dataset

  1. Download the SceneFlow dataset
  2. Organize the data structure as:
/path/to/sceneflow/
β”œβ”€β”€ frames_cleanpass/
β”œβ”€β”€ frames_finalpass/
└── disparity/
  1. Update the dataset path in training script:
--datapath /path/to/sceneflow

KITTI Dataset

  1. Download KITTI 2012 and/or KITTI 2015
  2. Organize as:
/path/to/kitti/
β”œβ”€β”€ 2012/
β”‚   β”œβ”€β”€ training/
β”‚   └── testing/
└── 2015/
    β”œβ”€β”€ training/
    └── testing/

Other Datasets

See data/filenames/ directory for supported datasets:

  • MiddleBurry
  • ETH3D
  • FAT (Flying Automotive Things)

πŸš€ Training

Quick Start

Prepare your dataset first (see Dataset Preparation), then run:

# Create necessary directories
mkdir -p models_saved logs experiments_logdir

# Single GPU training (recommended to test first)
python scripts/train.py \
    --cuda \
    --model GOAT_T_Origin \
    --loss configs/loss_config_disp.json \
    --lr 1e-3 \
    --batch_size 4 \
    --dataset sceneflow \
    --trainlist data/filenames/SceneFlow/SceneFlow_With_Occ.list \
    --vallist data/filenames/SceneFlow/FlyingThings3D_Test_With_Occ.list \
    --datapath /path/to/sceneflow \
    --outf models_saved/goat_experiment \
    --logFile logs/train.log \
    --save_logdir experiments_logdir/goat_experiment \
    --devices 0 \
    --local_rank 0 \
    --datathread 4 \
    --manualSeed 1024

Multi-GPU Training (Recommended)

For faster training, use distributed data parallel (DDP) with multiple GPUs:

# 2 GPUs
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
    --nproc_per_node=2 \
    scripts/train.py \
    --cuda \
    --model GOAT_T_Origin \
    --loss configs/loss_config_disp.json \
    --lr 1e-3 \
    --batch_size 2 \
    --dataset sceneflow \
    --trainlist data/filenames/SceneFlow/SceneFlow_With_Occ.list \
    --vallist data/filenames/SceneFlow/FlyingThings3D_Test_With_Occ.list \
    --datapath /path/to/sceneflow \
    --outf models_saved/goat_experiment \
    --logFile logs/train.log \
    --save_logdir experiments_logdir/goat_experiment \
    --devices 0,1 \
    --datathread 4 \
    --manualSeed 1024

# 4 GPUs (adjust batch_size and nproc_per_node accordingly)
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch \
    --nproc_per_node=4 \
    scripts/train.py \
    --cuda \
    --model GOAT_T_Origin \
    --loss configs/loss_config_disp.json \
    --lr 1e-3 \
    --batch_size 1 \
    --dataset sceneflow \
    --trainlist data/filenames/SceneFlow/SceneFlow_With_Occ.list \
    --vallist data/filenames/SceneFlow/FlyingThings3D_Test_With_Occ.list \
    --datapath /path/to/sceneflow \
    --outf models_saved/goat_experiment \
    --logFile logs/train.log \
    --save_logdir experiments_logdir/goat_experiment \
    --devices 0,1,2,3 \
    --datathread 4 \
    --manualSeed 1024

Using the Training Script

Modify and use the provided shell script:

# 1. Edit scripts/train.sh and update these variables:
#    - datapath: path to your dataset
#    - pretrain_name: experiment name
#    - Other parameters as needed

# 2. Run the script
bash scripts/train.sh

Training Arguments

Required Arguments:

Argument Description Example
--cuda Enable CUDA (flag, no value)
--model Model architecture GOAT_T_Origin
--loss Loss configuration file configs/loss_config_disp.json
--dataset Dataset name sceneflow
--trainlist Training file list data/filenames/SceneFlow/SceneFlow_With_Occ.list
--vallist Validation file list data/filenames/SceneFlow/FlyingThings3D_Test_With_Occ.list
--datapath Dataset root path /path/to/sceneflow
--outf Output directory for models models_saved/experiment_name
--logFile Log file path logs/train.log
--save_logdir TensorBoard log directory experiments_logdir/experiment_name
--devices GPU device IDs 0 or 0,1,2,3
--local_rank Local rank (DDP) Auto-set by launcher
--datathread Number of data loading workers 4

Optional Arguments:

Argument Description Default
--net Legacy network name simplenet
--lr Learning rate 0.0002
--batch_size Batch size per GPU 8
--test_batch Test batch size 4
--maxdisp Maximum disparity -1 (auto)
--pretrain Path to pretrained model none
--initial_pretrain Partial weight loading none
--use_deform Use deformable convolution False
--startRound Start training round 0
--startEpoch Start epoch 0
--manualSeed Random seed Random
--workers Number of workers 8
--momentum SGD momentum 0.9
--beta Adam beta 0.999

With Deformable Convolution:

# First build deformable convolution (see Installation)
python scripts/train.py \
    --use_deform \
    ... other arguments ...

Tracked Metrics:

  • Total loss (combined disparity + occlusion loss)
  • Disparity sequence loss
  • Occlusion loss
  • Disparity EPE (End-Point Error)
  • Occlusion EPE
  • Occlusion mIOU (mean Intersection over Union)
  • P1-Error (percentage of pixels with >1px error)
  • Learning rate

Model Checkpoints: Models are automatically saved in the --outf directory:

  • {net}_{round}_{epoch}_{EPE}.pth - Regular checkpoints
  • model_best.pth - Best model based on validation EPE
  • checkpoint.pth - Latest checkpoint

Training Schedule: The learning rate schedule (defined in goat/trainer.py):

  • Epochs 0-10: 3e-4
  • Epochs 11-39: 1e-4
  • Epochs 40-49: 5e-5
  • Epochs 50-59: 3e-5
  • Epochs 60+: 1.5e-5

πŸ“Š Dataset Information

Supported Datasets

Dataset Type Training Samples Test Samples Resolution Usage
SceneFlow Synthetic ~35K ~4K Variable Pre-training
KITTI 2015 Real-world 200 200 1242Γ—375 Fine-tuning/Eval
KITTI 2012 Real-world 194 195 1242Γ—375 Fine-tuning/Eval
MiddleBurry Real-world ~100 ~15 Variable Evaluation
ETH3D Real-world 27 20 Variable Evaluation
FAT Synthetic ~35K ~5K Variable Training/Eval

File Lists

All dataset file lists are in data/filenames/:

data/filenames/
β”œβ”€β”€ SceneFlow/
β”‚   β”œβ”€β”€ SceneFlow_With_Occ.list        # Training with occlusion
β”‚   └── FlyingThings3D_Test_With_Occ.list  # Validation with occlusion
β”œβ”€β”€ KITTI/
β”‚   β”œβ”€β”€ KITTI_mix_train.txt            # Combined KITTI 2012+2015 train
β”‚   β”œβ”€β”€ KITTI_mix_val.txt              # Combined KITTI 2012+2015 val
β”‚   β”œβ”€β”€ KITTI_2015_train.txt           # KITTI 2015 train only
β”‚   └── KITTI_2015_val.txt             # KITTI 2015 val only
β”œβ”€β”€ MiddleBurry/
β”‚   └── middleburry_all_training.list
β”œβ”€β”€ ETH3D/
β”‚   └── ETH3D.list
└── FAT/
    β”œβ”€β”€ FAT_Trainlist_occ.txt
    └── FAT_Testlist_occ.txt

πŸ“– Citation

If you find this work useful in your research, please consider citing:

@InProceedings{Liu_2024_WACV,
    title     = {Global Occlusion-Aware Transformer for Robust Stereo Matching},
    booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
    month     = {January},
    year      = {2024},
    pages     = {TBA}
}

πŸ™ Acknowledgements

This project is built upon several excellent works:

We thank the authors for open-sourcing their implementations.

πŸ“œ License

This project is licensed under the MIT License - see the LICENSE file for details.

About

The official implementation of Global Occlusion-Aware Transformer for Robust Stereo Matching.(WACV2024)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published