A PyTorch implementation of Tversky Neural Networks (TNNs), a novel architecture that replaces traditional linear classification layers with Tversky similarity-based projection layers. This implementation faithfully reproduces the key concepts from the original paper and provides optimized, production-ready models for both research and practical applications.
Tversky Neural Networks introduce a fundamentally different approach to neural network classification by leveraging Tversky similarity functions instead of traditional dot-product operations. The key innovation is the Tversky Projection Layer, which:
- Replaces linear layers with learnable prototype-based similarity computations
- Uses asymmetric similarity through Tversky index (ฮฑ, ฮฒ parameters)
- Provides interpretable representations through learned prototypes
- Maintains competitive accuracy while offering explainable decision boundaries
The Tversky Projection Layer computes similarities between input features and learned prototypes using:
S_ฮฉ,ฮฑ,ฮฒ,ฮธ(x, ฯ_k) = |x โฉ ฯ_k|_ฮฉ / (|x โฉ ฯ_k|_ฮฉ + ฮฑ|x \ ฯ_k|_ฮฉ + ฮฒ|ฯ_k \ x|_ฮฉ + ฮธ)
Where:
x
is the input feature vectorฯ_k
are learned prototypesฮฉ
is a learned feature bankฮฑ, ฮฒ
control asymmetric similarity weightingฮธ
provides numerical stability
pip install tversky-nn
git clone https://github.com/akshathmangudi/tnn.git
cd tnn
pip install -e .
- Python 3.10+
- PyTorch 2.0+
- torchvision 0.15+
- numpy
- scikit-learn
- tqdm
- pillow
import torch
from tnn.models import get_resnet_model
from tnn.datasets import get_mnist_loaders
# Create a TverskyResNet model
model = get_resnet_model(
architecture='resnet18',
num_classes=10,
use_tversky=True,
num_prototypes=8,
alpha=0.5,
beta=0.5
)
# Load MNIST dataset
train_loader, val_loader, test_loader = get_mnist_loaders(
data_dir='./data',
batch_size=64
)
# Use the model
x = torch.randn(32, 3, 224, 224) # Batch of images
outputs = model(x) # Shape: (32, 10)
Demonstrate TNN capabilities on the classic XOR problem:
from tnn.models.xor import TverskyXORNet
import torch
# Create XOR model
model = TverskyXORNet(
hidden_dim=8,
num_prototypes=4,
alpha=0.5,
beta=0.5
)
# XOR data
x = torch.tensor([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])
y = torch.tensor([0, 1, 1, 0])
# Forward pass
predictions = model(x)
Train a TverskyResNet on MNIST:
# Train with Tversky layer (recommended)
python train_resnet.py --dataset mnist --architecture resnet18 --epochs 50 --lr 0.01
# Train baseline (linear layer)
python train_resnet.py --dataset mnist --architecture resnet18 --use-linear --epochs 50 --lr 0.01
# Quick test (2 epochs)
python train_resnet.py --dataset mnist --epochs 2 --lr 0.01
python train_xor.py
TO BE UPDATED
Our implementation achieves strong performance across different tasks:
Configuration | Architecture | Classifier | Val Accuracy | Train Accuracy | Training Time |
---|---|---|---|---|---|
Optimized TNN | ResNet18 | Tversky (8 prototypes) | 98.88% | 98.81% | ~32 min (2 epochs) |
Baseline | ResNet18 | Linear | - | - | - |
Key Training Metrics:
- Epoch 1: Training Acc: 89.81%
- Epoch 2: Training Acc: 98.81%, Validation Acc: 98.88%
- Model Size: 11.18M parameters (4,608 in Tversky classifier)
- Convergence: Fast and stable with proper hyperparameters
Metric | Value |
---|---|
Final Test Accuracy | 93.00% |
Class 0 Accuracy | 95.40% |
Class 1 Accuracy | 91.15% |
Training Epochs | 500 |
Convergence | Smooth, interpretable decision boundary |
Visual Results:
- Clear non-linear decision boundary
- Interpretable learned prototypes
- Smooth training curves
- Fast Convergence: With proper hyperparameters (lr=0.01), TNNs converge quickly
- High Accuracy: Achieves 98.88% validation accuracy on MNIST
- Interpretability: Learned prototypes provide insight into model decisions
- Flexibility: Support for multiple ResNet architectures
- Stability: Robust training with mixed precision and proper initialization
- Modular Design: Easy to swap Tversky layers for linear layers
- Multiple Architectures: ResNet18/50/101/152 support
- Pretrained Weights: ImageNet initialization available
- Mixed Precision: Automatic mixed precision training
- Comprehensive Logging: Detailed metrics and checkpointing
# Tversky similarity parameters
alpha: float = 0.5 # Controls importance of false positives
beta: float = 0.5 # Controls importance of false negatives
num_prototypes: int = 8 # Number of learned prototypes
theta: float = 1e-7 # Numerical stability constant
# Architecture options
intersection_reduction = "product" # or "mean"
difference_reduction = "subtractmatch" # or "ignorematch"
feature_bank_init = "xavier" # Feature bank initialization
prototype_init = "xavier" # Prototype initialization
- Double Classification Layer: Fixed architecture that was causing convergence issues
- Softmax Placement: Corrected
apply_softmax=False
in Tversky layer - Learning Rate: Optimized default learning rate from 0.001 โ 0.01
- Initialization: Improved prototype and feature bank initialization
- Extended Datasets: Support for CIFAR-10/100, ImageNet
- Additional Architectures: Vision Transformers, EfficientNets
- Advanced Features:
- Prototype visualization tools
- Attention mechanisms
- Multi-modal support
- Optimization:
- Further convergence improvements
- Memory optimization for large models
- Research Extensions:
- Adaptive ฮฑ, ฮฒ parameters
- Hierarchical prototypes
- Ensemble methods
Our implementation includes several key optimizations discovered during development:
-
Architectural Fixes:
- Removed double classification layer causing gradient flow issues
- Set
apply_softmax=False
in Tversky layer for better optimization - Improved linear layer initialization with Xavier uniform
-
Training Optimizations:
- Increased learning rate to 0.01 for faster convergence
- Mixed precision training for memory efficiency
- Cosine annealing scheduler for better convergence
-
Numerical Stability:
- Proper theta parameter (1e-7) for numerical stability
- Xavier initialization for all learnable parameters
- Gradient clipping and proper loss scaling
We welcome contributions! Areas where help is needed:
- Additional dataset implementations
- New architecture support
- Performance optimizations
- Documentation improvements
- Bug fixes and testing
- Include GPT-2 implementation and benchmarks.
-
Run ResNet18 benchmarks on NABirds Dataset.(DEFERRED UNTIL LATER) - Add benchmarks for different datasets for different weight distributions.
- Unify training configuration instead of keeping several training files for different models.
- Include type checking and other software development process standards to maintain robustness.
If you use this implementation in your research, please cite:
@software{tnn_pytorch,
author = {Akshath Mangudi},
title = {TNN: A PyTorch Implementation of Tversky Neural Networks},
year = {2025},
url = {https://github.com/akshathmangudi/tnn}
}
For the original Tversky Neural Networks paper, please cite:
@article{tversky_neural_networks,
title={Tversky Neural Networks: Psychologically Plausible Deep Learning with Differentiable Tversky Similarity},
author={[Moussa Koulako Bala Doumbouya, Dan Jurafsky, Christopher D. Manning]},
journal={[NeurIPS]},
year={[2025]},
url={[https://arxiv.org/abs/2506.11035]}
}
This project is licensed under the MIT License - see the LICENSE file for details.
- Original Tversky Neural Networks paper authors
- PyTorch team for the excellent deep learning framework
- torchvision for pretrained models and datasets
Built with โค๏ธ and PyTorch | Ready for production use | Optimized for research