Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 91b2deb

Browse files
author
Taylor Robie
authored
Make flagfile sharing robust to distributed filesystems and multi-worker setups. (tensorflow#5521)
* move flagfile into the cache_dir * remove duplicate code * delint
1 parent 0c5c3a7 commit 91b2deb

File tree

3 files changed

+34
-18
lines changed

3 files changed

+34
-18
lines changed

official/recommendation/data_async_generation.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,8 @@ def remove_alive_file():
440440
gc.collect()
441441

442442

443-
def _parse_flagfile():
443+
def _parse_flagfile(flagfile):
444444
"""Fill flags with flagfile written by the main process."""
445-
flagfile = os.path.join(flags.FLAGS.data_dir,
446-
rconst.FLAGFILE)
447445
tf.logging.info("Waiting for flagfile to appear at {}..."
448446
.format(flagfile))
449447
start_time = time.time()
@@ -455,18 +453,26 @@ def _parse_flagfile():
455453
sys.exit()
456454
time.sleep(1)
457455
tf.logging.info("flagfile found.")
458-
# This overrides FLAGS with flags from flagfile.
459-
flags.FLAGS([__file__, "--flagfile", flagfile])
456+
457+
# `flags` module opens `flagfile` with `open`, which does not work on
458+
# google cloud storage etc.
459+
_, flagfile_temp = tempfile.mkstemp()
460+
tf.gfile.Copy(flagfile, flagfile_temp, overwrite=True)
461+
462+
flags.FLAGS([__file__, "--flagfile", flagfile_temp])
463+
tf.gfile.Remove(flagfile_temp)
460464

461465

462466
def main(_):
463467
global _log_file
464-
_parse_flagfile()
465-
466-
redirect_logs = flags.FLAGS.redirect_logs
467468
cache_paths = rconst.Paths(
468469
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
469470

471+
flagfile = os.path.join(cache_paths.cache_root, rconst.FLAGFILE)
472+
_parse_flagfile(flagfile)
473+
474+
redirect_logs = flags.FLAGS.redirect_logs
475+
470476
log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
471477
log_path = os.path.join(cache_paths.data_dir, log_file_name)
472478
if log_path.startswith("gs://") and redirect_logs:
@@ -518,7 +524,6 @@ def define_flags():
518524
help="Size of the negative generation worker pool.")
519525
flags.DEFINE_string(name="data_dir", default=None,
520526
help="The data root. (used to construct cache paths.)")
521-
flags.mark_flags_as_required(["data_dir"])
522527
flags.DEFINE_string(name="cache_id", default=None,
523528
help="The cache_id generated in the main process.")
524529
flags.DEFINE_integer(name="num_readers", default=4,
@@ -554,6 +559,7 @@ def define_flags():
554559
help="NumPy random seed to set at startup. If not "
555560
"specified, a seed will not be set.")
556561

562+
flags.mark_flags_as_required(["data_dir", "cache_id"])
557563

558564
if __name__ == "__main__":
559565
define_flags()

official/recommendation/data_preprocessing.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ def generate_train_eval_data(df, approx_num_shards, num_items, cache_paths,
357357

358358

359359
def construct_cache(dataset, data_dir, num_data_readers, match_mlperf,
360-
deterministic):
361-
# type: (str, str, int, bool) -> NCFDataset
360+
deterministic, cache_id=None):
361+
# type: (str, str, int, bool, typing.Optional[int]) -> NCFDataset
362362
"""Load and digest data CSV into a usable form.
363363
364364
Args:
@@ -371,7 +371,7 @@ def construct_cache(dataset, data_dir, num_data_readers, match_mlperf,
371371
deterministic: Try to enforce repeatable behavior, even at the cost of
372372
performance.
373373
"""
374-
cache_paths = rconst.Paths(data_dir=data_dir)
374+
cache_paths = rconst.Paths(data_dir=data_dir, cache_id=cache_id)
375375
num_data_readers = (num_data_readers or int(multiprocessing.cpu_count() / 2)
376376
or 1)
377377
approx_num_shards = int(movielens.NUM_RATINGS[dataset]
@@ -436,15 +436,16 @@ def _shutdown(proc):
436436
def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
437437
num_data_readers=None, num_neg=4, epochs_per_cycle=1,
438438
match_mlperf=False, deterministic=False,
439-
use_subprocess=True):
439+
use_subprocess=True, cache_id=None):
440440
# type: (...) -> (NCFDataset, typing.Callable)
441441
"""Preprocess data and start negative generation subprocess."""
442442

