Thanks to visit codestin.com
Credit goes to www.scribd.com

0% found this document useful (0 votes)
3 views97 pages

Recurrent Neural Networks LSTMS, Transformers, Graph Neural Networks

The document discusses advanced neural network architectures, focusing on Recurrent Neural Networks (RNNs), Long Short-Term Memory (LSTM) networks, Transformers, and Graph Neural Networks. It highlights the importance of context in understanding sequences and the challenges of training RNNs, such as vanishing gradients. Additionally, it covers techniques for improving memory retention in neural networks and the applications of these models in various fields.

Uploaded by

Khushi s
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
3 views97 pages

Recurrent Neural Networks LSTMS, Transformers, Graph Neural Networks

The document discusses advanced neural network architectures, focusing on Recurrent Neural Networks (RNNs), Long Short-Term Memory (LSTM) networks, Transformers, and Graph Neural Networks. It highlights the importance of context in understanding sequences and the challenges of training RNNs, such as vanishing gradients. Additionally, it covers techniques for improving memory retention in neural networks and the applications of these models in various fields.

Uploaded by

Khushi s
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 97

6.874, 6.802, 20.390, 20.490, HST.

506
Computational Systems Biology
Deep Learning in the Life Sciences

Lecture 4:
Recurrent Neural Networks
LSTMs, Transformers,
Graph Neural Networks
Prof. Manolis Kellis
Guest lecture: Neil Band

Slides credit: Geoffrey Hinton, Ian Goodfellow,


http://mit6874.github.io David Gifford, 6.S191 (Ava Soleimany, Alex Amini)
Recurrent Neural Networks (RNNs) + Generalization
1. How do you read/listen/understand/write? Can machines do that?
– Context matters: characters, words, letters, sounds, completion, multi-modal
– Predicting next word/image: from unsupervised learning to supervised learning
2. Encoding temporal context: Hidden Markov Models (HMMs), RNNs
– Primitives: hidden state, memory of previous experiences, limitations of HMMs
– RNN architectures,unrolling,back-propagation-through-time(BPTT),param reuse
3. Vanishing gradients, Long-Short-Term Memory (LSTM), initialization
– Key idea: gated input/output/memory nodes, model choose to forget/remember
– Example: online character recognition with LSTM recurrent neural network
4. Transformer modules
– Learning temporal relationships without unrolling and without RNNs
– Encoder/Decoder output architecture and multi-head attention modeule
5. Graph Neural Networks
– Applications: social, brain, chemical drug design, graphics, transport, knowledge
– Define each node’s computation graph, from its neighborhood
– Classical network/graph problems: Node/graph classification, link prediction
– Research frontiers: deep generative models, latent graph inferences
1a. What do you hear and why?
Context matters Top-down
processing

Phonemic
restoration

Hearing lips and seeing voices


(McGurk, MacDonald, Nature 1976)
https://youtu.be/PWGeUztTkRA?t=35
Split class into 4 groups: (1) close your
eyes, (2) look left, (3) middle, (4) right

Adults: 200 ms delay max disruption.


Children: 500 ms
Delayed typing: Google Docs, zoom
video screen sharing, slow computer https://www.sciencedaily.com/releases/2018/11/181129142352.htm
Recurrent Neural Networks (RNNs) + Generalization
1. How do you read/listen/understand/write? Can machines do that?
– Context matters: characters, words, letters, sounds, completion, multi-modal
– Predicting next word/image: from unsupervised learning to supervised learning
2. Encoding temporal context: Hidden Markov Models (HMMs), RNNs
– Primitives: hidden state, memory of previous experiences, limitations of HMMs
– RNN architectures,unrolling,back-propagation-through-time(BPTT),param reuse
3. Vanishing gradients, Long-Short-Term Memory (LSTM), initialization
– Key idea: gated input/output/memory nodes, model choose to forget/remember
– Example: online character recognition with LSTM recurrent neural network
4. Transformer modules
– Learning temporal relationships without unrolling and without RNNs
– Encoder/Decoder output architecture and multi-head attention modeule
5. Graph Neural Networks
– Applications: social, brain, chemical drug design, graphics, transport, knowledge
– Define each node’s computation graph, from its neighborhood
– Classical network/graph problems: Node/graph classification, link prediction
– Research frontiers: deep generative models, latent graph inferences
2a. Encoding time
Getting targets when modeling sequences
•When applying machine learning to sequences, we often want to turn an input
sequence into an output sequence that lives in a different domain.
– E. g. turn a sequence of sound pressures into a sequence of word identities.

