Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit babbb38

Browse files
author
Victor Bittorf
committed
Adding flag to set random seeds.
1 parent 2661eb9 commit babbb38

File tree

6 files changed

+70
-1
lines changed

6 files changed

+70
-1
lines changed

official/mnist/mnist.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ def main(argv):
182182

183183
model_function = model_fn
184184

185+
if flags.seed is not None:
186+
model_helpers.set_random_seed(flags.seed)
187+
185188
if flags.multi_gpu:
186189
validate_batch_size_for_multi_gpu(flags.batch_size)
187190

official/resnet/imagenet_main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from official.resnet import imagenet_preprocessing
2727
from official.resnet import resnet_model
2828
from official.resnet import resnet_run_loop
29+
from official.utils.misc import model_helpers
2930

3031
_DEFAULT_IMAGE_SIZE = 224
3132
_NUM_CHANNELS = 3
@@ -315,6 +316,9 @@ def main(argv):
315316

316317
flags = parser.parse_args(args=argv[1:])
317318

319+
if flags.seed is not None:
320+
model_helpers.set_random_seed(flags.seed)
321+
318322
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
319323

320324
resnet_run_loop.resnet_main(

official/utils/arg_parsers/parsers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,13 @@ class BaseParser(argparse.ArgumentParser):
104104
batch_size: Create a flag to specify the batch size.
105105
multi_gpu: Create a flag to allow the use of all available GPUs.
106106
hooks: Create a flag to specify hooks for logging.
107+
seed: Create a flag to set random seeds.
107108
"""
108109

109110
def __init__(self, add_help=False, data_dir=True, model_dir=True,
110111
train_epochs=True, epochs_between_evals=True,
111112
stop_threshold=True, batch_size=True, multi_gpu=True,
112-
hooks=True):
113+
hooks=True, seed=True):
113114
super(BaseParser, self).__init__(add_help=add_help)
114115

115116
if data_dir:
@@ -176,6 +177,16 @@ def __init__(self, add_help=False, data_dir=True, model_dir=True,
176177
metavar="<HK>"
177178
)
178179

180+
if seed:
181+
self.add_argument(
182+
"--seed", "-s", nargs="+", type=int, default=None,
183+
help="[default: %(default)s] An integer to seed random number"
184+
"generators. If unset, RNGs choose their own seeds resulting "
185+
"in each run having a different seed.",
186+
metavar="<SEED>"
187+
)
188+
189+
179190

180191
class PerformanceParser(argparse.ArgumentParser):
181192
"""Default parser for specifying performance tuning arguments.

official/utils/misc/model_helpers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
import numbers
22+
import random
2223

2324
import tensorflow as tf
2425

@@ -53,3 +54,25 @@ def past_stop_threshold(stop_threshold, eval_metric):
5354
return True
5455

5556
return False
57+
58+
59+
def set_random_seed(seed):
60+
"""Sets the random seeds for available RNGs.
61+
This seeds RNGs for python's random and for Tensorflow. The intended
62+
use case is for this to be called exactly once at the start of execution
63+
to improve stability and reproducability between runs.
64+
65+
Successive calls to re-seed will not behave as expected. This should
66+
be called at most once.
67+
68+
Args:
69+
seed: integer, a seed which will be passed to the RNGs.
70+
71+
Raises:
72+
ValueError: if the seed is not an integer or if deemed unsuitable for
73+
seeding a the RNGs.
74+
"""
75+
if not isinstance(seed, int):
76+
raise ValueError("Random seed is not an integer: {}".format(seed))
77+
random.seed(seed)
78+
tf.set_random_seed(seed)

official/utils/misc/model_helpers_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import random
22+
2123
import tensorflow as tf # pylint: disable=g-bad-import-order
2224

2325
from official.utils.misc import model_helpers
@@ -64,6 +66,29 @@ def test_past_stop_threshold_not_number(self):
6466
with self.assertRaises(ValueError):
6567
model_helpers.past_stop_threshold(tf.constant(4), None)
6668

69+
def test_random_seed(self):
70+
"""It is unclear if this test is a good idea or stable.
71+
If tests are run in parallel, this could be flakey."""
72+
model_helpers.set_random_seed(42)
73+
expected_py_random = [int(random.random() * 1000) for i in range(10)]
74+
tf_random = []
75+
with tf.Session() as sess:
76+
for i in range(10):
77+
a = tf.random_uniform([1])
78+
tf_random.append(int(sess.run(a)[0] * 1000))
79+
80+
model_helpers.set_random_seed(42)
81+
py_random = [int(random.random() * 1000) for i in range(10)]
82+
83+
# Instead of concerning ourselves with the particular results, we simply
84+
# want to ensure that the results are reproducible. So, we seed, read,
85+
# re-seed, re-read.
86+
self.assertAllEqual(expected_py_random, py_random)
87+
88+
# TF does not accept being re-seeded.
89+
expected_tf_random = [637, 689, 961, 969, 321, 390, 919, 681, 112, 187]
90+
self.assertAllEqual(expected_tf_random, tf_random)
91+
6792

6893
if __name__ == "__main__":
6994
tf.test.main()

official/wide_deep/wide_deep.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def main(argv):
179179
parser = WideDeepArgParser()
180180
flags = parser.parse_args(args=argv[1:])
181181

182+
if flags.seed is not None:
183+
model_helpers.set_random_seed(flags.seed)
184+
182185
# Clean up the model directory if present
183186
shutil.rmtree(flags.model_dir, ignore_errors=True)
184187
model = build_estimator(flags.model_dir, flags.model_type)

0 commit comments

Comments
 (0)