torch-diode is a library for programmatically altering the performance-relevant decisions made by torch.compile. It makes it easy to gather data on the outcomes of decisions, and then train Machine Learning models on that data. It initially focuses on Matmul Kernel selection, but it will be expanded to other decisions in the future. Documentation
- Developers looking to adapt the compilation of their model to their specific situation.
- Hardware Vendors looking to optimize
torch.compileheuristics for their hardware. - OSS Contributors looking to add support for less popular hardware.
- Pre-Trained Models: Profit from community efforts to gather data and train models.
- Data collection: Gather data from torch external interfaces.
- Stable Type Definitions: storing data from the external interfaces.
- Model Training Code: Train ML models on the gathered data and contribute back to the
torchcommunity.
- Matrix Multiplication Kernel Prediction: Predict the runtime of matrix multiplication kernels. The results of this model are enabled in
fast-autotune.
If you want to get access to the pre-trained performance models, as well as the libraries, install torch-diode:
$ pip install torch-diode
And then import torch_diode in python:
import torch_diode
This import has several side-effects, each of which are dependent on the success of the previous step:
- Attempt to import
torch. - Register dummy models to the relevant
torch.compileinterfaces. - For each registration that is successful, it will load the actual model and register it.
- Enable the configs in
torch.compilethat engage the models.
diode requires nightly pytorch, or pytorch 2.9 or later.
For developers who don't want these side effects, simply installing torch-diode-lib will get access to the library.
$ pip install torch-diode-lib
The import remains the same:
import torch_diode
git clone https://github.com/exclamaforte/diode.git
cd diode
pip install .Models are organized in a structured directory format:
trained_models/
├── <model_purpose>/
│ ├── <model_name>.pt
│ └── ...
└── <other_model_file>.pt
Example:
trained_models/
├── matmul_kernel_runtime_prediction/
│ ├── v1_model.pt
│ └── v2_model.pt
└── matmul_model_exhaustive.pt
The main entry point is in workflows.
- torch-diode: Full package with auto-registration to PyTorch Inductor
- torch-diode-lib: Library-only version without auto-registration