This repository contains all code and resources for experiments conducted in our paper, U-Prithvi: Integrating a GeoAI Foundation Model with UNet for Flood Inundation Mapping. Our research introduces U-Prithvi, a novel framework that combines a GeoAI foundation model with a UNet architecture to enhance flood inundation mapping.
- Python 3.12
-
Clone the repository
git clone https://github.com/your-repo/segmentation-floods.git cd segmentation-floods -
Install all required packages
pip install -r requirements.txt
-
Download the Prithvi model
# Ensure git-lfs is installed (https://git-lfs.com) git lfs install git clone https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M # Rename to a valid Python module name mv Prithvi-100M prithvi touch prithvi/__init__.py
-
Download the Sen1Floods11 dataset
Follow the download instructions here. The folder structure should be as follows:
sen1floods11/ ├── LabelHand/ ├── S1Hand/ ├── S2Hand/ └── splits/
To train the model, execute the following command:
python train.py *params*This will start the training process using the specified dataset and model directory.
--data_dir: Path to the dataset directory.--model_dir: Path to the Prithvi model directory.--batch_size: Batch size for training (default: 32).--epochs: Number of epochs for training (default: 50).--learning_rate: Learning rate for the optimizer (default: 0.001).--checkpoint_dir: Directory to save model checkpoints.--log_dir: Directory to save training logs.
--combine_func: Combination function applied only for U-Prithvi mode. Options: [concat, product, sum]. (default: concat)--random_dropout_prob: The probability that one of the embeddings will be dropped. (default: 2/3)