forked from IshmaelBelghazi/ALI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_spiral_plots
More file actions
executable file
·92 lines (79 loc) · 3.52 KB
/
generate_spiral_plots
File metadata and controls
executable file
·92 lines (79 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python
import argparse
import theano
from blocks.bricks.interfaces import Random
from blocks.serialization import load
from matplotlib import pyplot, rc
from theano import tensor
from ali.streams import create_spiral_data_streams
rc('font', **{'family': 'serif', 'serif': 'Computer Modern Roman'})
rc('text', usetex=True)
def main(ali_main_loop, gan_main_loop, save_path=None):
ali, = ali_main_loop.model.top_bricks
gan, = gan_main_loop.model.top_bricks
random_brick = Random()
_1, _2, stream = create_spiral_data_streams(1000, 1000)
x = tensor.as_tensor_variable(next(stream.get_epoch_iterator())[0])
z = random_brick.theano_rng.normal(
size=(x.shape[0], ali.decoder.input_dim), dtype=x.dtype)
params = ali.encoder.apply(x)
latent_dim = ali.decoder.input_dim
mu, log_sigma = params[:, :latent_dim], params[:, latent_dim:]
epsilon = random_brick.theano_rng.normal(size=mu.shape, dtype=mu.dtype)
z_hat = mu + tensor.exp(log_sigma) * epsilon
x_tilde = ali.decoder.apply(z)
x_hat = ali.decoder.apply(z_hat)
gan_x_tilde = gan.decoder.apply(z)
samples = theano.function([], [x, x_tilde, gan_x_tilde, x_hat, z, z_hat])()
x, x_tilde, gan_x_tilde, x_hat, z, z_hat = samples
figure, axes = pyplot.subplots(nrows=2, ncols=3)
for ax in axes.ravel():
ax.set_aspect('equal')
ax.set_xticks([])
ax.set_yticks([])
for ax in axes[0]:
ax.set_xlim([-2, 2])
ax.set_ylim([-2, 2])
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
for ax in axes[1]:
ax.set_xlim([-4, 4])
ax.set_ylim([-4, 4])
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
axes[0, 0].set_title('ALI reconstructions')
axes[0, 0].scatter(x[:, 0], x[:, 1], marker='o', c='black', alpha=0.3)
axes[0, 0].scatter(x_hat[:, 0], x_hat[:, 1], marker='o', c='blue',
alpha=0.3)
axes[0, 1].set_title('ALI samples')
axes[0, 1].scatter(x[:, 0], x[:, 1], marker='o', c='black', alpha=0.3)
axes[0, 1].scatter(x_tilde[:, 0], x_tilde[:, 1], marker='o', c='blue',
alpha=0.3)
axes[0, 2].set_title('GAN samples')
axes[0, 2].scatter(x[:, 0], x[:, 1], marker='o', c='black', alpha=0.3)
axes[0, 2].scatter(gan_x_tilde[:, 0], gan_x_tilde[:, 1], marker='o',
c='blue', alpha=0.3)
axes[1, 0].set_title('ALI encoding')
axes[1, 0].scatter(z_hat[:, 0], z_hat[:, 1], marker='o', c='blue',
alpha=0.3)
axes[1, 1].set_title('Prior')
axes[1, 1].scatter(z[:, 0], z[:, 1], marker='o', c='blue', alpha=0.3)
axes[1, 2].set_title('Prior')
axes[1, 2].scatter(z[:, 0], z[:, 1], marker='o', c='blue', alpha=0.3)
pyplot.tight_layout()
if save_path is None:
pyplot.show()
else:
pyplot.savefig(save_path, transparent=True, bbox_inches='tight')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Plot Spiral samples.")
parser.add_argument("ali_main_loop_path", type=str,
help="path to the pickled ALI main loop.")
parser.add_argument("gan_main_loop_path", type=str,
help="path to the pickled GAN main loop.")
parser.add_argument("--save-path", type=str, default=None,
help="where to save the generated samples.")
args = parser.parse_args()
with open(args.ali_main_loop_path, 'rb') as ali_src:
with open(args.gan_main_loop_path, 'rb') as gan_src:
main(load(ali_src), load(gan_src), args.save_path)