This implementation is based on DiGress's excellent work.
The code has been tested with:
- PyTorch 2.2
- CUDA 11.8
- PyTorch Geometric 2.3.1
conda create -n THGD python=3.10 -y
pip install torch==2.2 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
pip install torch-scatter==2.1.2
pip install -e .
All executions are launched via python3 main.py
. Refer to Hydra documentation for parameter customization.
The model training consists of two sequential phases:
- Coarse Model Training (learns high-level molecular topologies)
- Refinement Model Training (recovers atomic-level details)
Pre-configured training files are provided in configs/
. Example for GuacaMol:
- Train coarse model:
cd src
python main.py +guacamol_exp=coarse
- Train refinement model:
python main.py +guacamol_exp=refine
Notes:
- Dataset preprocessing (feature extraction for coarse/expanded graphs) runs on CPU and typically takes ~30 minutes (device-dependent). It is recommended to use our preprocessed dataset here:
- This preprocessing must complete before training begins.
Sampling requires:
- Trained coarse model checkpoint
- Trained refinement model checkpoint
- Precomputed optimal prior distribution tensor (obtain by preprocessing)
Use pre-configured sampling profiles:
python sample.py sample=guacamol
❗ Configuration Alignment:
Ensure the checkpoint's training config matches your sampling config (e.g., coarse_cfg
in YAML must correspond to the checkpoint's original training config).
Modify the scaffold
field in the sampling config.
Available checkpoint for three datasets (place downloaded files and place in checkpoints/
folder):
https://drive.google.com/file/d/1nOrOq6Jf7adqG0vbitnbFPJ9_-AzdHcX/view?usp=sharing
Dataset | Coarse Model (NLL) | Refinement Model (NLL) |
---|---|---|
ZINC250k | 44.9 | 82.9 |
MOSES | 37.5 | 72.0 |
GuacaMol | 59.8 | 94.6 |
We provide few example outputs in sample-results/
.