#TODO:
- coverage is not calculated correctly
This repository implements a modular framework for training photonic quantum neural networks with uncertainty quantification capabilities. The code is based on the photonic quantum memristor paper and provides both discrete-phase and continuous-swipe implementations for circuit simulation.
The codebase is organized into the following modules:
src/autograd.py: Implements parameter-shift rule (PSR) for photonic circuitssrc/circuits.py: Circuit construction for encoding and memristor components (array-based parameter design)src/data.py: Data generation and processing utilities (multiple synthetic functions)src/loss.py: Custom loss functions and PyTorch model implementationssrc/simulation.py: Circuit simulation with discrete and continuous modessrc/training.py: Training algorithms with PyTorch optimizationsrc/utils.py: Configuration and utility functions
# Clone the repository
git clone https://github.com/username/uq-qnn.git
cd uq-qnn
# Install dependencies
pip install -r requirements.txtRun the main script to train the quantum neural network:
python main.py --n-samples 1000--continuous: Use continuous-swipe training mode--n-samples: Number of samples for circuit simulation--epochs: Number of training epochs--lr: Learning rate--measured-data: Path to measured data pickle file--datafunction: Synthetic data function to use (see below)--n-phases: Number of phase parameters in the memristor circuit (default: 2)--circuit-type: Circuit architecture to use ('memristor' or 'clements')--n-modes: Number of modes for Clements architecture (default: 3)--encoding-mode: Mode to apply encoding to (default: 0)--target-mode: Target output mode(s) as comma-separated list (e.g., '2,3')
See python main.py --help for all available options.
The framework includes multiple synthetic data functions for regression tasks:
quartic_data: Standard x⁴ functionsinusoid_data: Sinusoidal function (sin(2πx) * 0.5 + 0.5)multi_modal_data: Sum of Gaussian peaksstep_function_data: Smooth step function using tanhoscillating_poly_data: Oscillating polynomial (x³ - 0.5x² + 0.1sin(15x))damped_cosine_data: Damped cosine wave
Run the examples/function_comparison.py script to compare model performance across all functions.
The repository includes several example scripts:
examples/simple_regression.py- Basic regression with uncertainty quantificationexamples/function_comparison.py- Compare performance across different synthetic functionsexamples/circuit_comparison.py- Compare memristor vs. Clements circuit architectures
For the Clements architecture example, run:
python examples/circuit_comparison.pyThis script demonstrates how to use both architectures on the same dataset and compares their performance.
The framework supports two distinct circuit architectures:
The photonic memristor circuit implementation uses an array-based approach:
encoding_circuit: Builds a 2-mode encoding circuit with a phase shiftermemristor_circuit: Takes an array of phases instead of individual parametersbuild_circuit: Combines encoding and memristor circuits with array-based parameters
Parameter structure for memristor circuit:
params = [phi1, phi3, w]
where phi1 and phi3 are phase parameters and w is the memory weight parameter.
The Clements architecture provides a more flexible, scalable approach:
- Configurable number of modes (use
--n-modesoption) - Mesh of Mach-Zehnder Interferometers (MZIs) in a rectangular grid pattern
- Each MZI has two phase shifters (internal and external)
- Supports arbitrary-sized photonic neural networks
Parameter structure for Clements circuit:
params = [phi1_int, phi1_ext, phi2_int, phi2_ext, ..., phiN_int, phiN_ext, w]
where each MZI has an internal phase (phi_int) and external phase (phi_ext), and w is the memory weight parameter.
The number of phase parameters is automatically calculated as n_modes * (n_modes - 1) based on the number of modes.
Note: For Clements architecture, you must ensure that
n_modes≥ 2 andencoding_mode<n_modes. The target mode(s) must also be valid for the given number of modes.
-
initial study of QNNs with UQ in simulation, use 1-d regression function from photonic quantum memristor paper from Iris (ask Iris about simulations)
-
use existing QNNs works for regression and classification to check if "inherent Quantum"-UQ adds a benefit:
- classification task: over multiple forward passes compute the mean of logits and then take the softmax, as e.g. in Link
- compute Entropy over softmax outputs as in Link
- on a validation set (it should kinda all be iid and similar splits) compute the quantiles of the entropies (per prediction) as in Pandas Link. An example is given in Link - code box 12, titled "Selective Prediction"
:::info Selective prediction in a nutshell: UQ evaluation with selective prediction, as introduced in Paper. Here, samples with with a predictive uncertainty (classification: entropies, regression: standard deviation) above a given threshold are omitted from prediction and referred to an expert and optionally another method. If the corresponding UQ method has higher uncertainties for inaccurate predictions, leaving out the predictions for these samples should increase the overall accuracy. This could resemble a deployment scenario, where predictions are monitored and if the predictive uncertainties surpass a given threshold, the sample is referred to an expert and/or additionally evaluated with another method. Instead of a fixed threshold on the predictive uncertainties across methods, one can chose a UQ specific threshold based on the 0.8 quantile of predictive uncertainties computed on a held out validation dataset for each method. These method-specific thresholds are then utilized on the separate test set for which we report results.
:::
-
regression task: over multiple forward passes compute the mean of predictions Link and standard deviation, as e.g. in Link
-
on a validation set (it should kinda all be iid and similar splits) compute the quantiles of the standard deviations (per prediction) as in Pandas Link. An example is given in Link - code box 6, titled "selective prediction thresholds based on validation set"
-
compute accuracy according to different quantile UQ thresholds on test datasets, plot should look like:
