21
21
import atexit
22
22
import contextlib
23
23
import gc
24
+ import hashlib
24
25
import multiprocessing
25
26
import json
26
27
import os
@@ -50,7 +51,7 @@ class NCFDataset(object):
50
51
"""Container for training and testing data."""
51
52
52
53
def __init__ (self , user_map , item_map , num_data_readers , cache_paths ,
53
- num_train_positives ):
54
+ num_train_positives , deterministic = False ):
54
55
# type: (dict, dict, int, rconst.Paths) -> None
55
56
"""Assign key values for recommendation dataset.
56
57
@@ -61,6 +62,8 @@ def __init__(self, user_map, item_map, num_data_readers, cache_paths,
61
62
cache_paths: Object containing locations for various cache files.
62
63
num_train_positives: The number of positive training examples in the
63
64
dataset.
65
+ deterministic: Operations should use deterministic, order preserving
66
+ methods, even at the cost of performance.
64
67
"""
65
68
66
69
self .user_map = {int (k ): int (v ) for k , v in user_map .items ()}
@@ -70,6 +73,7 @@ def __init__(self, user_map, item_map, num_data_readers, cache_paths,
70
73
self .num_data_readers = num_data_readers
71
74
self .cache_paths = cache_paths
72
75
self .num_train_positives = num_train_positives
76
+ self .deterministic = deterministic
73
77
74
78
75
79
def _filter_index_sort (raw_rating_path , match_mlperf ):
@@ -340,7 +344,8 @@ def generate_train_eval_data(df, approx_num_shards, num_items, cache_paths,
340
344
pickle .dump (eval_data , f , protocol = pickle .HIGHEST_PROTOCOL )
341
345
342
346
343
- def construct_cache (dataset , data_dir , num_data_readers , match_mlperf ):
347
+ def construct_cache (dataset , data_dir , num_data_readers , match_mlperf ,
348
+ deterministic ):
344
349
# type: (str, str, int, bool) -> NCFDataset
345
350
"""Load and digest data CSV into a usable form.
346
351
@@ -351,6 +356,8 @@ def construct_cache(dataset, data_dir, num_data_readers, match_mlperf):
351
356
data during training.
352
357
match_mlperf: If True, change the behavior of the cache construction to
353
358
match the MLPerf reference implementation.
359
+ deterministic: Try to enforce repeatable behavior, even at the cost of
360
+ performance.
354
361
"""
355
362
cache_paths = rconst .Paths (data_dir = data_dir )
356
363
num_data_readers = (num_data_readers or int (multiprocessing .cpu_count () / 2 )
@@ -377,7 +384,8 @@ def construct_cache(dataset, data_dir, num_data_readers, match_mlperf):
377
384
ncf_dataset = NCFDataset (user_map = user_map , item_map = item_map ,
378
385
num_data_readers = num_data_readers ,
379
386
cache_paths = cache_paths ,
380
- num_train_positives = len (df ) - len (user_map ))
387
+ num_train_positives = len (df ) - len (user_map ),
388
+ deterministic = deterministic )
381
389
382
390
run_time = timeit .default_timer () - st
383
391
tf .logging .info ("Cache construction complete. Time: {:.1f} sec."
@@ -403,13 +411,15 @@ def _shutdown(proc):
403
411
404
412
def instantiate_pipeline (dataset , data_dir , batch_size , eval_batch_size ,
405
413
num_data_readers = None , num_neg = 4 , epochs_per_cycle = 1 ,
406
- match_mlperf = False ):
414
+ match_mlperf = False , deterministic = False ):
415
+ # type: (...) -> (NCFDataset, typing.Callable)
407
416
"""Preprocess data and start negative generation subprocess."""
408
417
409
418
tf .logging .info ("Beginning data preprocessing." )
410
419
ncf_dataset = construct_cache (dataset = dataset , data_dir = data_dir ,
411
420
num_data_readers = num_data_readers ,
412
- match_mlperf = match_mlperf )
421
+ match_mlperf = match_mlperf ,
422
+ deterministic = deterministic )
413
423
414
424
tf .logging .info ("Creating training file subprocess." )
415
425
@@ -439,18 +449,30 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
439
449
# guarantee batch size and significantly improves
440
450
# performance. (~5% increase in examples/sec on
441
451
# GPU, and needed for TPU XLA.)
442
- "--redirect_logs" , "True" ,
443
- "--seed" , str (int (stat_utils .random_int32 ()))
452
+ "--redirect_logs" , "True"
444
453
]
454
+ if ncf_dataset .deterministic :
455
+ subproc_args .extend (["--seed" , str (int (stat_utils .random_int32 ()))])
445
456
446
457
tf .logging .info (
447
458
"Generation subprocess command: {}" .format (" " .join (subproc_args )))
448
459
449
460
proc = subprocess .Popen (args = subproc_args , shell = False , env = subproc_env )
450
461
451
- atexit .register (_shutdown , proc = proc )
452
- atexit .register (tf .gfile .DeleteRecursively ,
453
- ncf_dataset .cache_paths .cache_root )
462
+ cleanup_called = {"finished" : False }
463
+ @atexit .register
464
+ def cleanup ():
465
+ """Remove files and subprocess from data generation."""
466
+ if cleanup_called ["finished" ]:
467
+ return
468
+
469
+ _shutdown (proc )
470
+ try :
471
+ tf .gfile .DeleteRecursively (ncf_dataset .cache_paths .cache_root )
472
+ except tf .errors .NotFoundError :
473
+ pass
474
+
475
+ cleanup_called ["finished" ] = True
454
476
455
477
for _ in range (300 ):
456
478
if tf .gfile .Exists (ncf_dataset .cache_paths .subproc_alive ):
@@ -460,7 +482,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
460
482
raise ValueError ("Generation subprocess did not start correctly. Data will "
461
483
"not be available; exiting to avoid waiting forever." )
462
484
463
- return ncf_dataset
485
+ return ncf_dataset , cleanup
464
486
465
487
466
488
def make_deserialize (params , batch_size , training = False ):
@@ -498,6 +520,44 @@ def deserialize(examples_serialized):
498
520
return deserialize
499
521
500
522
523
+ def hash_pipeline (dataset , deterministic ):
524
+ # type: (tf.data.Dataset, bool) -> None
525
+ """Utility function for detecting non-determinism in the data pipeline.
526
+
527
+ Args:
528
+ dataset: a tf.data.Dataset generated by the input_fn
529
+ deterministic: Does the input_fn expect the dataset to be deterministic.
530
+ (i.e. fixed seed, sloppy=False, etc.)
531
+ """
532
+ if not deterministic :
533
+ tf .logging .warning ("Data pipeline is not marked as deterministic. Hash "
534
+ "values are not expected to be meaningful." )
535
+
536
+ batch = dataset .make_one_shot_iterator ().get_next ()
537
+ md5 = hashlib .md5 ()
538
+ count = 0
539
+ first_batch_hash = b""
540
+ with tf .Session () as sess :
541
+ while True :
542
+ try :
543
+ result = sess .run (batch )
544
+ if isinstance (result , tuple ):
545
+ result = result [0 ] # only hash features
546
+ except tf .errors .OutOfRangeError :
547
+ break
548
+
549
+ count += 1
550
+ md5 .update (memoryview (result [movielens .USER_COLUMN ]).tobytes ())
551
+ md5 .update (memoryview (result [movielens .ITEM_COLUMN ]).tobytes ())
552
+ if count == 1 :
553
+ first_batch_hash = md5 .hexdigest ()
554
+ overall_hash = md5 .hexdigest ()
555
+ tf .logging .info ("Batch count: {}" .format (count ))
556
+ tf .logging .info (" [pipeline_hash] First batch hash: {}" .format (
557
+ first_batch_hash ))
558
+ tf .logging .info (" [pipeline_hash] All batches hash: {}" .format (overall_hash ))
559
+
560
+
501
561
def make_train_input_fn (ncf_dataset ):
502
562
# type: (NCFDataset) -> (typing.Callable, str, int)
503
563
"""Construct training input_fn for the current epoch."""
@@ -556,14 +616,19 @@ def input_fn(params):
556
616
tf .data .TFRecordDataset ,
557
617
cycle_length = 4 ,
558
618
block_length = 100000 ,
559
- sloppy = True ,
619
+ sloppy = not ncf_dataset . deterministic ,
560
620
prefetch_input_elements = 4 ,
561
621
)
562
622
563
623
deserialize = make_deserialize (params , batch_size , True )
564
624
dataset = record_files .apply (interleave )
565
625
dataset = dataset .map (deserialize , num_parallel_calls = 4 )
566
- return dataset .prefetch (32 )
626
+ dataset = dataset .prefetch (32 )
627
+
628
+ if params .get ("hash_pipeline" ):
629
+ hash_pipeline (dataset , ncf_dataset .deterministic )
630
+
631
+ return dataset
567
632
568
633
return input_fn , record_dir , batch_count
569
634
@@ -588,7 +653,11 @@ def input_fn(params):
588
653
589
654
deserialize = make_deserialize (params , batch_size , False )
590
655
dataset = dataset .map (deserialize , num_parallel_calls = 4 )
656
+ dataset = dataset .prefetch (16 )
657
+
658
+ if params .get ("hash_pipeline" ):
659
+ hash_pipeline (dataset , ncf_dataset .deterministic )
591
660
592
- return dataset . prefetch ( 16 )
661
+ return dataset
593
662
594
663
return input_fn
0 commit comments