Official implementation of TelePiT: Physics-Informed Teleconnection-Aware Transformer for Global S2S Forecasting
|
|
| Component | Description | Key Innovation |
|---|---|---|
| πΊοΈ Spherical Harmonic Embedding | Encodes global atmospheric variables onto spherical geometry | Learnable positional encoding for Earth's spherical structure |
| π Multi-Scale Physics-Informed ODE | Captures atmospheric dynamics across multiple frequency bands | Physics-constrained neural ODEs with learnable decomposition |
| π Teleconnection-Aware Transformer | Models global climate interactions and cross-scale processes | Attention mechanism biased by teleconnection patterns |
Click to expand installation steps
# Create conda environment
conda create -n telepit python=3.8 -y
conda activate telepit
# Install PyTorch (adjust CUDA version as needed)
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
# Install dependencies
pip install -r requirements.txtCore Dependencies:
torch>=1.9.0, torchdiffeq>=0.2.3, xarray>=0.19.0, numpy>=1.21.0
TelePiT uses the ERA5 reanalysis dataset via ChaosBench:
# Download ChaosBench dataset
git clone https://github.com/leap-stc/ChaosBench.git
cd ChaosBench
# Follow ChaosBench instructions to download ERA5 data
# Data will be automatically processed to the required format| Dataset Info | Details |
|---|---|
| π Source | ERA5 reanalysis (ECMWF) |
| π Resolution | 1.5Β° (121 Γ 240 grid) |
| π Variables | 63 total (60 pressure-level + 3 surface) |
| β° Time Split | Train: 1979-2016, Val: 2017, Test: 2018 |
You can download the checkpoints in this link: checkpoint
import torch
from S2S.models.TelePiT import Model
# Load pre-trained model
model = Model(
img_size=[121, 240], input_size=63, output_size=63,
embed_dim=256, depth=6, num_heads=8, wavelet_levels=3
)
# Load checkpoint and predict
checkpoint = torch.load('checkpoints/TelePiT/best.ckpt')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
with torch.no_grad():
# Input: [batch_size, 63, 121, 240]
# Output: [batch_size, 2, 63, 121, 240] for weeks 3-4 and 5-6
prediction = model(input_data)# Training
python train.py --config configs/telepit_config.yaml
# Evaluation
python step1_predict_to_npy.py --config_filepath configs/telepit.yaml
python step2_predict_with_wandb.py --config_filepath configs/telepit.yaml
| Variable | Metric | Previous Best | TelePiT | Improvement |
|---|---|---|---|---|
| π‘οΈ t2m (K) | RMSE | 28.526 (CirT) | 12.057 | π₯ 57.7% β |
| π z500 (mΒ²/sΒ²) | RMSE | 53.807 (CirT) | 48.671 | β¨ 9.5% β |
| π z850 (mΒ²/sΒ²) | RMSE | 34.006 (CirT) | 31.082 | β‘ 8.6% β |
| π‘οΈ t2m (K) | ACC | 0.977 (CirT) | 0.996 | π 1.9% β |
| Component | Impact | Key Benefit |
|---|---|---|
| πΊοΈ Spherical Harmonic Embedding | Critical | Largest impact on t2m: 27.2K β 12.1K |
| π Wavelet Decomposition | Essential | 7-8% improvement in geopotential heights |
| βοΈ Physics-Informed ODE | Essential | Enhanced wind component predictions |
| π Teleconnection Attention | Critical | Increasingly valuable at longer lead times |
| Component | Training | Inference |
|---|---|---|
| GPU | 2Γ RTX A40 (48GB total) | 1Γ A800 80G |
| Memory | 32GB+ RAM | 16GB+ RAM |
| Storage | 1TB+ (for dataset) | 100GB+ |
- RMSE: Root Mean Squared Error with latitude weighting
- ACC: Anomaly Correlation Coefficient
- MS-SSIM: Multi-Scale Structural Similarity
- SpecDiv: Spectral Divergence (physics-based)
- SpecRes: Spectral Residual (physics-based)
π Many Thanks for Your Review! π
Advancing atmospheric science through AI πβ‘π€