forked from IshmaelBelghazi/ALI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample
More file actions
executable file
·46 lines (37 loc) · 1.69 KB
/
sample
File metadata and controls
executable file
·46 lines (37 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env python
import argparse
import theano
from blocks.serialization import load
from matplotlib import cm, pyplot
from mpl_toolkits.axes_grid1 import ImageGrid
def main(main_loop, nrows, ncols, save_path=None):
ali, = main_loop.model.top_bricks
input_shape = ali.encoder.get_dim('output')
z = ali.theano_rng.normal(size=(nrows * ncols,) + input_shape)
x = ali.sample(z)
samples = theano.function([], x)()
figure = pyplot.figure()
grid = ImageGrid(figure, 111, (nrows, ncols), axes_pad=0.1)
for sample, axis in zip(samples, grid):
axis.imshow(sample.transpose(1, 2, 0).squeeze(),
cmap=cm.Greys_r, interpolation='nearest')
axis.set_yticklabels(['' for _ in range(sample.shape[1])])
axis.set_xticklabels(['' for _ in range(sample.shape[2])])
axis.axis('off')
if save_path is None:
pyplot.show()
else:
pyplot.savefig(save_path, transparent=True, bbox_inches='tight')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Plot samples.")
parser.add_argument("main_loop_path", type=str,
help="path to the pickled main loop.")
parser.add_argument("--nrows", type=int, default=10,
help="number of rows of samples to display.")
parser.add_argument("--ncols", type=int, default=10,
help="number of columns of samples to display.")
parser.add_argument("--save-path", type=str, default=None,
help="where to save the generated samples.")
args = parser.parse_args()
with open(args.main_loop_path, 'rb') as src:
main(load(src), args.nrows, args.ncols, args.save_path)