Neural Ordinary
Differential Equations
Ricky T. Q. Chen*, Yulia Rubanova*, Jesse Bettencourt*, David Duvenaud
University of Toronto
Background: Ordinary Differential Equations (ODEs)
- Model the instantaneous change of a state.
(explicit form)
- Solving an initial value problem (IVP) corresponds to integration.
(solution is a trajectory)
- Euler method approximates with small steps:
Residual Networks interpreted as an ODE Solver
- Hidden units look like:
- Final output is the composition:
Haber & Ruthotto (2017). E (2017).
Residual Networks interpreted as an ODE Solver
- Hidden units look like:
- Final output is the composition:
- This can be interpreted as an Euler
discretization of an ODE.
- In the limit of smaller steps:
Haber & Ruthotto (2017). E (2017).
Deep Learning as Discretized Differential Equations
Many deep learning networks can be interpreted as ODE solvers.
Network Fixed-step Numerical Scheme
ResNet, RevNet, ResNeXt, etc. Forward Euler
Lu et al. (2017)
Chang et al. (2018)
PolyNet Approximation to Backward Euler
Zhu et al. (2018)
FractalNet Runge-Kutta
DenseNet Runge-Kutta
Deep Learning as Discretized Differential Equations
Many deep learning networks can be interpreted as ODE solvers.
Network Fixed-step Numerical Scheme
ResNet, RevNet, ResNeXt, etc. Forward Euler
Lu et al. (2017)
Chang et al. (2018)
PolyNet Approximation to Backward Euler
Zhu et al. (2018)
FractalNet Runge-Kutta
DenseNet Runge-Kutta
But:
(1) What is the underlying dynamics?
(2) Adaptive-step size solvers provide better error handling.
“Neural” Ordinary Differential Equations
Instead of y = F(x),
“Neural” Ordinary Differential Equations
Instead of y = F(x), solve y = z(T)
given the initial condition z(0) = x.
Parameterize
“Neural” Ordinary Differential Equations
Instead of y = F(x), solve y = z(T)
given the initial condition z(0) = x.
Parameterize
Solve the dynamic using any
black-box ODE solver.
- Adaptive step size.
- Error estimate.
- O(1) memory learning.
Backprop without knowledge of the ODE Solver
Ultimately want to optimize some loss
Backprop without knowledge of the ODE Solver
Ultimately want to optimize some loss
Naive approach: Know the solver. Backprop through the solver.
- Memory-intensive.
- Family of “implicit” solvers perform inner optimization.
Backprop without knowledge of the ODE Solver
Ultimately want to optimize some loss
Naive approach: Know the solver. Backprop through the solver.
- Memory-intensive.
- Family of “implicit” solvers perform inner optimization.
Our approach: Adjoint sensitivity analysis. (Reverse-mode Autodiff.)
- Pontryagin (1962).
+ Automatic differentiation.
+ O(1) memory in backward pass.
Continuous-time Backpropagation
Residual network. Adjoint method. Define:
Forward:
Backward:
Params:
Continuous-time Backpropagation
Residual network. Adjoint method. Define:
Forward: Forward:
Backward:
Params:
Continuous-time Backpropagation
Residual network. Adjoint method. Define:
Forward: Forward:
Backward: Backward:
Adjoint State Adjoint DiffEq
Params:
Continuous-time Backpropagation
Residual network. Adjoint method. Define:
Forward: Forward:
Backward: Backward:
Adjoint State Adjoint DiffEq
Params: Params:
A Differentiable Primitive for AutoDiff
Forward:
Backward:
A Differentiable Primitive for AutoDiff
Forward:
Backward:
A Differentiable Primitive for AutoDiff
Don’t need to store layer activations for reverse pass - just follow dynamics in
reverse!
Reversible networks (Gomez et al. 2018) also only require O(1)-memory, but
require very specific neural network architectures with partitioned dimensions.
Reverse versus Forward Cost
- Empirically, reverse
pass roughly half as
expensive as forward
pass.
-
- Adapts to instance
difficulty.
-
- Num evaluations can
be viewed as number of
layers in neural nets.
NFE = Number of Function Evaluations.
Dynamics Become Increasingly Complex
- Dynamics become
more demanding to
compute during
training.
- Adapts computation
time according to
complexity of diffeq.
In contrast, Chang et al. (ICLR 2018)
explicitly add layers during training.
Continuous-time RNNs for Time Series Modeling
- We often want arbitrary measurement times, ie. irregular time intervals.
- Can do VAE-style inference with a latent ODE.
ODEs vs Recurrent Neural Networks (RNNs)
- RNNs learn very
stiff dynamics,
have exploding
gradients.
-
- Whereas ODEs
are guaranteed
to be smooth.
Continuous Normalizing Flows
Instantaneous Change of variables (iCOV):
- For a Lipschitz continuous function
Continuous Normalizing Flows
Instantaneous Change of variables (iCOV):
- For a Lipschitz continuous function
- In other words,
Continuous Normalizing Flows
Instantaneous Change of variables (iCOV):
- For a Lipschitz continuous function
- In other words,
With an
invertible F:
Continuous Normalizing Flows
1D: 2D: Data Discrete-NF CNF
Is the ODE being correctly solved?
Stochastic Unbiased Log Density
Stochastic Unbiased Log Density
Can further reduce time complexity using stochastic estimators.
Grathwohl et al. (2019)
FFJORD - Stochastic Continuous Flows
MNIST - Model Samples CIFAR10 - Model Samples
Grathwohl et al. (2019)
Variational Autoencoders with FFJORD
ODE Solving as a Modeling Primitive
Adaptive-step solvers with O(1) memory backprop.
github.com/rtqichen/torchdiffeq
Future directions we’re currently working on:
- Latent Stochastic Differential Equations.
- Network architectures suited for ODEs.
- Regularization of dynamics to require fewer evaluations.
Co-authors:
Yulia Rubanova Jesse Bettencourt David Duvenaud
Thanks!
Extra Slides
Latent Space Visualizations
• Released an implementation of reverse-mode
autodiff through black-box ODE solvers.
• Solves a system of size 2D + K + 1.
• In contrast, forward-mode implementation
solves a system of size D^2 + KD.
• Tensorflow has Dormand-Prince-Shampine
Runge-Kutta 5(4) implemented, but uses
naive autodiff for backpropagation.
How much precision is needed?
Explicit Error Control
- More fine-grained
control than
low-precision floats.
- Cost scales with
instance difficulty.
NFE = Number of Function Evaluations.
Computation Depends on Complexity of Dynamics
- Time cost is dominated by
evaluation of dynamics f.
NFE = Number of Function Evaluations.
Why not use an ODE solver as modeling primitive?
- Solving an ODE is expensive.
Future Directions
- Stochastic differential equations and Random ODEs. Approximates stochastic
gradient descent.
- Scaling up ODE solvers with machine learning.
- Partial differential equations.
- Graphics, physics, simulations.