Thanks to visit codestin.com
Credit goes to github.com

Skip to content

plantcad/plantcad

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Static Badge GitHub Repo stars GitHub Issues or Pull Requests PlantCAD PlantCAD2 Hugging Face Hugging Face Downloads

logo

We’re excited to announce PlantCAD2 🌱 — our new DNA foundation model for angiosperms.

In addition, we’re also releasing a collection of LoRA fine-tuned models 🎯, tailored for key downstream tasks including accessible chromatin, gene expression, and protein translation.

  • Explore the fine-tuned PlantCAD2 models here

  • Explore the zero-shot evaluation of PlantCAD2 models here

  • Explore the post-training or pre-training of PlantCAD2 here

Table of Contents

PlantCaduceus, with its short name of PlantCAD, is a plant DNA LM based on the Caduceus architecture, which extends the efficient Mamba linear-time sequence modeling framework to incorporate bi-directionality and reverse complement equivariance, specifically designed for DNA sequences. PlantCAD is pre-trained on a curated dataset of 16 Angiosperm genomes. PlantCAD showed state-of-the-art cross species performance in predicting TIS, TTS, Splice Donor and Splice Acceptor. The zero-shot of PlantCAD enables identifying genome-wide deleterious mutations and known causal variants in Arabidopsis, Sorghum and Maize.

Quick Start

New to PlantCAD? Try our Google Colab demo - no installation required!

For local usage: See installation instructions here, then use notebooks/examples.ipynb to get started.

Model summary

Pre-trained models have been uploaded to HuggingFace 🤗: PlantCAD and PlantCAD2. Here is the comparison between PlantCAD (v1) and PlantCAD2 models:

Model Sequence Length Model Size Embedding Size
PlantCAD
PlantCaduceus_l20 512bp 20M 384
PlantCaduceus_l24 512bp 40M 512
PlantCaduceus_l28 512bp 128M 768
PlantCaduceus_l32 512bp 225M 1024
PlantCAD2
PlantCAD2-Small 8192bp 88M 768
PlantCAD2-Medium 8192bp 311M 1024
PlantCAD2-Large 8192bp 694M 1536

Note: For PlantCAD, the maximum sequence length is 512bp. For PlantCAD2, it is 8,192bp.

Prerequisites and System Requirements

For Google Colab: Just a Google account - GPU runtime recommended (free tier available)

For Local Installation: NVIDIA GPU is required!

Installation

Option 1: Google Colab (Recommended for beginners)

No installation required! Just open our PlantCAD Google Colab notebook and start analyzing your data.

Option 2: Local installation

See our Local Installation Guide.

Basic Usage

Exploring model inputs and outputs

The easiest way to start is with our example notebook: notebooks/examples.ipynb

Quick example - Get sequence embeddings:

import torch
from mamba_ssm import Mamba
from transformers import AutoTokenizer, AutoModelForMaskedLM
import pandas as pd

device = 'cuda:0'
# Test PlantCAD model loading
tokenizer = AutoTokenizer.from_pretrained('kuleshov-group/PlantCaduceus_l32')
model = AutoModelForMaskedLM.from_pretrained('kuleshov-group/PlantCaduceus_l32', trust_remote_code=True)
model.to(device)
# Example plant DNA sequence (512bp max)
sequence = "CTTAATTAATATTGCCTTTGTAATAACGCGCGAAACACAAATCTTCTCTGCCTAATGCAGTAGTCATGTGTTGACTCCTTCAAAATTTCCAAGAAGTTAGTGGCTGGTGTGTCATTGTCTTCATCTTTTTTTTTTTTTTTTTAAAAATTGAATGCGACATGTACTCCTCAACGTATAAGCTCAATGCTTGTTACTGAAACATCTCTTGTCTGATTTTTTCAGGCTAAGTCTTACAGAAAGTGATTGGGCACTTCAATGGCTTTCACAAATGAAAAAGATGGATCTAAGGGATTTGTGAAGAGAGTGGCTTCATCTTTCTCCATGAGGAAGAAGAAGAATGCAACAAGTGAACCCAAGTTGCTTCCAAGATCGAAATCAACAGGTTCTGCTAACTTTGAATCCATGAGGCTACCTGCAACGAAGAAGATTTCAGATGTCACAAACAAAACAAGGATCAAACCATTAGGTGGTGTAGCACCAGCACAACCAAGAAGGGAAAAGATCGATGATCG"
device = 'cuda:0'
# Get embeddings
encoding = tokenizer.encode_plus(
            sequence,
            return_tensors="pt",
            return_attention_mask=False,
            return_token_type_ids=False
        )

