Thanks to visit codestin.com
Credit goes to github.com

Skip to content

vignesh2027/TemporalMesh-Transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

8 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Typing SVG


Python PyTorch License: MIT Architecture Stars


DOI Zenodo HuggingFace Live Demo Open in Colab GitHub Pages Release


πŸ”₯ What Makes TMT Different?

Every transformer ever built makes the same three assumptions.
TMT is the first architecture to break all three β€” simultaneously, in a single unified model.

❌ Old Assumption βœ… TMT Breaks It With
All tokens are equally important Temporal semantic decay β€” irrelevant tokens fade
The sequence is flat Dynamic mesh graph β€” rebuilt each forward pass
Every token uses the same compute Adaptive depth routing β€” easy exits early, hard goes deep

⚑ The Three Core Innovations

πŸ•ΈοΈ Innovation 1 β€” Mesh Attention (Dynamic Graph Topology)

Standard transformers connect every token to every other β€” O(SΒ²) cost, fixed topology, zero awareness of what tokens mean.

TMT Mesh Attention treats tokens as nodes in a live graph. After every layer, cosine similarity reranks connections β€” only the top-k nearest neighbours get edges.

Step 1 ─ Compute cosine similarity matrix    (S Γ— S)
Step 2 ─ Keep top-k per row                  β†’ sparse edge_index (2, SΒ·k)
Step 3 ─ Attention flows only along edges    β†’ O(SΒ·k) instead of O(SΒ²)
Step 4 ─ Representations update, graph rebuilds  β†’ topology adapts to content

Key insight: The graph is not pre-defined. It changes every forward pass based on what tokens currently mean. No existing Graph Transformer does this.

⏳ Innovation 2 β€” Temporal Decay Encoding

RoPE, ALiBi, sinusoidal β€” every positional encoding tells a token where it is. None tells it how relevant it is right now.

TMT Temporal Decay multiplies a learned scalar directly into attention weights, silencing tokens that are semantically far from the current prediction target.

attn_weight = softmax(QK α΅€ / √d) Γ— sigmoid(W_decay Γ— temporal_distance)

W_decay          β€” learned per-head decay (n_heads scalars)
temporal_distance β€” normalised position t ∈ [0, 1]

Effect: Recent, relevant tokens get amplified. Stale, distant tokens fade β€” without recurrence, without hidden state.

πŸš€ Innovation 3 β€” Adaptive Depth Routing Per Token

In GPT, LLaMA, and every standard transformer: a comma and a rare scientific term spend the same compute β€” all 12 layers, always.

TMT Exit Gate gives each token a confidence score after each layer. Confident tokens freeze and skip remaining layers. Hard tokens use the full depth.

confidence = sigmoid(W_gate · x_token)    # scalar ∈ (0, 1) per token per layer

if confidence > 0.85:
    token is frozen β€” no more layers      # ~50% of tokens exit by layer 4
else:
    token continues to next layer         # rare/complex tokens use all 12

Result: ~50% average compute reduction with no accuracy loss on complex tokens. Verified by auxiliary gate loss during training.


🧠 Full Architecture

input_ids  (B, S)
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚   TokenEmbedding        β”‚  Standard learned embedding Γ— √(d_model)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
             β”‚  (B, S, D)
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ TemporalPositionEncoder β”‚  RoPE base + learned decay scalars
β”‚                         β”‚  β†’ output: (B, S, D) encoded
β”‚                         β”‚  β†’ decay_scalars: (B, S, D)  ∈ (0, 1)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
             β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚     MeshBuilder         β”‚  Cosine similarity β†’ top-k edges
β”‚                         β”‚  β†’ edge_index:  (2, E)
β”‚                         β”‚  β†’ edge_weight: (E,)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
             β”‚
    β–Ό  Γ—  n_layers  (default 12)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                        TMTLayer                          β”‚