•When there is no separate target sequence, we can get a teaching signal by trying to
predict the next term in the input sequence.
– The target output sequence is the input sequence with an advance of 1 step.
– This seems much more natural than trying to predict one pixel in an image
from the other pixels, or one patch of an image from the rest of the image.
– For temporal sequences there is a natural order for the predictions.

•Predicting the next term in a sequence blurs the distinction between supervised and
unsupervised learning.
– It uses methods designed for supervised learning, but it doesn’t require a
separate teaching signal.
Memoryless models for sequences

• Autoregressive models
Predict the next term in a
sequence from a fixed number of
previous terms using “delay taps”. input(t-2) input(t-1) input(t)

• Feed-forward neural nets


These generalize autoregressive
models by using one or more
layers of non-linear hidden units. hidde
n

input(t-2) input(t-1) input(t)


Beyond memoryless models
• If we give our generative model some hidden state, and if we give
this hidden state its own internal dynamics, we get a much more
interesting kind of model.
– It can store information in its hidden state for a long time.
– If the dynamics is noisy and the way it generates outputs from its
hidden state is noisy, we can never know its exact hidden state.
– The best we can do is to infer a probability distribution over the
space of hidden state vectors.
• This inference is only tractable for two types of hidden state model.
Linear Dynamical Systems (engineers love them!)
• These are generative models. They have a real- time 
valued hidden state that cannot be observed

output
output

output
directly.
– The hidden state has linear dynamics with
Gaussian noise and produces the observations
using a linear model with Gaussian noise.
– There may also be driving inputs.

hidden

hidden

hidden
• To predict the next output (so that we can shoot
down the missile) we need to infer the hidden
state.
– A linearly transformed Gaussian is a Gaussian. So
the distribution over the hidden state given the data
so far is Gaussian. It can be computed using

input
driving

input
driving

input
driving
“Kalman filtering”.
Hidden Markov Models (computer scientists love them!)
• Hidden Markov Models have a discrete one-

output
output

output
of-N hidden state. Transitions between states
are stochastic and controlled by a transition
matrix. The outputs produced by a state are
stochastic.
– We cannot be sure which state produced a
given output. So the state is “hidden”.
– It is easy to represent a probability distribution
across N states with N numbers.
• To predict the next output we need to infer the
probability distribution over hidden states.
– HMMs have efficient algorithms for
inference and learning.

time 
A fundamental limitation of HMMs
• Consider what happens when a hidden Markov model generates
data.
– At each time step it must select one of its hidden states. So with N
hidden states it can only remember log(N) bits about what it generated
so far.
• Consider the information that the first half of an utterance contains
about the second half:
– The syntax needs to fit (e.g. number and tense agreement).
– The semantics needs to fit. The intonation needs to fit.
– The accent, rate, volume, and vocal tract characteristics must all fit.
• All these aspects combined could be 100 bits of information that the
first half of an utterance needs to convey to the second half. 2^100
is big!
2b. Recurrent Neural Networks
(RNNs)
Recurrent neural networks
time 
• RNNs are very powerful, because they
combine two properties:

output
output

output
– Distributed hidden state that allows
them to store a lot of information
about the past efficiently.
– Non-linear dynamics that allows
them to update their hidden state in

hidden

hidden

hidden
complicated ways.
• With enough neurons and time, RNNs
can compute anything that can be
computed by your computer.

input

input

input
Do generative models need to be stochastic?

• Linear dynamical systems and • Recurrent neural networks are