input_ids = encoding["input_ids"].to(device)
with torch.inference_mode():
    outputs = model(input_ids=input_ids, output_hidden_states=True)

embeddings = outputs.hidden_states[-1]
print(f"Embedding shape: {embeddings.shape}")  # [batch_size, seq_len, embedding_dim]

embeddings = embeddings.to(torch.float32).cpu().numpy()
# Given that PlantCaduceus has bi-directionality and reverse complement equivariance, so the first half of embedding is for forward sequences and the sencond half is for reverse complemented sequences, we need to average the embeddings before working on downstream classifier

hidden_size = embeddings.shape[-1] // 2
forward = embeddings[..., 0:hidden_size]
reverse = embeddings[..., hidden_size:]
reverse = reverse[..., ::-1]
averaged_embeddings = (forward + reverse) / 2
print(averaged_embeddings.shape)

Zero-shot Scoring of Genomic Variants and Regions

The zero_shot_score.py script now provides unified functionality to estimate the functional impact of genetic variants or score genomic regions using PlantCAD's log-likelihood scores. It supports two primary modes:

  1. Variant Scoring (VCF Input): Scores specific genetic variants provided in a VCF file.
  2. Genome-Wide Region Scoring (BED Input): Calculates log-likelihood ratios for all positions within specified genomic regions (BED file).
# Download example reference genome
wget https://download.maizegdb.org/Zm-B73-REFERENCE-NAM-5.0/Zm-B73-REFERENCE-NAM-5.0.fa.gz
gunzip Zm-B73-REFERENCE-NAM-5.0.fa.gz
# VCF Input Mode
# --- Example: Variant Scoring (VCF Input) ---
# Estimate impact of specific variants from a VCF file.
# Note: Only the first 8 columns of the VCF file (CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO) are strictly required.
python src/zero_shot_score.py \
    -input-vcf examples/example_maize_snp.vcf \
    -input-fasta Zm-B73-REFERENCE-NAM-5.0.fa \
    -output scored_variants.vcf \
    -model 'kuleshov-group/PlantCaduceus_l32' \
    -device 'cuda:0'

# Expected output for VCF mode:
# - A new VCF file ('scored_variants.vcf') with PlantCAD scores added to the INFO field.
# - Scores represent log-likelihood ratios between reference and alternative alleles.
#   Low negative scores indicate potentially more deleterious mutations.
# BED Input Mode
# --- Example: Genome-Wide Region Scoring (BED Input) ---
# Calculate log-likelihood ratios for all positions within specified BED regions.
# Note: You would need an example BED file for this.
# For demonstration, creating a dummy BED file:
echo -e "chr1\t1000\t1010\nchr1\t2000\t2015" > examples/example_regions.bed

python src/zero_shot_score.py \
    -input-bed examples/example_regions.bed \
    -input-fasta Zm-B73-REFERENCE-NAM-5.0.fa \
    -output genome_wide_scores.tsv \
    -model 'kuleshov-group/PlantCaduceus_l32' \
    -device 'cuda:0' \
    -step-size 1 \
    -aggregation average \
    -use-masking \
    -output-raw-prob

