Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 46dea62

Browse files
author
Dieterich Lawson
committed
Updating fivo codebase
1 parent 5856878 commit 46dea62

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+19724
-935
lines changed

research/fivo/.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.pkl binary
2+
*.tfrecord binary

research/fivo/.gitignore

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
MANIFEST
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
49+
# Translations
50+
*.mo
51+
*.pot
52+
53+
# Django stuff:
54+
*.log
55+
.static_storage/
56+
.media/
57+
local_settings.py
58+
59+
# Flask stuff:
60+
instance/
61+
.webassets-cache
62+
63+
# Scrapy stuff:
64+
.scrapy
65+
66+
# Sphinx documentation
67+
docs/_build/
68+
69+
# PyBuilder
70+
target/
71+
72+
# Jupyter Notebook
73+
.ipynb_checkpoints
74+
75+
# pyenv
76+
.python-version
77+
78+
# celery beat schedule file
79+
celerybeat-schedule
80+
81+
# SageMath parsed files
82+
*.sage.py
83+
84+
# Environments
85+
.env
86+
.venv
87+
env/
88+
venv/
89+
ENV/
90+
env.bak/
91+
venv.bak/
92+
93+
# Spyder project settings
94+
.spyderproject
95+
.spyproject
96+
97+
# Rope project settings
98+
.ropeproject
99+
100+
# mkdocs documentation
101+
/site
102+
103+
# mypy
104+
.mypy_cache/

research/fivo/README.md

Lines changed: 88 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,42 @@ Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohamma
88

99
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
1010

11-
Additionally it contains an implementation of the variational recurrent neural network (VRNN), a sequential latent variable model that can be trained using these three objectives. This repo provides code for training a VRNN to do sequence modeling of pianoroll and speech data.
11+
Additionally it contains several sequential latent variable model implementations:
12+
13+
* Variational recurrent neural network (VRNN)
14+
* Stochastic recurrent neural network (SRNN)
15+
* Gaussian hidden Markov model with linear conditionals (GHMM)
16+
17+
The VRNN and SRNN can be trained for sequence modeling of pianoroll and speech data. The GHMM is trainable on a synthetic dataset, useful as a simple example of an analytically tractable model.
1218

1319
#### Directory Structure
1420
The important parts of the code are organized as follows.
1521

1622
```
17-
fivo.py # main script, contains flag definitions
18-
runners.py # graph construction code for training and evaluation
19-
bounds.py # code for computing each bound
20-
data
21-
├── datasets.py # readers for pianoroll and speech datasets
22-
├── calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
23-
└── create_timit_dataset.py # preprocesses the TIMIT dataset
24-
models
25-
└── vrnn.py # variational RNN implementation
23+
run_fivo.py # main script, contains flag definitions
24+
fivo
25+
├─smc.py # a sequential Monte Carlo implementation
26+
├─bounds.py # code for computing each bound, uses smc.py
27+
├─runners.py # code for VRNN and SRNN training and evaluation
28+
├─ghmm_runners.py # code for GHMM training and evaluation
29+
├─data
30+
| ├─datasets.py # readers for pianoroll and speech datasets
31+
| ├─calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
32+
| └─create_timit_dataset.py # preprocesses the TIMIT dataset
33+
└─models
34+
├─base.py # base classes used in other models
35+
├─vrnn.py # VRNN implementation
36+
├─srnn.py # SRNN implementation
37+
└─ghmm.py # Gaussian hidden Markov model (GHMM) implementation
2638
bin
27-
├── run_train.sh # an example script that runs training
28-
├── run_eval.sh # an example script that runs evaluation
29-
└── download_pianorolls.sh # a script that downloads the pianoroll files
39+
├─run_train.sh # an example script that runs training
40+
├─run_eval.sh # an example script that runs evaluation
41+
├─run_sample.sh # an example script that runs sampling
42+
├─run_tests.sh # a script that runs all tests
43+
└─download_pianorolls.sh # a script that downloads pianoroll files
3044
```
3145

32-
### Training on Pianorolls
46+
### Pianorolls
3347

3448
Requirements before we start:
3549

@@ -60,9 +74,9 @@ python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl
6074

6175
#### Training
6276

63-
Now we can train a model. Here is a standard training run, taken from `bin/run_train.sh`:
77+
Now we can train a model. Here is the command for a standard training run, taken from `bin/run_train.sh`:
6478
```
65-
python fivo.py \
79+
python run_fivo.py \
6680
--mode=train \
6781
--logdir=/tmp/fivo \
6882
--model=vrnn \
@@ -75,26 +89,24 @@ python fivo.py \
7589
--dataset_type="pianoroll"
7690
```
7791

78-
You should see output that looks something like this (with a lot of extra logging cruft):
92+
You should see output that looks something like this (with extra logging cruft):
7993

8094
```
81-
Step 1, fivo bound per timestep: -11.801050
82-
global_step/sec: 9.89825
83-
Step 101, fivo bound per timestep: -11.198309
84-
global_step/sec: 9.55475
85-
Step 201, fivo bound per timestep: -11.287262
86-
global_step/sec: 9.68146
87-
step 301, fivo bound per timestep: -11.316490
88-
global_step/sec: 9.94295
89-
Step 401, fivo bound per timestep: -11.151743
95+
Saving checkpoints for 0 into /tmp/fivo/model.ckpt.
96+
Step 1, fivo bound per timestep: -11.322491
97+
global_step/sec: 7.49971
98+
Step 101, fivo bound per timestep: -11.399275
99+
global_step/sec: 8.04498
100+
Step 201, fivo bound per timestep: -11.174991
101+
global_step/sec: 8.03989
102+
Step 301, fivo bound per timestep: -11.073008
90103
```
91-
You will also see lines saying `Out of range: exceptions.StopIteration: Iteration finished`. This is not an error and is fine.
92104
#### Evaluation
93105

94106
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
95107

96108
```
97-
python fivo.py \
109+
python run_fivo.py \
98110
--mode=eval \
99111
--split=test \
100112
--alsologtostderr \
@@ -108,12 +120,52 @@ python fivo.py \
108120

109121
You should see output like this:
110122
```
111-
Model restored from step 1, evaluating.
112-
test elbo ll/t: -12.299635, iwae ll/t: -12.128336 fivo ll/t: -11.656939
113-
test elbo ll/seq: -754.750312, iwae ll/seq: -744.238773 fivo ll/seq: -715.3121490
123+
Restoring parameters from /tmp/fivo/model.ckpt-0
124+
Model restored from step 0, evaluating.
125+
test elbo ll/t: -12.198834, iwae ll/t: -11.981187 fivo ll/t: -11.579776
126+
test elbo ll/seq: -748.564789, iwae ll/seq: -735.209206 fivo ll/seq: -710.577141
114127
```
115128
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
116129

130+
#### Sampling
131+
132+
You can also sample from trained models. The `sample` mode loads a model checkpoint, conditions the model on a prefix of a randomly chosen datapoint, samples a sequence of outputs from the conditioned model, and writes out the samples and prefix to a `.npz` file in `logdir`. For example here is a command that samples from a model trained on JSB, taken from `bin/run_sample.sh`:
133+
```
134+
python run_fivo.py \
135+
--mode=sample \
136+
--alsologtostderr \
137+
--logdir="/tmp/fivo" \
138+
--model=vrnn \
139+
--bound=fivo \
140+
--batch_size=4 \
141+
--num_samples=4 \
142+
--split=test \
143+
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
144+
--dataset_type="pianoroll" \
145+
--prefix_length=25 \
146+
--sample_length=50
147+
```
148+
149+
Here `num_samples` denotes the number of samples used when conditioning the model as well as the number of trajectories to sample for each prefix.
150+
151+
You should see very little output.
152+
```
153+
Restoring parameters from /tmp/fivo/model.ckpt-0
154+
Running local_init_op.
155+
Done running local_init_op.
156+
```
157+
158+
Loading the samples with `np.load` confirms that we conditioned the model on 4
159+
prefixes of length 25 and sampled 4 sequences of length 50 for each prefix.
160+
```
161+
>>> import numpy as np
162+
>>> x = np.load("/tmp/fivo/samples.npz")
163+
>>> x[()]['prefixes'].shape
164+
(25, 4, 88)
165+
>>> x[()]['samples'].shape
166+
(50, 4, 4, 88)
167+
```
168+
117169
### Training on TIMIT
118170

119171
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
@@ -137,7 +189,7 @@ train mean: 0.006060 train std: 548.136169
137189
#### Training on TIMIT
138190
This is very similar to training on pianoroll datasets, with just a few flags switched.
139191
```
140-
python fivo.py \
192+
python run_fivo.py \
141193
--mode=train \
142194
--logdir=/tmp/fivo \
143195
--model=vrnn \
@@ -149,6 +201,10 @@ python fivo.py \
149201
--dataset_path="$TIMIT_DIR/train" \
150202
--dataset_type="speech"
151203
```
204+
Evaluation and sampling are similar.
205+
206+
### Tests
207+
This codebase comes with a number of tests to verify correctness, runnable via `bin/run_tests.sh`. The tests are also useful to look at for examples of how to use the code.
152208

153209
### Contact
154210

research/fivo/bin/run_eval.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
PIANOROLL_DIR=$HOME/pianorolls
2020

21-
python fivo.py \
21+
python run_fivo.py \
2222
--mode=eval \
2323
--logdir=/tmp/fivo \
2424
--model=vrnn \

research/fivo/bin/run_sample.sh

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/bin/bash
2+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
# An example of sampling from the model.
18+
19+
PIANOROLL_DIR=$HOME/pianorolls
20+
21+
python run_fivo.py \
22+
--mode=sample \
23+
--alsologtostderr \
24+
--logdir="/tmp/fivo" \
25+
--model=vrnn \
26+
--bound=fivo \
27+
--batch_size=4 \
28+
--num_samples=4 \
29+
--split=test \
30+
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
31+
--dataset_type="pianoroll" \
32+
--prefix_length=25 \
33+
--sample_length=50

research/fivo/bin/run_tests.sh

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
python -m fivo.smc_test && \
18+
python -m fivo.bounds_test && \
19+
python -m fivo.nested_utils_test && \
20+
python -m fivo.data.datasets_test && \
21+
python -m fivo.models.ghmm_test && \
22+
python -m fivo.models.vrnn_test && \
23+
python -m fivo.models.srnn_test && \
24+
python -m fivo.ghmm_runners_test && \
25+
python -m fivo.runners_test

research/fivo/bin/run_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
PIANOROLL_DIR=$HOME/pianorolls
2020

21-
python fivo.py \
21+
python run_fivo.py \
2222
--mode=train \
2323
--logdir=/tmp/fivo \
2424
--model=vrnn \

0 commit comments

Comments
 (0)