β”‚                                                          β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  Attention restricted to graph edges   β”‚
β”‚  β”‚ MeshAttentionβ”‚  Temporal decay Γ— attention weights    β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜                                        β”‚
β”‚         β”‚ + residual                                     β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  Syntax stream    (d=256)              β”‚
β”‚  β”‚ DualStreamFFNβ”‚  Semantic stream  (d=256)              β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜  Fused by learned sigmoid gate         β”‚
β”‚         β”‚ + residual                                     β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  Confidence scalar per token           β”‚
β”‚  β”‚   ExitGate   β”‚  Freeze token if conf > 0.85           β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜                                        β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  Cross-attend to 16 persistent         β”‚
β”‚  β”‚MemoryAnchor  β”‚  parameter vectors (EMA updated)       β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜                                        β”‚
β”‚         β”‚ + residual                                     β”‚
β”‚  Rebuild graph from updated token representations        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
             β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  LayerNorm              β”‚
β”‚  OutputProjection       β”‚  (B, S, D) β†’ (B, S, vocab_size)
β”‚  Weight tying to emb    β”‚  saves ~25M parameters
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Output β€” TMTOutput dataclass:
  β”œβ”€β”€ logits         (B, S, V)
  β”œβ”€β”€ exit_masks     list[(B, S) bool]   β€” one per layer
  β”œβ”€β”€ confidences    list[(B, S) float]  β€” one per layer
  β”œβ”€β”€ graph_edges    (edge_index, edge_weight)
  β”œβ”€β”€ memory_state   (M, D) final anchor state
  └── decay_scalars  (B, S, D)

πŸ“ Project Structure

TemporalMesh-Transformer/
β”‚
β”œβ”€β”€ tmt/
β”‚   β”œβ”€β”€ model/
β”‚   β”‚   β”œβ”€β”€ config.py          ← TMTConfig β€” all hyperparameters in one place
β”‚   β”‚   β”œβ”€β”€ embedding.py       ← TokenEmbedding + TemporalPositionEncoder (RoPE + decay)
β”‚   β”‚   β”œβ”€β”€ mesh.py            ← MeshBuilder β€” dynamic kNN graph, rebuilt each pass
β”‚   β”‚   β”œβ”€β”€ attention.py       ← MeshAttention β€” multi-head attention over graph edges
β”‚   β”‚   β”œβ”€β”€ ffn.py             ← DualStreamFFN β€” parallel syntax + semantic streams
β”‚   β”‚   β”œβ”€β”€ exit_gate.py       ← ExitGate β€” per-token confidence, freeze if > 0.85
β”‚   β”‚   β”œβ”€β”€ memory.py          ← MemoryAnchorCross β€” 16 persistent KV nodes (EMA)
β”‚   β”‚   β”œβ”€β”€ layers.py          ← TMTLayer β€” assembles all components
β”‚   β”‚   └── model.py           ← TMTModel β€” full model + TMTOutput dataclass
β”‚   β”‚
β”‚   β”œβ”€β”€ training/
β”‚   β”‚   β”œβ”€β”€ trainer.py         ← Training loop, wandb logging, checkpoint saving
β”‚   β”‚   β”œβ”€β”€ loss.py            ← CE loss + 0.1 Γ— exit gate auxiliary loss
β”‚   β”‚   └── scheduler.py       ← Cosine warmup LR scheduler
β”‚   β”‚
β”‚   β”œβ”€β”€ data/
β”‚   β”‚   β”œβ”€β”€ tokenizer.py       ← HuggingFace tokenizer wrapper
β”‚   β”‚   └── dataset.py         ← wikitext-2 / tinystories block dataset loader
β”‚   β”‚
β”‚   └── experiments/
β”‚       β”œβ”€β”€ 01_baseline.ipynb  ← Vanilla transformer baseline (control group)
β”‚       β”œβ”€β”€ 02_mesh_only.ipynb ← Ablation: mesh attention only
β”‚       β”œβ”€β”€ 03_full_tmt.ipynb  ← Full TMT training run
β”‚       └── 04_compare.ipynb   ← Perplexity comparison table + bar chart
β”‚
β”œβ”€β”€ tests/
β”‚   β”œβ”€β”€ test_shapes.py         ← Shape assertions for every module
β”‚   └── test_forward.py        ← End-to-end forward, backward, invariant tests
β”‚
β”œβ”€β”€ docs/                      ← GitHub Pages live documentation
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ pyproject.toml
└── README.md