# Expected output for BED mode:
# - A tab-separated file ('genome_wide_scores.tsv') containing scores for each position.
# - Output includes chromosome, start, end, reference allele, aggregated score,
#   and optionally raw probabilities for all four nucleotides.
# - `-step-size`: Number of positions to analyze per window, if step size is greater than 1, we recommend turn off masking!!!
# - `-aggregation`: How to aggregate alternative allele scores.
#   - `'max'`: Reports the maximum log-likelihood ratio among all three alternative alleles relative to the reference.
#   - `'average'`: Reports the average log-likelihood ratio across all three alternative alleles relative to the reference.
#   - `'all'`: Reports the individual log-likelihood ratios for each of the three alternative alleles relative to the reference.
# - `-use-masking`: Whether to mask the central position(s) during inference.
# - `-output-raw-prob`: Include raw probabilities in the output.

When analyzing the entire genome or large genomic regions, the -step-size parameter is very important for speeding up the analysis. For a detailed guide on this trade-off between speed and accuracy, see here.

In-silico mutagenesis pipeline

For large-scale simulation and analysis of genetic variants, we provide a comprehensive in-silico mutagenesis pipeline. See pipelines/in-silico-mutagenesis/README.md for detailed instructions.

Advanced Usage

Training XGBoost classifiers

Train custom classifiers on top of PlantCAD embeddings for specific annotation tasks (e.g., TIS, TTS, splice sites).

Purpose: Fine-tune prediction performance for specific annotation tasks using supervised learning.

Data format: Training data should follow the format used in our cross-species annotation dataset.

python src/train_XGBoost.py \
    -train train.tsv \
    -valid valid.tsv \
    -test test_rice.tsv \
    -model 'kuleshov-group/PlantCaduceus_l20' \
    -output ./output \
    -device 'cuda:0'

Expected outputs:

  • Trained XGBoost classifier (.json file)
  • Performance metrics on validation/test sets
  • Feature importance analysis

Using pre-trained XGBoost classifiers

We provide pre-trained XGBoost classifiers for common annotation tasks in the classifiers directory.

Available classifiers:

  • TIS (Translation Initiation Sites)
  • TTS (Translation Termination Sites)
  • Splice donor/acceptor sites
python src/predict_XGBoost.py \
    -test test_rice.tsv \
    -model 'kuleshov-group/PlantCaduceus_l20' \
    -classifier classifiers/PlantCaduceus_l20/TIS_XGBoost.json \
    -device 'cuda:0' \
    -output ./output

Expected output: Predictions with confidence scores for each sequence in your test data.

Development and Training

Pre-training PlantCAD

For advanced users who want to pre-train PlantCAD or PlantCAD2 models from scratch or fine-tune on custom datasets.

Requirements:

  • Large computational resources (multi-GPU recommended)
  • WandB account for experiment tracking
  • Custom genomic dataset in HuggingFace format

Basic pre-training command:

WANDB_PROJECT=PlantCAD python src/HF_pre_train.py \
    --do_train \
    --report_to wandb \
    --prediction_loss_only True \
    --remove_unused_columns False \
    --dataset_name 'kuleshov-group/Angiosperm_16_genomes' \
    --soft_masked_loss_weights_train 0.1 \
    --soft_masked_loss_weights_evaluation 0.0 \
    --weight_decay 0.01 \
    --optim adamw_torch \
    --dataloader_num_workers 16 \
    --preprocessing_num_workers 16 \
    --seed 32 \
    --save_strategy steps \
    --save_steps 1000 \
    --evaluation_strategy steps \
    --eval_steps 1000 \
    --logging_steps 10 \
    --max_steps 120000 \
    --warmup_steps 1000 \
    --save_total_limit 20 \
    --learning_rate 2E-4 \
    --lr_scheduler_type constant_with_warmup \
    --run_name test \
    --overwrite_output_dir \
    --output_dir "PlantCaduceus_train_1" \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --gradient_accumulation_steps 4 \
    --tokenizer_name 'kuleshov-group/PlantCaduceus_l20' \
    --config_name 'kuleshov-group/PlantCaduceus_l20' \
    --log_level info