443443
tf.logging.info("Beginning data preprocessing.")
444444
ncf_dataset = construct_cache(dataset=dataset, data_dir=data_dir,
445445
num_data_readers=num_data_readers,
446446
match_mlperf=match_mlperf,
447-
deterministic=deterministic)
447+
deterministic=deterministic,
448+
cache_id=cache_id)
448449
# By limiting the number of workers we guarantee that the worker
449450
# pool underlying the training generation doesn't starve other processes.
450451
num_workers = int(multiprocessing.cpu_count() * 0.75) or 1
@@ -473,13 +474,14 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
473474
# We write to a temp file then atomically rename it to the final file,
474475
# because writing directly to the final file can cause the data generation
475476
# async process to read a partially written JSON file.
476-
flagfile_temp = os.path.join(data_dir, rconst.FLAGFILE_TEMP)
477+
flagfile_temp = os.path.join(ncf_dataset.cache_paths.cache_root,
478+
rconst.FLAGFILE_TEMP)
477479
tf.logging.info("Preparing flagfile for async data generation in {} ..."
478480
.format(flagfile_temp))
479481
with tf.gfile.Open(flagfile_temp, "w") as f:
480482
for k, v in six.iteritems(flags_):
481483
f.write("--{}={}\n".format(k, v))
482-
flagfile = os.path.join(data_dir, rconst.FLAGFILE)
484+
flagfile = os.path.join(ncf_dataset.cache_paths.cache_root, rconst.FLAGFILE)
483485
tf.gfile.Rename(flagfile_temp, flagfile)
484486
tf.logging.info(
485487
"Wrote flagfile for async data generation in {}."
@@ -493,7 +495,8 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
493495
# contention with the main training process.
494496
subproc_env["CUDA_VISIBLE_DEVICES"] = ""
495497
subproc_args = popen_helper.INVOCATION + [
496-
"--data_dir", data_dir]
498+
"--data_dir", data_dir,
499+
"--cache_id", str(ncf_dataset.cache_paths.cache_id)]
497500
tf.logging.info(
498501
"Generation subprocess command: {}".format(" ".join(subproc_args)))
499502
proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)

official/recommendation/ncf_main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def run_ncf(_):
152152
epochs_per_cycle=FLAGS.epochs_between_evals,
153153
match_mlperf=FLAGS.ml_perf,
154154
deterministic=FLAGS.seed is not None,
155-
use_subprocess=FLAGS.use_subprocess)
155+
use_subprocess=FLAGS.use_subprocess,
156+
cache_id=FLAGS.cache_id)
156157
num_users = ncf_dataset.num_users
157158
num_items = ncf_dataset.num_items
158159
approx_train_steps = int(ncf_dataset.num_train_positives
@@ -387,6 +388,12 @@ def eval_size_check(eval_batch_size):
387388
"subprocess. If set to False, ncf_main.py will assume the async data "
388389
"generation process has already been started by the user."))
389390

391+
flags.DEFINE_integer(name="cache_id", default=None, help=flags_core.help_wrap(
392+
"Use a specified cache_id rather than using a timestamp. This is only "
393+
"needed to synchronize across multiple workers. Generally this flag will "
394+
"not need to be set."
395+
))
396+
390397

391398
if __name__ == "__main__":
392399
tf.logging.set_verbosity(tf.logging.INFO)

0 commit comments

Comments
 (0)