48
48
from official .recommendation import popen_helper
49
49
50
50
51
+ DATASET_TO_NUM_USERS_AND_ITEMS = {
52
+ "ml-1m" : (6040 , 3706 ),
53
+ "ml-20m" : (138493 , 26744 )
54
+ }
55
+
56
+
57
+ # Number of batches to run per epoch when using synthetic data. At high batch
58
+ # sizes, we run for more batches than with real data, which is good since
59
+ # running more batches reduces noise when measuring the average batches/second.
60
+ _SYNTHETIC_BATCHES_PER_EPOCH = 2000
61
+
62
+
51
63
class NCFDataset (object ):
52
64
"""Container for training and testing data."""
53
65
@@ -376,6 +388,14 @@ def construct_cache(dataset, data_dir, num_data_readers, match_mlperf,
376
388
377
389
raw_rating_path = os .path .join (data_dir , dataset , movielens .RATINGS_FILE )
378
390
df , user_map , item_map = _filter_index_sort (raw_rating_path , match_mlperf )
391
+ num_users , num_items = DATASET_TO_NUM_USERS_AND_ITEMS [dataset ]
392
+
393
+ if num_users != len (user_map ):
394
+ raise ValueError ("Expected to find {} users, but found {}" .format (
395
+ num_users , len (user_map )))
396
+ if num_items != len (item_map ):
397
+ raise ValueError ("Expected to find {} items, but found {}" .format (
398
+ num_items , len (item_map )))
379
399
380
400
generate_train_eval_data (df = df , approx_num_shards = approx_num_shards ,
381
401
num_items = len (item_map ), cache_paths = cache_paths ,
@@ -570,9 +590,12 @@ def hash_pipeline(dataset, deterministic):
570
590
571
591
572
592
def make_train_input_fn (ncf_dataset ):
573
- # type: (NCFDataset) -> (typing.Callable, str, int)
593
+ # type: (typing.Optional[ NCFDataset] ) -> (typing.Callable, str, int)
574
594
"""Construct training input_fn for the current epoch."""
575
595
596
+ if ncf_dataset is None :
597
+ return make_train_synthetic_input_fn ()
598
+
576
599
if not tf .gfile .Exists (ncf_dataset .cache_paths .subproc_alive ):
577
600
# The generation subprocess must have been alive at some point, because we
578
601
# earlier checked that the subproc_alive file existed.
@@ -644,10 +667,40 @@ def input_fn(params):
644
667
return input_fn , record_dir , batch_count
645
668
646
669
670
+ def make_train_synthetic_input_fn ():
671
+ """Construct training input_fn that uses synthetic data."""
672
+ def input_fn (params ):
673
+ """Generated input_fn for the given epoch."""
674
+ batch_size = params ["batch_size" ]
675
+ num_users = params ["num_users" ]
676
+ num_items = params ["num_items" ]
677
+
678
+ users = tf .random_uniform ([batch_size ], dtype = tf .int32 , minval = 0 ,
679
+ maxval = num_users )
680
+ items = tf .random_uniform ([batch_size ], dtype = tf .int32 , minval = 0 ,
681
+ maxval = num_items )
682
+ labels = tf .random_uniform ([batch_size ], dtype = tf .int32 , minval = 0 ,
683
+ maxval = 2 )
684
+
685
+ data = {
686
+ movielens .USER_COLUMN : users ,
687
+ movielens .ITEM_COLUMN : items ,
688
+ }, labels
689
+ dataset = tf .data .Dataset .from_tensors (data ).repeat (
690
+ _SYNTHETIC_BATCHES_PER_EPOCH )
691
+ dataset = dataset .prefetch (32 )
692
+ return dataset
693
+
694
+ return input_fn , None , _SYNTHETIC_BATCHES_PER_EPOCH
695
+
696
+
647
697
def make_pred_input_fn (ncf_dataset ):
648
- # type: (NCFDataset) -> typing.Callable
698
+ # type: (typing.Optional[ NCFDataset] ) -> typing.Callable
649
699
"""Construct input_fn for metric evaluation."""
650
700
701
+ if ncf_dataset is None :
702
+ return make_synthetic_pred_input_fn ()
703
+
651
704
def input_fn (params ):
652
705
"""Input function based on eval batch size."""
653
706
@@ -672,3 +725,32 @@ def input_fn(params):
672
725
return dataset
673
726
674
727
return input_fn
728
+
729
+
730
+ def make_synthetic_pred_input_fn ():
731
+ """Construct input_fn for metric evaluation that uses synthetic data."""
732
+
733
+ def input_fn (params ):
734
+ """Generated input_fn for the given epoch."""
735
+ batch_size = params ["eval_batch_size" ]
736
+ num_users = params ["num_users" ]
737
+ num_items = params ["num_items" ]
738
+
739
+ users = tf .random_uniform ([batch_size ], dtype = tf .int32 , minval = 0 ,
740
+ maxval = num_users )
741
+ items = tf .random_uniform ([batch_size ], dtype = tf .int32 , minval = 0 ,
742
+ maxval = num_items )
743
+ dupe_mask = tf .cast (tf .random_uniform ([batch_size ], dtype = tf .int32 ,
744
+ minval = 0 , maxval = 2 ), tf .bool )
745
+
746
+ data = {
747
+ movielens .USER_COLUMN : users ,
748
+ movielens .ITEM_COLUMN : items ,
749
+ rconst .DUPLICATE_MASK : dupe_mask ,
750
+ }
751
+ dataset = tf .data .Dataset .from_tensors (data ).repeat (
752
+ _SYNTHETIC_BATCHES_PER_EPOCH )
753
+ dataset = dataset .prefetch (16 )
754
+ return dataset
755
+
756
+ return input_fn
0 commit comments