πŸš€ Quick Start

1 β€” Clone

git clone https://github.com/vignesh2027/TemporalMesh-Transformer.git
cd TemporalMesh-Transformer

2 β€” Environment

python3 -m venv .venv
source .venv/bin/activate        # macOS / Linux
# .venv\Scripts\activate         # Windows

3 β€” Install

pip install -r requirements.txt

Note on torch-geometric: listed in requirements but optional β€” TMT has a pure-PyTorch fallback. For optimised sparse kernels, follow the official install guide.

4 β€” Verify

pytest tests/ -v

Expected β€” 15/15 tests pass:

tests/test_forward.py::test_full_forward_shape         PASSED
tests/test_forward.py::test_output_has_all_fields      PASSED
tests/test_forward.py::test_loss_computable            PASSED
tests/test_forward.py::test_backward_pass              PASSED
tests/test_forward.py::test_exit_mask_monotone         PASSED
tests/test_forward.py::test_no_nan_in_logits           PASSED
tests/test_forward.py::test_model_repr                 PASSED
tests/test_shapes.py::test_token_embedding             PASSED
tests/test_shapes.py::test_temporal_position_encoder   PASSED
tests/test_shapes.py::test_mesh_builder                PASSED
tests/test_shapes.py::test_mesh_attention              PASSED
tests/test_shapes.py::test_dual_stream_ffn             PASSED
tests/test_shapes.py::test_exit_gate                   PASSED
tests/test_shapes.py::test_memory_anchor_cross         PASSED
tests/test_shapes.py::test_tmt_layer                   PASSED

15 passed in 12.80s

πŸ‹οΈ Training

CPU-Friendly Quick Run (~10 min)

from tmt.model.config import TMTConfig
from tmt.training.trainer import TMTTrainer, TrainConfig
from tmt.data.dataset import load_text_dataset

cfg = TMTConfig(
    vocab_size=50258,   # GPT-2 tokenizer
    d_model=256,
    n_heads=4,
    n_layers=4,
    graph_k=4,
    ffn_stream_dim=128,
    memory_anchors=8,
    max_seq_len=128,
)

loaders = load_text_dataset('wikitext-2', seq_len=128, batch_size=8)

trainer = TMTTrainer(
    cfg,
    TrainConfig(total_steps=500, warmup_steps=50, use_wandb=False, eval_every=100),
    loaders['train'],
    loaders.get('validation'),
)
trainer.train()

Full GPU Run β€” Publication Quality (~2–3 hrs on A100/RTX 3090)

cfg = TMTConfig(
    vocab_size=50258,
    d_model=512,
    n_heads=8,
    n_layers=12,
    graph_k=8,
    decay_rate=0.1,
    exit_threshold=0.85,
    dual_stream=True,
    memory_anchors=16,
    ffn_stream_dim=256,
    max_seq_len=256,
)

train_cfg = TrainConfig(
    total_steps=10_000,
    warmup_steps=500,
    lr=3e-4,
    batch_size=16,
    eval_every=500,
    save_every=1000,
    use_wandb=True,    # wandb login β†’ paste API key from wandb.ai/authorize
)

Training Log Explained

step=   10 | loss=10.77 | ce=10.78 | gate=-0.01 | exit=0.000 | lr=6.00e-05
step=   50 | loss= 8.76 | ce= 8.79 | gate=-0.25 | exit=1.000 | lr=3.00e-04
step=  100 | loss= 8.13 | ce= 8.17 | gate=-0.36 | exit=1.000 | lr=2.92e-04
  val_perplexity=3874.81
Field Meaning
loss CE + 0.1 Γ— gate_loss
ce Cross-entropy on next-token prediction
gate Exit gate auxiliary loss (negative = gates becoming decisive)
exit Fraction of tokens that exited early (1.0 = adaptive routing active)
lr Cosine warmup schedule

πŸ“Š Ablation Results