Key parameters:

  • dataset_name: Your custom dataset or use our Angiosperm dataset
  • max_steps: Total training steps (adjust based on dataset size)
  • learning_rate: 2E-4 works well for most cases
  • Batch sizes: Adjust based on your GPU memory

Model Recommendations

Inference speed

Here are the inference speed benchmark results for PlantCaduceus (v1) and PlantCAD2 models:

Model Seq Len Batch Peak memory (GB) Seq/s Tokens/s
PlantCaduceus_l20 512 16 0.31 400.28 204,942
PlantCaduceus_l20 512 32 0.56 640.86 328,118
PlantCaduceus_l20 512 64 1.01 663.04 339,475
PlantCaduceus_l24 512 16 0.43 335.88 171,970
PlantCaduceus_l24 512 32 0.75 392.83 201,140
PlantCaduceus_l24 512 64 1.37 407.38 208,577
PlantCaduceus_l28 512 16 0.77 207.61 106,295
PlantCaduceus_l28 512 32 1.27 213.99 109,563
PlantCaduceus_l28 512 64 2.22 219.97 112,626
PlantCaduceus_l32 512 16 1.1 130.56 66,848
PlantCaduceus_l32 512 32 1.71 132.62 67,902
PlantCaduceus_l32 512 64 2.97 135.05 69,144
PlantCAD2-Small 8192 16 6.56 19.61 160,653
PlantCAD2-Small 8192 32 12.76 19.26 157,767
PlantCAD2-Small 8192 64 24.89 19 155,649
PlantCAD2-Medium 8192 16 9.62 6.88 56,386
PlantCAD2-Medium 8192 32 17.62 6.76 55,342
PlantCAD2-Medium 8192 64 33.62 6.79 55,636
PlantCAD2-Large 8192 16 14.89 3.92 32,111
PlantCAD2-Large 8192 32 26.95 3.87 31,741
PlantCAD2-Large 8192 64 51.09 3.89 31,833

Which model to use?

Variant Effect Analysis

  • Balanced Performance: For a good trade-off between speed and accuracy, we recommend using PlantCaduceus_l32 with a 512bp context window.
  • Maximum Accuracy: If computational resources are not a constraint, we recommend using PlantCAD2-Large with a context window larger than 1024bp to achieve the best performance.

Other Downstream Tasks

  • Long Sequences: For tasks requiring sequences longer than 512bp, we highly recommend fine-tuning PlantCAD2 models. PlantCAD (v1) models are limited to a 512bp context window, whereas PlantCAD2 supports up to 8192bp.

Citations

If you find PlantCAD useful for your research, please consider citing our paper:

  • Zhai, J., Gokaslan, A., Schiff, Y., Berthel, A., Liu, Z. Y., Lai, W. L., Miller, Z. R., Scheben, A., Stitzer, M. C., Romay, M. C., Buckler, E. S., & Kuleshov, V. (2025). Cross-species modeling of plant genomes at single nucleotide resolution using a pretrained DNA language model. Proceedings of the National Academy of Sciences, 122(24), e2421738122. https://doi.org/10.1073/pnas.2421738122
  • Zhai J., Gokaslan A., Hsu SK., Chen SP., Liu ZY., Marroquin E., Czech E., Cannon B., Berthel A., Romay MC., Pennell M., Kuleshov V.* Buckler ES*. PlantCAD2: A Long-Context DNA Language Model for Cross-Species Functional Annotation in Angiosperms. bioRxiv. 2025. Nov 19. doi: https://doi.org/10.1101/2025.08.27.672609

Releases

No releases published

Packages

 
 
 

Contributors 4

  •  
  •  
  •  
  •