hidden Markov models are deterministic.
stochastic models. – So think of the hidden state
– But the posterior probability of an RNN as the
distribution over their equivalent of the
hidden states given the deterministic probability
observed data so far is a distribution over hidden
deterministic function of the states in a linear dynamical
data. system or hidden Markov
model.
Recurrent neural networks

• What kinds of behaviour can RNNs exhibit?


– They can oscillate. Good for motor control?
– They can settle to point attractors. Good for retrieving memories?
– They can behave chaotically. Bad for information processing?
– RNNs could potentially learn to implement lots of small programs
that each capture a nugget of knowledge and run in parallel,
interacting to produce very complicated effects.
• But the computational power of RNNs makes them very hard to train.
– For many years we could not exploit the computational power of
RNNs despite some heroic efforts (e.g. Tony Robinson’s speech
recognizer).
The equivalence between feedforward nets and recurrent
nets

w1 w4
time=3
w1 w2 W3 W4
w2 w3
time=2
Assume that there is a time w1 w2 W3 W4
delay of 1 in using each
connection.
time=1
The recurrent net is just a
layered net that keeps w1 w2 W3 W4
reusing the same weights.

time=0
2c. Alternative architectures
for RNNs
Different RNN remembering architectures

Recurrent network with no outputs

Single output
after entire
sequence

Teacher-forcing: train from y and x in parallel

o: output, y: target, L: loss o: output, y: target, L: loss


Memory: h(t-1)  h(t) Memory: o(t-1)  h(t) . Only train sequentially
2d. Back-propagation through
time (BPTT)
Reminder: Backpropagation with weight
constraints
• It is easy to modify the
backprop algorithm to To constrain : w1 = w2
incorporate linear constraints
between the weights.
we need : ∆w1 = ∆w2
• We compute the gradients as
usual, and then modify the
gradients so that they satisfy ∂E ∂E
compute : and
the constraints. ∂w1 ∂w2
– So if the weights started off
satisfying the constraints,
they will continue to satisfy ∂E ∂E
them. use + for w1 and w2
∂w1 ∂w2
Backpropagation through time

• We can think of the recurrent net as a layered, feed-forward


net with shared weights and then train the feed-forward net
with weight constraints.
• We can also think of this training algorithm in the time domain:
– The forward pass builds up a stack of the activities of all
the units at each time step.
– The backward pass peels activities off the stack to
compute the error derivatives at each time step.
– After the backward pass we add together the derivatives at
all the different times for each weight.
Getting targets when modeling sequences
•When applying machine learning to sequences, we often want to turn an input
sequence into an output sequence that lives in a different domain.
– E. g. turn a sequence of sound pressures into a sequence of word identities.

•When there is no separate target sequence, we can get a teaching signal by trying to
predict the next term in the input sequence.
– The target output sequence is the input sequence with an advance of 1 step.
– This seems much more natural than trying to predict one pixel in an image
from the other pixels, or one patch of an image from the rest of the image.
– For temporal sequences there is a natural order for the predictions.

•Predicting the next term in a sequence blurs the distinction between supervised and
unsupervised learning.
– It uses methods designed for supervised learning, but it doesn’t require a
separate teaching signal.
Recurrent Neural Networks (RNNs) + Generalization
1. How do you read/listen/understand/write? Can machines do that?
– Context matters: characters, words, letters, sounds, completion, multi-modal
– Predicting next word/image: from unsupervised learning to supervised learning
2. Encoding temporal context: Hidden Markov Models (HMMs), RNNs
– Primitives: hidden state, memory of previous experiences, limitations of HMMs
– RNN architectures,unrolling,back-propagation-through-time(BPTT),param reuse
3. Vanishing gradients, Long-Short-Term Memory (LSTM), initialization
– Key idea: gated input/output/memory nodes, model choose to forget/remember
– Example: online character recognition with LSTM recurrent neural network
4. Transformer modules
– Learning temporal relationships without unrolling and without RNNs
– Encoder/Decoder output architecture and multi-head attention modeule
5. Graph Neural Networks
– Applications: social, brain, chemical drug design, graphics, transport, knowledge
– Define each node’s computation graph, from its neighborhood
– Classical network/graph problems: Node/graph classification, link prediction
– Research frontiers: deep generative models, latent graph inferences
3a. Remembering for
longer time periods
Four effective ways to increase length of memory

