Please see our paper: https://doi.org/10.48550/arXiv.2411.15802
Figure: Overview of the Model Architecture and Attention Flow.
(a) The Medical Slice Transformer framework processes individual MRI or CT slices using 2D image encoders, such as DINOv2, and then passes the encoded outputs through the Slice Transformer for downstream classification tasks.
(b) Visualization of attention mechanisms showing how the Slice Transformer assigns attention to specific slices and how within-slice attention is further refined to specific patches, resulting in a combined attention map highlighting regions of interest in the input volume.
- Clone this repository
git clone https://github.com/mueller-franzes/MST - Run:
conda env create -f environment.yaml - Run
conda activate MST
- Download data (use 'Classic Directory Name' for TCIA):
- Follow preprocessing steps in scripts/preprocessing
- Add your own dataset to mst/data/datasets
- Add your own dataset to
get_dataset()in scripts/main_train.py
Skip training and download the weights from Zenodo.
Run Script: scripts/main_train.py
- Eg.
python scripts/main_train.py --dataset LIDC --model ResNet - Use
--modelto select:- ResNet = 3D ResNet50,
- ResNetSliceTrans = MST-ResNet,
- DinoV2ClassifierSlice = MST-DINOv2
Run Script: scripts/main_predict.py
- Eg.
python scripts/main_predict.py --run_folder LIDC/ResNet - Use
--get_attentionto compute saliency maps - Use
--get_segmentationto compute segmentation masks and DICE score - Use
--use_ttato enable Test Time Augmentation