Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0783f1c

Browse files
authored
Merge pull request tensorflow#5227 from mikaelsouza/adding-fuse-batch-norm-parameter
Added fused_batch_norm parameter
2 parents 23b5b42 + 84577d6 commit 0783f1c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

research/gan/cifar/networks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def generator(noise, is_training=True):
4646
Returns:
4747
A single Tensor with a batch of generated CIFAR images.
4848
"""
49-
images, _ = dcgan.generator(noise, is_training=is_training)
49+
images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
5050

5151
# Make sure output lies between [-1, 1].
5252
return tf.tanh(images)
@@ -68,7 +68,7 @@ def conditional_generator(inputs, is_training=True):
6868
noise, one_hot_labels = inputs
6969
noise = tfgan.features.condition_tensor_from_onehot(noise, one_hot_labels)
7070

71-
images, _ = dcgan.generator(noise, is_training=is_training)
71+
images, _ = dcgan.generator(noise, is_training=is_training, fused_batch_norm=True)
7272

7373
# Make sure output lies between [-1, 1].
7474
return tf.tanh(images)
@@ -94,7 +94,7 @@ def discriminator(img, unused_conditioning, is_training=True):
9494
images are real. The output can lie in [-inf, inf], with positive values
9595
indicating high confidence that the images are real.
9696
"""
97-
logits, _ = dcgan.discriminator(img, is_training=is_training)
97+
logits, _ = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
9898
return logits
9999

100100

@@ -118,7 +118,7 @@ def conditional_discriminator(img, conditioning, is_training=True):
118118
images are real. The output can lie in [-inf, inf], with positive values
119119
indicating high confidence that the images are real.
120120
"""
121-
logits, end_points = dcgan.discriminator(img, is_training=is_training)
121+
logits, end_points = dcgan.discriminator(img, is_training=is_training, fused_batch_norm=True)
122122

123123
# Condition the last convolution layer.
124124
_, one_hot_labels = conditioning

0 commit comments

Comments
 (0)