• Long Short Term Memory • Echo State Networks: Initialize the


Make the RNN out of little inputhidden and hiddenhidden and
modules that are designed to outputhidden connections very
remember values for a long time. carefully so that the hidden state has a
huge reservoir of weakly coupled
• Hessian Free Optimization: Deal
oscillators which can be selectively driven
with the vanishing gradients
by the input.
problem by using a fancy
optimizer that can detect – ESNs only need to learn the
directions with a tiny gradient but hiddenoutput connections.
even smaller curvature. • Good initialization with momentum
– The HF optimizer ( Martens & Initialize like in Echo State Networks, but
Sutskever, 2011) is good at then learn all of the connections using
this. momentum.
Long Short Term Memory (LSTM)

• Hochreiter & Schmidhuber • Information gets into the cell


(1997) solved the problem of whenever its “write” gate is on.
getting an RNN to remember • The information stays in the
things for a long time (like cell so long as its “keep” gate
hundreds of time steps). is on.
• They designed a memory cell • Information can be read from
using logistic and linear units the cell by turning on its “read”
with multiplicative interactions. gate.
Implementing a memory cell in a neural network

To preserve information for a long time in


the activities of an RNN, we use a circuit
keep
that implements an analog memory cell. gate
– A linear unit that has a self-link with a
weight of 1 will maintain its state.
– Information is stored in the cell by
activating its write gate. 1.73
– Information is retrieved by activating write read
the read gate. gate gate
– We can backpropagate through this
circuit because logistics are have nice input from output to rest
derivatives. rest of RNN of RNN
RNN LSTM Backpropagation through a memory cell

keep keep keep keep


0 1 1 0

1.7 1.7 1.7

write read write read write read


1 0 0 0 0 1

1.7 1.7
time 
Reading cursive handwriting

• This is a natural task for an • Graves & Schmidhuber (2009)


RNN. showed that RNNs with LSTM
• The input is a sequence of are currently the best systems
(x,y,p) coordinates of the tip of for reading cursive writing.
the pen, where p indicates – They used a sequence of
whether the pen is up or down. small images as input
• The output is a sequence of rather than pen
characters. coordinates.
Demonstration of online handwriting recognition by an RNN with
Long Short Term Memory (from Alex Graves)
• Row 1: Shows when characters are
recognized.
– It never revises its output so
difficult decisions are more
delayed.
• Row 2: Shows the states of a subset
of the memory cells.
– Notice how they get reset when it
recognizes a character.

• Row 3: Shows the writing. The net


sees the x and y coordinates.
– Optical input actually works a bit
better than pen coordinates.
• Row 4: Shows the gradient
backpropagated all the way to the x
and y inputs from the currently most
active character.
– This lets you see which bits of the
data are influencing the decision.
https://youtu.be/9T2X6WRUwFU?t=2791
3b. Initialization
Initialization: Dealing with boundary cases

• We need to specify the initial activity state of all the hidden and output
units.
• We could just fix these initial states to have some default value like 0.5.
• But it is better to treat the initial states as learned parameters.
• We learn them in the same way as we learn the weights.
– Start off with an initial random guess for the initial states.
– At the end of each training sequence, backpropagate through time all
the way to the initial states to get the gradient of the error function
with respect to each initial state.
– Adjust the initial states by following the negative gradient.
Teaching signals for recurrent networks

• We can specify targets in several


ways:
– Specify desired final activities w1 w2 W3 W4
of all the units
– Specify desired activities of all
units for the last few steps
• Good for learning attractors
• It is easy to add in extra error w1 w2 W3 W4
derivatives as we
backpropagate.
– Specify the desired activity of a
subset of the units.
• The other units are input or
w1 w2 W3 W4
hidden units.
What the network learns
• It learns four distinct patterns of • A recurrent network can emulate
activity for the 3 hidden units. a finite state automaton, but it is
These patterns correspond to the exponentially more powerful.
nodes in the finite state With N hidden neurons it has 2^N
automaton. possible binary activity vectors
– Do not confuse units in a (but only N^2 weights)
neural network with nodes in a – This is important when the
finite state automaton. Nodes input stream has two separate
are like activity vectors. things going on at once.
– The automaton is restricted to – A finite state automaton
be in exactly one state at needs to square its number of
each time. The hidden units states.
are restricted to have exactly – An RNN needs to double its
one vector of activity at each number of units.
time.
The backward pass is linear
• There is a big difference between the
forward and backward passes.
• In the forward pass we use squashing
functions (like the logistic) to prevent the
activity vectors from exploding.
• The backward pass, is completely linear. If
you double the error derivatives at the final
layer, all the error derivatives will double.
– The forward pass determines the slope
of the linear function used for
backpropagating through each neuron.
The problem of exploding or vanishing gradients

• What happens to the magnitude of • In an RNN trained on long


the gradients as we sequences (e.g. 100 time steps)
backpropagate through many the gradients can easily explode
layers? or vanish.
– If the weights are small, the – We can avoid this by
gradients shrink initializing the weights very
exponentially. carefully.
– If the weights are big the • Even with good initial weights, its
gradients grow very hard to detect that the
exponentially. current target output depends on
• Typical feed-forward neural nets an input from many time-steps
can cope with these exponential ago.
effects because they only have a – So RNNs have difficulty
few hidden layers. dealing with long-range
dependencies.
– Can use ideas for residual
networks (ResNet), pass info
from the input to far away
nodes
Recurrent Neural Networks (RNNs) + Generalization
1. How do you read/listen/understand/write? Can machines do that?
– Context matters: characters, words, letters, sounds, completion, multi-modal
– Predicting next word/image: from unsupervised learning to supervised learning
2. Encoding temporal context: Hidden Markov Models (HMMs), RNNs
– Primitives: hidden state, memory of previous experiences, limitations of HMMs
– RNN architectures,unrolling,back-propagation-through-time(BPTT),param reuse
3. Vanishing gradients, Long-Short-Term Memory (LSTM), initialization
– Key idea: gated input/output/memory nodes, model choose to forget/remember
– Example: online character recognition with LSTM recurrent neural network
4. Transformer modules
– Learning temporal relationships without unrolling and without RNNs
– Encoder/Decoder output architecture and multi-head attention modeule
5. Graph Neural Networks
– Applications: social, brain, chemical drug design, graphics, transport, knowledge
– Define each node’s computation graph, from its neighborhood
– Classical network/graph problems: Node/graph classification, link prediction
– Research frontiers: deep generative models, latent graph inferences
4. Attention and
transformer models
Encoder/Decoder/Attention modules in Transformer

Q: matrix, query, vector representation of one word in sequence


K: all keys, vector representations of all words in sequence
V: values, vector representations of all words in sequence

Encoder, decoder, multi-head attention module: V = same word sequence as Q


Attention module = V different from Q, uses encoder and decoder sequences
Time explicitly encoded Training setup: Predict next work each time, decoder shifted by one
No need for RNN structure
Transforms one sequence into another sequence, using full context for each
(e.g. sentence translation, or any other sequential task)
Recurrent Neural Networks (RNNs) + Generalization
1. How do you read/listen/understand/write? Can machines do that?
– Context matters: characters, words, letters, sounds, completion, multi-modal
– Predicting next word/image: from unsupervised learning to supervised learning
2. Encoding temporal context: Hidden Markov Models (HMMs), RNNs
– Primitives: hidden state, memory of previous experiences, limitations of HMMs
– RNN architectures,unrolling,back-propagation-through-time(BPTT),param reuse
3. Vanishing gradients, Long-Short-Term Memory (LSTM), initialization
– Key idea: gated input/output/memory nodes, model choose to forget/remember
– Example: online character recognition with LSTM recurrent neural network
4. Transformer modules
– Learning temporal relationships without unrolling and without RNNs
– Encoder/Decoder output architecture and multi-head attention modeule
5. Graph Neural Networks
– Applications: social, brain, chemical drug design, graphics, transport, knowledge
– Define each node’s computation graph, from its neighborhood
– Classical network/graph problems: Node/graph classification, link prediction
– Research frontiers: deep generative models, latent graph inferences
5. Graph Neural Networks

Guest lecture by Neil Band


Sources / Further Reading
● Adapted from
○ Thomas Kipf’s presentations (Cambridge CompBio, IPAM UCLA)

○ Graph Neural Networks by Xavier Bresson (Guest lecture in Yann LeCun’s NYU DL course)

○ CS224 Machine Learning on Graphs by Jure Leskovec (Course @ Stanford)

○ Junction Tree Variational Autoencoder (Wengong Jin, ICML 2018)

● Mining and Learning with Graphs at Scale (Google Graph Mining team @ NIPS 2020)

● Graph Representation Learning (Book by Will Hamilton, 2020)

● Thomas Kipf’s thesis (Deep Learning with Graph Structured Representations, 2020)
● Further reading: Petar Veličković’s thread of resources

Jraph! (GNNs with Jax)


Outline
1. Motivation

2. Graph neural nets (GNNs)


Introduction and history

3. GNNs for classic network problems

Node Graph
classification classification

...
Link prediction
Outline
1. Motivation
4. Research frontiers
2. Graph neural nets (GNNs)
Introduction and history
Deep
3. GNNs for classic network problems Latent graph
generative
inference
graph models

Node Graph
classification classification With applications in...

● Chemical synthesis
● Interacting systems (physical,
... multi-agent, biological)
Link prediction ● Causal inference
● Program induction
The M L canon lives in grid world

● Images, volumes, videos lie on


2D, 3D, 2D + 1 grids

● Sentences, words, speech lie on


1D grids

● Deep neural nets on grids exploit:


- translation equivariance (weight sharing)
- hierarchical compositionality
But there’s so much more...
But there’s so much more...
Cool applications (in the last year!)
DeepMind / Google Maps ETA improvements across world SuperGlue (Magic Leap) feature matching

Large Hadron Collider real-time collision analysis MaSIF predicts protein-protein interactions
Setup
Naive approach
1. Join adjacency matrix and node features

2. Plug them into a deep neural net

● Issues with this idea:


○ O(N) parameters → 6 billion nodes in Pinterest

○ Not applicable to graphs of different sizes → graphs change!

○ Not invariant to node ordering → expensive sorting


1 Graph neural nets
Aggregating neighbors
Recap: CNN (on grids) as message passing

Single CNN layer with


3x3 filter

Animation by Vincent Dumoulin


Aggregating neighbors

NNs
Aggregating neighbors
Graph convolutional networks (GCNs)
Kipf & Welling (ICLR 2017), related previous works: Duvenaud et al. (NIPS 2015) and Li et al. (ICLR 2016)
GCN classification on citation networks
Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, ICLR 2017
GNNs with edge embeddings (Neural message passing)
Battaglia et al. (NIPS 2016), Gilmer et al. (ICML 2017), Kipf et al. (ICML 2018)
Graph neural nets with attention
Monti et al. (CVPR 2017), Hoshen (NIPS 2017), Veličković et al. (ICLR 2018)
2 Application to “classical”
network problems

Node Graph
Link prediction
classification classification
One fits all: Classification and link prediction with GNNs/GCNs
3 Research frontiers

Deep
Latent graph
generative
inference
graph models
Unsupervised learning with GNNs
Unsupervised learning with GNNs

● Sampling strategies
e.g. positive: neighbor; negative: random node

● Encoder variants
GCN, GAT,MLP,Lookup table

● Node representations
Geometry of latent space, distributional embeddings (e.g. Hyperbolic GCNN, Chami et al. 2019)

● Score functions
Inner/bilinear product, local vs. global (e.g. Deep Graph Infomax, Velickovic et al. 2019)

● Loss
(Cross-entropy,MSE, exponential)
Unsupervised learning takeaways
A Modular Framework for Unsupervised Graph Representation Learning, Daza & Kipf (WIP)
Likelihood-based (deep) graph generation

VGAE generative model (with ELBO loss)

(Variational) Graph2Gauss
Graphite Hyperspherical VAEs
Graph Auto-Encoders Grover et al.
Bojchevski &
Davidson et al.
Kipf & Welling Gunneman
(NIPS BDL 2017) (UAI 2018)
(NIPS BDL 2016) (ICLR 2018)
Likelihood-based (deep) graph generation
Likelihood-based (deep) graph generation
Likelihood-based (deep) graph generation

Learning Graphical Deep Generative


GraphVAE GraphRNN
State Transitions Models of Graphs Simonovsky et al. You et al.
Johnson Li et al. (arXiv 2018) (ICML 2018)
(ICLR 2017) (arXiv 2018)
Graph generation for drug discovery
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018)

Aim:generate molecules with high potency


How should we decode the graph?
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018)

● Not every graph is chemically valid


● Invalid intermediate states → hard to validate
● Many intermediate states (i.e. long sequences) → difficult to train (Li et al. 2018)
How should we decode the graph?
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018)

Tree Decomposition

● Shorter action sequence


● Easy to check validity as we construct
● Vocabulary size: ~800 for 250K
molecules
High-level approach
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018)
Focus on the cool part: tree decoding
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018)
Tree decoding
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018); Tree-Structured Decoding, Alvarez-Melis & Jaakkola (ICLR 2017)
Tree decoding
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018); Tree-Structured Decoding, Alvarez-Melis & Jaakkola (ICLR 2017)

Topological Prediction: Should we add a child node, or backtrack?


Label Prediction: What do we label the new node?
Tree decoding
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018); Tree-Structured Decoding, Alvarez-Melis & Jaakkola (ICLR 2017)

Topological Prediction: Should we add a child node, or backtrack?


Label Prediction: What do we label the new node?
Tree decoding
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018); Tree-Structured Decoding, Alvarez-Melis & Jaakkola (ICLR 2017)

Encodes state of subtree thus far Functional group features


JTVAE evaluation
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018)

1 Molecule Reconstruction
100 forward passes per molecule, report portion of
decoded molecules identical to input

2 Molecule Validity
Random samples from latent z, report portion that
are chemically valid (RDKit)
JTVAE without validity checking: 93.5%

3 Bayesian Optimization
1. Train a VAE, associate each molecule with
latent vector (mean of encoding distribution)
2. Train a sparse GP to predict target chemical
property y(m) given the latent representation
3. Use property predictor for BO
3 Research frontiers

Deep
Latent graph
generative
inference
graph models
Modeling implicit/hidden structure
Neural Relational Inference for Interacting Systems, Kipf & Fetaya et al. (ICML 2018)
Neural Relational Inference with GNNs
Neural Relational Inference for Interacting Systems, Kipf & Fetaya et al. (ICML 2018)

Discrete (Gumbel softmax trick)


[Jang et al., 2016, Maddison et al., 2016]
NRI evaluation - toy data
Neural Relational Inference for Interacting Systems, Kipf & Fetaya et al. (ICML 2018)
NRI evaluation - CMU Motion Capture (e.g. walking)
Neural Relational Inference for Interacting Systems, Kipf & Fetaya et al. (ICML 2018)
NRI applications - causal discovery
Amortized Causal Discovery: Learning to Infer Causal Graphs from Time-Series Data, Lowe et al. 2020
Challenges and future work in graph neural nets

● Problems of neighborhood aggregation / message passing


○ Theoretical relation to WL isomorphism, simple graph convolutions;
tree-structured computation graphs → bounded power
○ Oversmoothing (residual/gated updates help, but don’t solve)
○ See recent work from Max Welling e.g. Natural Graph Networks
● Scalable, stable generative models
● Learning on large, evolving data
● (Mostly) assume a graph structure is provided as input
○ Neural Relational Inference is a preliminary work here, also see Pointer
Graph Networks (Velickovic et al., NeurIPS 2020)
● Multi-modal and cross-modal learning (e.g. sequence2graph)
Recurrent Neural Networks (RNNs) + Generalization
1. How do you read/listen/understand/write? Can machines do that?
– Context matters: characters, words, letters, sounds, completion, multi-modal
– Predicting next word/image: from unsupervised learning to supervised learning
2. Encoding temporal context: Hidden Markov Models (HMMs), RNNs
– Primitives: hidden state, memory of previous experiences, limitations of HMMs
– RNN architectures,unrolling,back-propagation-through-time(BPTT),param reuse
3. Vanishing gradients, Long-Short-Term Memory (LSTM), initialization
– Key idea: gated input/output/memory nodes, model choose to forget/remember
– Example: online character recognition with LSTM recurrent neural network
4. Transformer modules
– Learning temporal relationships without unrolling and without RNNs
– Encoder/Decoder output architecture and multi-head attention modeule
5. Graph Neural Networks
– Applications: social, brain, chemical drug design, graphics, transport, knowledge
– Define each node’s computation graph, from its neighborhood
– Classical network/graph problems: Node/graph classification, link prediction
– Research frontiers: deep generative models, latent graph inferences
Appendix
Graph Transformers (Li et al. 2018)
A Vaswani, N Shazeer, N Parmar, J Uszkoreit, L Jones, A Gomez, L Kaiser, I Polosukhin, Attention is all you need (2017)

Value

Attn in 1-hop
neighborhood

Query Key
Graph Transformers (Li et al. 2018)
A Vaswani, N Shazeer, N Parmar, J Uszkoreit, L Jones, A Gomez, L Kaiser, I Polosukhin, Attention is all you need (2017)

● We can frame transformers as a special case of GCNs when the


graph is fully connected.
● The neighborhood is the whole graph.
A brief history of graph neural nets
Relation Nets
Santoro et al. GraphSAGE
“Spatial methods” (NIPS 2017)
MoNet Hamilton et al.
Monti et al. Programs as (NIPS 2017)
(CVPR 2017) Graphs
Original GNN GG-NN
Allamanis et al.
Li et al.
Gori et al. (2005)
(ICLR 2016)
(ICLR 2018) NRI
Neural M P Kipf et al.
Gilmer et al. GAT (ICML 2018)
(ICML 2017) Veličković et al.
(ICLR 2018)

“DL on graphs explosion”


“Spectral methods”
Other early work:
● Duvenaud et al. (NIPS 2015)
Spectral
ChebNet ● Dai et al. (ICML 2016)
Graph CNN Defferrard et al. ● Niepert et al. (ICML 2016)
Bruna et al.
(ICLR 2015)
(NIPS 2016) ● Battaglia et al. (NIPS 2016)
● Atwood & Towsley (NIPS 2016)
● Sukhbaatar et al. (NIPS 2016)
MoNet & Relational GCN for modeling (multi-)relational data
Monti et al. (CVPR 2017), Schlichtkrull & Kipf et al. (ESWC 2018)

Relational GCN update rule


Semi-supervised classification on graphs
Toy example (semi-supervised learning)
from tkipf.github.io/graph-convolutional-networks

Latent space dynamics for 300 training iterations. Labeled


nodes are highlighted.

GCN model manages to linearly separate classes with


only 1 training example per class, no node features!
Tree decoding
Junction Tree Variational Autoencoder, Jin et al. (ICML 2018); Tree-Structured Decoding, Alvarez-Melis & Jaakkola (ICLR 2017)

Topological Prediction

Label Prediction

Training

+ Teacher forcing -- replace topological and label


predictions with ground truth at train time
Generalizing the space of M L approaches on graphs
Machine Learning on Graphs: A Model and Comprehensive Taxonomy (Chami et al., preprint)

GraphEDM model

You might also like