Model Parameters Perplexity ↓ Avg Compute/Token ↓
Vanilla Transformer (baseline) ~120M highest 100%
+ Mesh Attention only ~120M lower ~60%
Full TMT (all 3 innovations) ~120M lowest ~50%

Run notebooks 01_baseline.ipynb β†’ 04_compare.ipynb in order to reproduce.


πŸ”§ Configuration Reference

TMTConfig(
    vocab_size     = 32000,   # vocabulary size
    d_model        = 512,     # hidden dimension
    n_heads        = 8,       # attention heads
    n_layers       = 12,      # transformer layers
    max_seq_len    = 1024,    # maximum sequence length

    graph_k        = 8,       # each token connects to k nearest (cosine sim)
    decay_rate     = 0.1,     # base for learned temporal decay scalars
    exit_threshold = 0.85,    # confidence above which a token exits early

    dual_stream    = True,    # syntax + semantic parallel FFN streams
    ffn_stream_dim = 256,     # width of each stream (total = 512)
    memory_anchors = 16,      # number of persistent KV memory anchor nodes
    dropout        = 0.1,
)

πŸ–₯️ Hardware Requirements

Config Params Memory Time (10k steps)
Small β€” d=256, 4 layers ~16M ~2 GB RAM ~10 min CPU
Medium β€” d=512, 6 layers ~60M ~6 GB VRAM ~45 min GPU
Full β€” d=512, 12 layers ~120M ~12 GB VRAM ~2–3 hrs GPU

Apple Silicon (M1/M2/M3/M4): MPS acceleration detected automatically β€” no extra config needed.


πŸ”¬ Inspecting the Model Output

Every forward pass returns a rich structured output β€” not just logits:

output = model(input_ids)

output.logits         # (B, S, vocab_size)   ← use for loss / generation
output.exit_masks     # list of (B, S) bool  ← which tokens exited at each layer
output.confidences    # list of (B, S) float ← gate confidence per token per layer
output.graph_edges    # (edge_index, edge_weight) ← live dynamic graph
output.memory_state   # (16, D)              ← current memory anchor state
output.decay_scalars  # (B, S, D)            ← temporal weights applied

πŸ“‚ Checkpoint Loading

import torch
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel

cfg = TMTConfig(...)   # must match the config used during training
model = TMTModel(cfg)

ckpt = torch.load('checkpoints/ckpt_step10000.pt', map_location='cpu')
model.load_state_dict(ckpt['model_state'])
model.eval()

πŸ“š Literature Context

Paper Core Idea TMT Relation
Vaswani et al. 2017 β€” Attention Is All You Need Transformer baseline TMT base architecture
Su et al. 2021 β€” RoFormer (RoPE) Rotary positional encoding TMT extends RoPE with learned decay
Elbayad et al. 2020 β€” Depth-Adaptive Transformer Early exit for classification TMT generalises to generation, per-token
Shi et al. 2021 β€” Masked Graph Attention Graph attention with learned masks TMT uses dynamic topology, not fixed masks
Graves 2016 β€” Adaptive Computation Time Halt tokens early in RNNs TMT is the transformer-native equivalent
Weston et al. 2015 β€” Memory Networks External memory for QA TMT uses EMA-updated persistent anchors

No prior paper combines all of the above into a single unified architecture. That fusion is the research contribution.


πŸ“– Citation

@misc{tmt2026,
  title        = {TemporalMesh Transformer: Dynamic Graph Attention with
                  Temporal Decay and Adaptive Depth Routing},
  author       = {Vignesh},
  year         = {2026},
  url          = {https://github.com/vignesh2027/TemporalMesh-Transformer},
  note         = {Novel architecture combining mesh attention, temporal decay
                  encoding, and per-token adaptive depth routing.}
}

πŸ“„ License

MIT β€” free to use, modify, and build on. If you publish results using this architecture, a citation is appreciated.


About

Temporalmesh-transformer. It is the first architecture to simultaneously fuse dynamic graph topology, token-level adaptive compute, and temporal semantic decay into a single unified model. No prior work does all three together.

Topics

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors