We’re excited to announce PlantCAD2 🌱 — our new DNA foundation model for angiosperms.
In addition, we’re also releasing a collection of LoRA fine-tuned models 🎯, tailored for key downstream tasks including accessible chromatin, gene expression, and protein translation.
-
Explore the fine-tuned PlantCAD2 models here
-
Explore the zero-shot evaluation of PlantCAD2 models here
-
Explore the post-training or pre-training of PlantCAD2 here
- PlantCAD overview
- Quick Start
- Model summary
- Prerequisites and system requirements
- Installation
- Basic Usage
- Advanced Usage
- Development and Training
- Model Recommendations
- Citation
PlantCaduceus, with its short name of PlantCAD, is a plant DNA LM based on the Caduceus architecture, which extends the efficient Mamba linear-time sequence modeling framework to incorporate bi-directionality and reverse complement equivariance, specifically designed for DNA sequences. PlantCAD is pre-trained on a curated dataset of 16 Angiosperm genomes. PlantCAD showed state-of-the-art cross species performance in predicting TIS, TTS, Splice Donor and Splice Acceptor. The zero-shot of PlantCAD enables identifying genome-wide deleterious mutations and known causal variants in Arabidopsis, Sorghum and Maize.
New to PlantCAD? Try our Google Colab demo - no installation required!
For local usage: See installation instructions here, then use notebooks/examples.ipynb to get started.
Pre-trained models have been uploaded to HuggingFace 🤗: PlantCAD and PlantCAD2. Here is the comparison between PlantCAD (v1) and PlantCAD2 models:
| Model | Sequence Length | Model Size | Embedding Size |
|---|---|---|---|
| PlantCAD | |||
| PlantCaduceus_l20 | 512bp | 20M | 384 |
| PlantCaduceus_l24 | 512bp | 40M | 512 |
| PlantCaduceus_l28 | 512bp | 128M | 768 |
| PlantCaduceus_l32 | 512bp | 225M | 1024 |
| PlantCAD2 | |||
| PlantCAD2-Small | 8192bp | 88M | 768 |
| PlantCAD2-Medium | 8192bp | 311M | 1024 |
| PlantCAD2-Large | 8192bp | 694M | 1536 |
Note: For PlantCAD, the maximum sequence length is 512bp. For PlantCAD2, it is 8,192bp.
For Google Colab: Just a Google account - GPU runtime recommended (free tier available)
For Local Installation: NVIDIA GPU is required!
No installation required! Just open our PlantCAD Google Colab notebook and start analyzing your data.
See our Local Installation Guide.
The easiest way to start is with our example notebook: notebooks/examples.ipynb
Quick example - Get sequence embeddings:
import torch
from mamba_ssm import Mamba
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
device = 'cuda:0'
# Test PlantCAD model loading
tokenizer = AutoTokenizer.from_pretrained('kuleshov-group/PlantCaduceus_l32')
model = AutoModelForMaskedLM.from_pretrained('kuleshov-group/PlantCaduceus_l32', trust_remote_code=True)
model.to(device)
# Example plant DNA sequence (512bp max)
sequence = "CTTAATTAATATTGCCTTTGTAATAACGCGCGAAACACAAATCTTCTCTGCCTAATGCAGTAGTCATGTGTTGACTCCTTCAAAATTTCCAAGAAGTTAGTGGCTGGTGTGTCATTGTCTTCATCTTTTTTTTTTTTTTTTTAAAAATTGAATGCGACATGTACTCCTCAACGTATAAGCTCAATGCTTGTTACTGAAACATCTCTTGTCTGATTTTTTCAGGCTAAGTCTTACAGAAAGTGATTGGGCACTTCAATGGCTTTCACAAATGAAAAAGATGGATCTAAGGGATTTGTGAAGAGAGTGGCTTCATCTTTCTCCATGAGGAAGAAGAAGAATGCAACAAGTGAACCCAAGTTGCTTCCAAGATCGAAATCAACAGGTTCTGCTAACTTTGAATCCATGAGGCTACCTGCAACGAAGAAGATTTCAGATGTCACAAACAAAACAAGGATCAAACCATTAGGTGGTGTAGCACCAGCACAACCAAGAAGGGAAAAGATCGATGATCG"
device = 'cuda:0'
# Get embeddings
encoding = tokenizer.encode_plus(
sequence,
return_tensors="pt",
return_attention_mask=False,
return_token_type_ids=False
)
input_ids = encoding["input_ids"].to(device)
with torch.inference_mode():
outputs = model(input_ids=input_ids, output_hidden_states=True)
embeddings = outputs.hidden_states[-1]
print(f"Embedding shape: {embeddings.shape}") # [batch_size, seq_len, embedding_dim]
embeddings = embeddings.to(torch.float32).cpu().numpy()
# Given that PlantCaduceus has bi-directionality and reverse complement equivariance, so the first half of embedding is for forward sequences and the sencond half is for reverse complemented sequences, we need to average the embeddings before working on downstream classifier
hidden_size = embeddings.shape[-1] // 2
forward = embeddings[..., 0:hidden_size]
reverse = embeddings[..., hidden_size:]
reverse = reverse[..., ::-1]
averaged_embeddings = (forward + reverse) / 2
print(averaged_embeddings.shape)The zero_shot_score.py script now provides unified functionality to estimate the functional impact of genetic variants or score genomic regions using PlantCAD's log-likelihood scores. It supports two primary modes:
- Variant Scoring (VCF Input): Scores specific genetic variants provided in a VCF file.
- Genome-Wide Region Scoring (BED Input): Calculates log-likelihood ratios for all positions within specified genomic regions (BED file).
# Download example reference genome
wget https://download.maizegdb.org/Zm-B73-REFERENCE-NAM-5.0/Zm-B73-REFERENCE-NAM-5.0.fa.gz
gunzip Zm-B73-REFERENCE-NAM-5.0.fa.gz# VCF Input Mode
# --- Example: Variant Scoring (VCF Input) ---
# Estimate impact of specific variants from a VCF file.
# Note: Only the first 8 columns of the VCF file (CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO) are strictly required.
python src/zero_shot_score.py \
-input-vcf examples/example_maize_snp.vcf \
-input-fasta Zm-B73-REFERENCE-NAM-5.0.fa \
-output scored_variants.vcf \
-model 'kuleshov-group/PlantCaduceus_l32' \
-device 'cuda:0'
# Expected output for VCF mode:
# - A new VCF file ('scored_variants.vcf') with PlantCAD scores added to the INFO field.
# - Scores represent log-likelihood ratios between reference and alternative alleles.
# Low negative scores indicate potentially more deleterious mutations.# BED Input Mode
# --- Example: Genome-Wide Region Scoring (BED Input) ---
# Calculate log-likelihood ratios for all positions within specified BED regions.
# Note: You would need an example BED file for this.
# For demonstration, creating a dummy BED file:
echo -e "chr1\t1000\t1010\nchr1\t2000\t2015" > examples/example_regions.bed
python src/zero_shot_score.py \
-input-bed examples/example_regions.bed \
-input-fasta Zm-B73-REFERENCE-NAM-5.0.fa \
-output genome_wide_scores.tsv \
-model 'kuleshov-group/PlantCaduceus_l32' \
-device 'cuda:0' \
-step-size 1 \
-aggregation average \
-use-masking \
-output-raw-prob
# Expected output for BED mode:
# - A tab-separated file ('genome_wide_scores.tsv') containing scores for each position.
# - Output includes chromosome, start, end, reference allele, aggregated score,
# and optionally raw probabilities for all four nucleotides.
# - `-step-size`: Number of positions to analyze per window, if step size is greater than 1, we recommend turn off masking!!!
# - `-aggregation`: How to aggregate alternative allele scores.
# - `'max'`: Reports the maximum log-likelihood ratio among all three alternative alleles relative to the reference.
# - `'average'`: Reports the average log-likelihood ratio across all three alternative alleles relative to the reference.
# - `'all'`: Reports the individual log-likelihood ratios for each of the three alternative alleles relative to the reference.
# - `-use-masking`: Whether to mask the central position(s) during inference.
# - `-output-raw-prob`: Include raw probabilities in the output.When analyzing the entire genome or large genomic regions, the -step-size parameter is very important for speeding up the analysis. For a detailed guide on this trade-off between speed and accuracy, see here.
For large-scale simulation and analysis of genetic variants, we provide a comprehensive in-silico mutagenesis pipeline. See pipelines/in-silico-mutagenesis/README.md for detailed instructions.
Train custom classifiers on top of PlantCAD embeddings for specific annotation tasks (e.g., TIS, TTS, splice sites).
Purpose: Fine-tune prediction performance for specific annotation tasks using supervised learning.
Data format: Training data should follow the format used in our cross-species annotation dataset.
python src/train_XGBoost.py \
-train train.tsv \
-valid valid.tsv \
-test test_rice.tsv \
-model 'kuleshov-group/PlantCaduceus_l20' \
-output ./output \
-device 'cuda:0'Expected outputs:
- Trained XGBoost classifier (
.jsonfile) - Performance metrics on validation/test sets
- Feature importance analysis
We provide pre-trained XGBoost classifiers for common annotation tasks in the classifiers directory.
Available classifiers:
- TIS (Translation Initiation Sites)
- TTS (Translation Termination Sites)
- Splice donor/acceptor sites
python src/predict_XGBoost.py \
-test test_rice.tsv \
-model 'kuleshov-group/PlantCaduceus_l20' \
-classifier classifiers/PlantCaduceus_l20/TIS_XGBoost.json \
-device 'cuda:0' \
-output ./outputExpected output: Predictions with confidence scores for each sequence in your test data.
For advanced users who want to pre-train PlantCAD or PlantCAD2 models from scratch or fine-tune on custom datasets.
Requirements:
- Large computational resources (multi-GPU recommended)
- WandB account for experiment tracking
- Custom genomic dataset in HuggingFace format
Basic pre-training command:
WANDB_PROJECT=PlantCAD python src/HF_pre_train.py \
--do_train \
--report_to wandb \
--prediction_loss_only True \
--remove_unused_columns False \
--dataset_name 'kuleshov-group/Angiosperm_16_genomes' \
--soft_masked_loss_weights_train 0.1 \
--soft_masked_loss_weights_evaluation 0.0 \
--weight_decay 0.01 \
--optim adamw_torch \
--dataloader_num_workers 16 \
--preprocessing_num_workers 16 \
--seed 32 \
--save_strategy steps \
--save_steps 1000 \
--evaluation_strategy steps \
--eval_steps 1000 \
--logging_steps 10 \
--max_steps 120000 \
--warmup_steps 1000 \
--save_total_limit 20 \
--learning_rate 2E-4 \
--lr_scheduler_type constant_with_warmup \
--run_name test \
--overwrite_output_dir \
--output_dir "PlantCaduceus_train_1" \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--gradient_accumulation_steps 4 \
--tokenizer_name 'kuleshov-group/PlantCaduceus_l20' \
--config_name 'kuleshov-group/PlantCaduceus_l20' \
--log_level infoKey parameters:
dataset_name: Your custom dataset or use our Angiosperm datasetmax_steps: Total training steps (adjust based on dataset size)learning_rate: 2E-4 works well for most casesBatch sizes: Adjust based on your GPU memory
Here are the inference speed benchmark results for PlantCaduceus (v1) and PlantCAD2 models:
| Model | Seq Len | Batch | Peak memory (GB) | Seq/s | Tokens/s |
|---|---|---|---|---|---|
| PlantCaduceus_l20 | 512 | 16 | 0.31 | 400.28 | 204,942 |
| PlantCaduceus_l20 | 512 | 32 | 0.56 | 640.86 | 328,118 |
| PlantCaduceus_l20 | 512 | 64 | 1.01 | 663.04 | 339,475 |
| PlantCaduceus_l24 | 512 | 16 | 0.43 | 335.88 | 171,970 |
| PlantCaduceus_l24 | 512 | 32 | 0.75 | 392.83 | 201,140 |
| PlantCaduceus_l24 | 512 | 64 | 1.37 | 407.38 | 208,577 |
| PlantCaduceus_l28 | 512 | 16 | 0.77 | 207.61 | 106,295 |
| PlantCaduceus_l28 | 512 | 32 | 1.27 | 213.99 | 109,563 |
| PlantCaduceus_l28 | 512 | 64 | 2.22 | 219.97 | 112,626 |
| PlantCaduceus_l32 | 512 | 16 | 1.1 | 130.56 | 66,848 |
| PlantCaduceus_l32 | 512 | 32 | 1.71 | 132.62 | 67,902 |
| PlantCaduceus_l32 | 512 | 64 | 2.97 | 135.05 | 69,144 |
| PlantCAD2-Small | 8192 | 16 | 6.56 | 19.61 | 160,653 |
| PlantCAD2-Small | 8192 | 32 | 12.76 | 19.26 | 157,767 |
| PlantCAD2-Small | 8192 | 64 | 24.89 | 19 | 155,649 |
| PlantCAD2-Medium | 8192 | 16 | 9.62 | 6.88 | 56,386 |
| PlantCAD2-Medium | 8192 | 32 | 17.62 | 6.76 | 55,342 |
| PlantCAD2-Medium | 8192 | 64 | 33.62 | 6.79 | 55,636 |
| PlantCAD2-Large | 8192 | 16 | 14.89 | 3.92 | 32,111 |
| PlantCAD2-Large | 8192 | 32 | 26.95 | 3.87 | 31,741 |
| PlantCAD2-Large | 8192 | 64 | 51.09 | 3.89 | 31,833 |
- Balanced Performance: For a good trade-off between speed and accuracy, we recommend using PlantCaduceus_l32 with a 512bp context window.
- Maximum Accuracy: If computational resources are not a constraint, we recommend using PlantCAD2-Large with a context window larger than 1024bp to achieve the best performance.
- Long Sequences: For tasks requiring sequences longer than 512bp, we highly recommend fine-tuning PlantCAD2 models. PlantCAD (v1) models are limited to a 512bp context window, whereas PlantCAD2 supports up to 8192bp.
If you find PlantCAD useful for your research, please consider citing our paper:
- Zhai, J., Gokaslan, A., Schiff, Y., Berthel, A., Liu, Z. Y., Lai, W. L., Miller, Z. R., Scheben, A., Stitzer, M. C., Romay, M. C., Buckler, E. S., & Kuleshov, V. (2025). Cross-species modeling of plant genomes at single nucleotide resolution using a pretrained DNA language model. Proceedings of the National Academy of Sciences, 122(24), e2421738122. https://doi.org/10.1073/pnas.2421738122
- Zhai J., Gokaslan A., Hsu SK., Chen SP., Liu ZY., Marroquin E., Czech E., Cannon B., Berthel A., Romay MC., Pennell M., Kuleshov V.* Buckler ES*. PlantCAD2: A Long-Context DNA Language Model for Cross-Species Functional Annotation in Angiosperms. bioRxiv. 2025. Nov 19. doi: https://doi.org/10.1101/2025.08.27.672609