Thanks to visit codestin.com
Credit goes to github.com

Skip to content

JianghaoWu/SAM-TTA

Repository files navigation

SAM-TTA: SAM-aware Test-Time Adaptation for Universal Medical Image Segmentation

Paper

License Python 3.8+ PyTorch 2.0+

SAM-aware Test-time Adaptation for Universal Medical Image Segmentation
Jianghao Wu, Yicheng Wu*, Yutong Xie, Wenjia Bai, You Zhang, Feilong Tang, Yulong Li, Imran Razzak, Daniel F Schmidt, Yasmeen George


Overview

SAM-TTA is a lightweight test-time adaptation framework that adapts the Segment Anything Model (SAM) to medical image segmentation — without any labeled data, retraining, or access to the source training set.


SAM-TTA framework: SBCT converts grayscale inputs into SAM-compatible 3-channel images; IMA uses a teacher–student structure with IoU-weighted losses to align predictions at multiple scales.

Installation

git clone https://github.com/JianghaoWu/SAM-TTA.git
cd SAM-TTA
pip install -r requirements.txt

Download SAM Checkpoint

Download the SAM ViT-B checkpoint and place it in checkpoints/:

mkdir -p checkpoints
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P checkpoints/

Datasets

Dataset Structure

Each dataset requires a CSV file listing image–mask pairs. The expected directory structure is:

data/
├── Pancreas/
│   ├── train.csv        # columns: image_path, mask_path
│   ├── val.csv
│   ├── test.csv
│   └── all.csv
└── BraTS_PED_t2w_2D/
    ├── train.csv
    ├── val.csv
    ├── test.csv
    └── all.csv

Each CSV row should contain: relative/path/to/image.nii.gz, relative/path/to/mask.nii.gz

Data Sources

We evaluate on eight public datasets:

Dataset Modality Source
BraTS-PED 2024 (T2W / T2F) MRI BraTS 2024 Challenge
BraTS-SSA 2023 (T2W / T2F) MRI BraTS 2023 Challenge
Pancreas (AMOS 2022) MRI AMOS Challenge
Pancreatic Cancer (MSWAL) CT MSWAL Dataset
CVC-ColonDB Endoscopy CVC-ColonDB
Kvasir-SEG Endoscopy Kvasir-SEG

Please refer to the respective official sources for download and terms of use.


Usage

Source-only Baseline (no adaptation)

python methods/adaptation_source.py --cfg configs/config_brats.py --dataset Pancreas --prompt box
python methods/adaptation_source.py --cfg configs/config_brats.py --dataset BraTS_PED_t2w_2D --prompt box

SAM-TTA (our method)

python methods/adaptation_samtta.py --cfg configs/config_brats.py --dataset Pancreas --prompt box
python methods/adaptation_samtta.py --cfg configs/config_brats.py --dataset BraTS_PED_t2w_2D --prompt box

Mean Teacher Baseline

python methods/adaptation_mt.py --cfg configs/config_brats.py --dataset Pancreas --prompt box

Arguments

Argument Description Example
--cfg Path to config file configs/config_brats.py
--dataset Dataset name (must match base_config.py) Pancreas, BraTS_PED_t2w_2D
--prompt Prompt type box or point
--gpu_ids GPU device IDs 0 or 0,1

Output

Results are saved to output/<method>/<dataset>/:

output/samtta/Pancreas/
├── Pancreas-box-pred_masks/      # per-image binary mask PNGs
├── Pancreas-box-test-results.csv # per-image Dice / ASSD / HD95
├── configs/                      # saved YAML config
└── save-ckpt/                    # adapted model checkpoint

Configuration

Edit configs/config_brats.py or configs/base_config.py to customize settings:

config = {
    "gpu_ids": "0",
    "batch_size": 1,
    "opt": {
        "learning_rate": 1e-3,   # LoRA + prompt encoder LR
    },
    "model": {
        "type": "vit_b",         # SAM backbone: vit_b / vit_l / vit_h
    },
}

Dataset paths are configured in configs/base_config.py under the datasets key.


Project Structure

SAM-TTA/
├── methods/
│   ├── adaptation_samtta.py    # SAM-TTA (ours)
│   ├── adaptation_source.py    # Source-only baseline
│   ├── adaptation_mt.py        # Mean Teacher baseline
│   └── adaptation_medsam.py    # MedSAM inference baseline
├── utils/
│   ├── nonlinear_net.py        # DynamicBezierTransform2D (SBCT)
│   ├── nonlinear.py            # LearnableBezierTransform (static variant)
│   ├── eval_utils.py           # Metrics: Dice, HD95, ASSD
│   └── tools.py                # EMA update, model copy, etc.
├── datasets/
│   ├── NII_test.py             # NIfTI loader for TTA (main dataset class)
│   ├── ISIC_test.py            # 2D image loader for TTA
│   └── ...
├── segment_anything/           # SAM implementation (ViT-B/L/H)
├── configs/
│   ├── base_config.py          # Global hyperparameters & dataset paths
│   └── config_brats.py         # BraTS-specific overrides
├── losses.py                   # DiceLoss, kl_spatial_per_channel, etc.
├── model.py                    # SAM wrapper (encode / decode)
├── sam_lora.py                 # LoRA adaptation for SAM image encoder
└── requirements.txt

Citation

If you find SAM-TTA useful in your research, please cite:

@article{wu2025samtta,
  title     = {SAM-aware Test-time Adaptation for Universal Medical Image Segmentation},
  author    = {Wu, Jianghao and Wu, Yicheng and Xie, Yutong and Bai, Wenjia and Zhang, You and
               Tang, Feilong and Li, Yulong and Razzak, Imran and Schmidt, Daniel F and George, Yasmeen},
  journal   = {arXiv},
  year      = {2025},
  note      = {arXiv:2506.05221}
}

Acknowledgements

This work was supported by the Commonwealth of Australia under the Medical Research Future Fund (No. NCRI000074). We thank the organizers of BraTS 2023/2024, AMOS 2022, and MSWAL for making their datasets available.

Our code builds on:


License

This project is released under the MIT License.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors