You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: research/fivo/README.md
+88-32Lines changed: 88 additions & 32 deletions
Original file line number
Diff line number
Diff line change
@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma
8
8
9
9
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
10
10
11
-
Additionally it contains an implementation of the variational recurrent neural network (VRNN), a sequential latent variable model that can be trained using these three objectives. This repo provides code for training a VRNN to do sequence modeling of pianoroll and speech data.
11
+
Additionally it contains several sequential latent variable model implementations:
12
+
13
+
* Variational recurrent neural network (VRNN)
14
+
* Stochastic recurrent neural network (SRNN)
15
+
* Gaussian hidden Markov model with linear conditionals (GHMM)
16
+
17
+
The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model.
12
18
13
19
#### Directory Structure
14
20
The important parts of the code are organized as follows.
15
21
16
22
```
17
-
fivo.py # main script, contains flag definitions
18
-
runners.py # graph construction code for training and evaluation
19
-
bounds.py # code for computing each bound
20
-
data
21
-
├── datasets.py # readers for pianoroll and speech datasets
22
-
├── calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
23
-
└── create_timit_dataset.py # preprocesses the TIMIT dataset
24
-
models
25
-
└── vrnn.py # variational RNN implementation
23
+
run_fivo.py # main script, contains flag definitions
24
+
fivo
25
+
├─smc.py # a sequential Monte Carlo implementation
26
+
├─bounds.py # code for computing each bound, uses smc.py
27
+
├─runners.py # code for VRNN and SRNN training and evaluation
28
+
├─ghmm_runners.py # code for GHMM training and evaluation
29
+
├─data
30
+
| ├─datasets.py # readers for pianoroll and speech datasets
31
+
| ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
32
+
| └─create_timit_dataset.py # preprocesses the TIMIT dataset
33
+
└─models
34
+
├─base.py # base classes used in other models
35
+
├─vrnn.py # VRNN implementation
36
+
├─srnn.py # SRNN implementation
37
+
└─ghmm.py # Gaussian hidden Markov model (GHMM) implementation
26
38
bin
27
-
├── run_train.sh # an example script that runs training
28
-
├── run_eval.sh # an example script that runs evaluation
29
-
└── download_pianorolls.sh # a script that downloads the pianoroll files
39
+
├─run_train.sh # an example script that runs training
40
+
├─run_eval.sh # an example script that runs evaluation
41
+
├─run_sample.sh # an example script that runs sampling
42
+
├─run_tests.sh # a script that runs all tests
43
+
└─download_pianorolls.sh # a script that downloads pianoroll files
Now we can train a model. Here is a standard training run, taken from `bin/run_train.sh`:
77
+
Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`:
64
78
```
65
-
python fivo.py \
79
+
python run_fivo.py \
66
80
--mode=train \
67
81
--logdir=/tmp/fivo \
68
82
--model=vrnn \
@@ -75,26 +89,24 @@ python fivo.py \
75
89
--dataset_type="pianoroll"
76
90
```
77
91
78
-
You should see output that looks something like this (with a lot of extra logging cruft):
92
+
You should see output that looks something like this (with extra logging cruft):
79
93
80
94
```
81
-
Step 1, fivo bound per timestep: -11.801050
82
-
global_step/sec: 9.89825
83
-
Step 101, fivo bound per timestep: -11.198309
84
-
global_step/sec: 9.55475
85
-
Step 201, fivo bound per timestep: -11.287262
86
-
global_step/sec: 9.68146
87
-
step 301, fivo bound per timestep: -11.316490
88
-
global_step/sec: 9.94295
89
-
Step 401, fivo bound per timestep: -11.151743
95
+
Saving checkpoints for 0 into /tmp/fivo/model.ckpt.
96
+
Step 1, fivo bound per timestep: -11.322491
97
+
global_step/sec: 7.49971
98
+
Step 101, fivo bound per timestep: -11.399275
99
+
global_step/sec: 8.04498
100
+
Step 201, fivo bound per timestep: -11.174991
101
+
global_step/sec: 8.03989
102
+
Step 301, fivo bound per timestep: -11.073008
90
103
```
91
-
You will also see lines saying `Out of range: exceptions.StopIteration: Iteration finished`. This is not an error and is fine.
92
104
#### Evaluation
93
105
94
106
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
95
107
96
108
```
97
-
python fivo.py \
109
+
python run_fivo.py \
98
110
--mode=eval \
99
111
--split=test \
100
112
--alsologtostderr \
@@ -108,12 +120,52 @@ python fivo.py \
108
120
109
121
You should see output like this:
110
122
```
111
-
Model restored from step 1, evaluating.
112
-
test elbo ll/t: -12.299635, iwae ll/t: -12.128336 fivo ll/t: -11.656939
113
-
test elbo ll/seq: -754.750312, iwae ll/seq: -744.238773 fivo ll/seq: -715.3121490
123
+
Restoring parameters from /tmp/fivo/model.ckpt-0
124
+
Model restored from step 0, evaluating.
125
+
test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776
126
+
test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141
114
127
```
115
128
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
116
129
130
+
#### Sampling
131
+
132
+
You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`:
133
+
```
134
+
python run_fivo.py \
135
+
--mode=sample \
136
+
--alsologtostderr \
137
+
--logdir="/tmp/fivo" \
138
+
--model=vrnn \
139
+
--bound=fivo \
140
+
--batch_size=4 \
141
+
--num_samples=4 \
142
+
--split=test \
143
+
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
144
+
--dataset_type="pianoroll" \
145
+
--prefix_length=25 \
146
+
--sample_length=50
147
+
```
148
+
149
+
Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix.
150
+
151
+
You should see very little output.
152
+
```
153
+
Restoring parameters from /tmp/fivo/model.ckpt-0
154
+
Running local_init_op.
155
+
Done running local_init_op.
156
+
```
157
+
158
+
Loading the samples with `np.load` confirms that we conditioned the model on 4
159
+
prefixes of length 25 and sampled 4 sequences of length 50 for each prefix.
160
+
```
161
+
>>> import numpy as np
162
+
>>> x = np.load("/tmp/fivo/samples.npz")
163
+
>>> x[()]['prefixes'].shape
164
+
(25, 4, 88)
165
+
>>> x[()]['samples'].shape
166
+
(50, 4, 4, 88)
167
+
```
168
+
117
169
### Training on TIMIT
118
170
119
171
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
This is very similar to training on pianoroll datasets, with just a few flags switched.
139
191
```
140
-
python fivo.py \
192
+
python run_fivo.py \
141
193
--mode=train \
142
194
--logdir=/tmp/fivo \
143
195
--model=vrnn \
@@ -149,6 +201,10 @@ python fivo.py \
149
201
--dataset_path="$TIMIT_DIR/train" \
150
202
--dataset_type="speech"
151
203
```
204
+
Evaluation and sampling are similar.
205
+
206
+
### Tests
207
+
This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code.
0 commit comments