Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 75d592e

Browse files
reedwmTaylor Robie
authored and
Taylor Robie
committed
Add --use_synthetic_data option to NCF. (tensorflow#5468)
* Add --use_synthetic_data option to NCF. * Add comment to _SYNTHETIC_BATCHES_PER_EPOCH * Fix test * Hopefully fix lint issue
1 parent 42f9821 commit 75d592e

File tree

3 files changed

+113
-18
lines changed

3 files changed

+113
-18
lines changed

official/recommendation/data_preprocessing.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@
4848
from official.recommendation import popen_helper
4949

5050

51+
DATASET_TO_NUM_USERS_AND_ITEMS = {
52+
"ml-1m": (6040, 3706),
53+
"ml-20m": (138493, 26744)
54+
}
55+
56+
57+
# Number of batches to run per epoch when using synthetic data. At high batch
58+
# sizes, we run for more batches than with real data, which is good since
59+
# running more batches reduces noise when measuring the average batches/second.
60+
_SYNTHETIC_BATCHES_PER_EPOCH = 2000
61+
62+
5163
class NCFDataset(object):
5264
"""Container for training and testing data."""
5365

@@ -376,6 +388,14 @@ def construct_cache(dataset, data_dir, num_data_readers, match_mlperf,
376388

377389
raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
378390
df, user_map, item_map = _filter_index_sort(raw_rating_path, match_mlperf)
391+
num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
392+
393+
if num_users != len(user_map):
394+
raise ValueError("Expected to find {} users, but found {}".format(
395+
num_users, len(user_map)))
396+
if num_items != len(item_map):
397+
raise ValueError("Expected to find {} items, but found {}".format(
398+
num_items, len(item_map)))
379399

380400
generate_train_eval_data(df=df, approx_num_shards=approx_num_shards,
381401
num_items=len(item_map), cache_paths=cache_paths,
@@ -570,9 +590,12 @@ def hash_pipeline(dataset, deterministic):
570590

571591

572592
def make_train_input_fn(ncf_dataset):
573-
# type: (NCFDataset) -> (typing.Callable, str, int)
593+
# type: (typing.Optional[NCFDataset]) -> (typing.Callable, str, int)
574594
"""Construct training input_fn for the current epoch."""
575595

596+
if ncf_dataset is None:
597+
return make_train_synthetic_input_fn()
598+
576599
if not tf.gfile.Exists(ncf_dataset.cache_paths.subproc_alive):
577600
# The generation subprocess must have been alive at some point, because we
578601
# earlier checked that the subproc_alive file existed.
@@ -644,10 +667,40 @@ def input_fn(params):
644667
return input_fn, record_dir, batch_count
645668

646669

670+
def make_train_synthetic_input_fn():
671+
"""Construct training input_fn that uses synthetic data."""
672+
def input_fn(params):
673+
"""Generated input_fn for the given epoch."""
674+
batch_size = params["batch_size"]
675+
num_users = params["num_users"]
676+
num_items = params["num_items"]
677+
678+
users = tf.random_uniform([batch_size], dtype=tf.int32, minval=0,
679+
maxval=num_users)
680+
items = tf.random_uniform([batch_size], dtype=tf.int32, minval=0,
681+
maxval=num_items)
682+
labels = tf.random_uniform([batch_size], dtype=tf.int32, minval=0,
683+
maxval=2)
684+
685+
data = {
686+
movielens.USER_COLUMN: users,
687+
movielens.ITEM_COLUMN: items,
688+
}, labels
689+
dataset = tf.data.Dataset.from_tensors(data).repeat(
690+
_SYNTHETIC_BATCHES_PER_EPOCH)
691+
dataset = dataset.prefetch(32)
692+
return dataset
693+
694+
return input_fn, None, _SYNTHETIC_BATCHES_PER_EPOCH
695+
696+
647697
def make_pred_input_fn(ncf_dataset):
648-
# type: (NCFDataset) -> typing.Callable
698+
# type: (typing.Optional[NCFDataset]) -> typing.Callable
649699
"""Construct input_fn for metric evaluation."""
650700

701+
if ncf_dataset is None:
702+
return make_synthetic_pred_input_fn()
703+
651704
def input_fn(params):
652705
"""Input function based on eval batch size."""
653706

@@ -672,3 +725,32 @@ def input_fn(params):
672725
return dataset
673726

674727
return input_fn
728+
729+
730+
def make_synthetic_pred_input_fn():
731+
"""Construct input_fn for metric evaluation that uses synthetic data."""
732+
733+
def input_fn(params):
734+
"""Generated input_fn for the given epoch."""
735+
batch_size = params["eval_batch_size"]
736+
num_users = params["num_users"]
737+
num_items = params["num_items"]
738+
739+
users = tf.random_uniform([batch_size], dtype=tf.int32, minval=0,
740+
maxval=num_users)
741+
items = tf.random_uniform([batch_size], dtype=tf.int32, minval=0,
742+
maxval=num_items)
743+
dupe_mask = tf.cast(tf.random_uniform([batch_size], dtype=tf.int32,
744+
minval=0, maxval=2), tf.bool)
745+
746+
data = {
747+
movielens.USER_COLUMN: users,
748+
movielens.ITEM_COLUMN: items,
749+
rconst.DUPLICATE_MASK: dupe_mask,
750+
}
751+
dataset = tf.data.Dataset.from_tensors(data).repeat(
752+
_SYNTHETIC_BATCHES_PER_EPOCH)
753+
dataset = dataset.prefetch(16)
754+
return dataset
755+
756+
return input_fn

official/recommendation/data_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def setUp(self):
8080

8181
movielens.download = mock_download
8282
movielens.NUM_RATINGS[DATASET] = NUM_PTS
83+
data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[DATASET] = (NUM_USERS,
84+
NUM_ITEMS)
8385

8486
def test_preprocessing(self):
8587
# For the most part the necessary checks are performed within

official/recommendation/ncf_main.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def main(_):
118118

119119
def run_ncf(_):
120120
"""Run NCF training and eval loop."""
121-
if FLAGS.download_if_missing:
121+
if FLAGS.download_if_missing and not FLAGS.use_synthetic_data:
122122
movielens.download(FLAGS.dataset, FLAGS.data_dir)
123123

124124
if FLAGS.seed is not None:
@@ -137,14 +137,25 @@ def run_ncf(_):
137137
"eval examples per user does not evenly divide eval_batch_size. "
138138
"Overriding to {}".format(eval_batch_size))
139139

140-
ncf_dataset, cleanup_fn = data_preprocessing.instantiate_pipeline(
141-
dataset=FLAGS.dataset, data_dir=FLAGS.data_dir,
142-
batch_size=batch_size,
143-
eval_batch_size=eval_batch_size,
144-
num_neg=FLAGS.num_neg,
145-
epochs_per_cycle=FLAGS.epochs_between_evals,
146-
match_mlperf=FLAGS.ml_perf,
147-
deterministic=FLAGS.seed is not None)
140+
if FLAGS.use_synthetic_data:
141+
ncf_dataset = None
142+
cleanup_fn = lambda: None
143+
num_users, num_items = data_preprocessing.DATASET_TO_NUM_USERS_AND_ITEMS[
144+
FLAGS.dataset]
145+
approx_train_steps = None
146+
else:
147+
ncf_dataset, cleanup_fn = data_preprocessing.instantiate_pipeline(
148+
dataset=FLAGS.dataset, data_dir=FLAGS.data_dir,
149+
batch_size=batch_size,
150+
eval_batch_size=eval_batch_size,
151+
num_neg=FLAGS.num_neg,
152+
epochs_per_cycle=FLAGS.epochs_between_evals,
153+
match_mlperf=FLAGS.ml_perf,
154+
deterministic=FLAGS.seed is not None)
155+
num_users = ncf_dataset.num_users
156+
num_items = ncf_dataset.num_items
157+
approx_train_steps = int(ncf_dataset.num_train_positives
158+
* (1 + FLAGS.num_neg) // FLAGS.batch_size)
148159

149160
model_helpers.apply_clean(flags.FLAGS)
150161

@@ -153,9 +164,10 @@ def run_ncf(_):
153164
"use_seed": FLAGS.seed is not None,
154165
"hash_pipeline": FLAGS.hash_pipeline,
155166
"batch_size": batch_size,
167+
"eval_batch_size": eval_batch_size,
156168
"learning_rate": FLAGS.learning_rate,
157-
"num_users": ncf_dataset.num_users,
158-
"num_items": ncf_dataset.num_items,
169+
"num_users": num_users,
170+
"num_items": num_items,
159171
"mf_dim": FLAGS.num_factors,
160172
"model_layers": [int(layer) for layer in FLAGS.layers],
161173
"mf_regularization": FLAGS.mf_regularization,
@@ -192,8 +204,6 @@ def run_ncf(_):
192204
run_params=run_params,
193205
test_id=FLAGS.benchmark_test_id)
194206

195-
approx_train_steps = int(ncf_dataset.num_train_positives
196-
* (1 + FLAGS.num_neg) // FLAGS.batch_size)
197207
pred_input_fn = data_preprocessing.make_pred_input_fn(ncf_dataset=ncf_dataset)
198208

199209
total_training_cycle = FLAGS.train_epochs // FLAGS.epochs_between_evals
@@ -205,14 +215,15 @@ def run_ncf(_):
205215
train_input_fn, train_record_dir, batch_count = \
206216
data_preprocessing.make_train_input_fn(ncf_dataset=ncf_dataset)
207217

208-
if np.abs(approx_train_steps - batch_count) > 1:
218+
if approx_train_steps and np.abs(approx_train_steps - batch_count) > 1:
209219
tf.logging.warning(
210220
"Estimated ({}) and reported ({}) number of batches differ by more "
211221
"than one".format(approx_train_steps, batch_count))
212222

213223
train_estimator.train(input_fn=train_input_fn, hooks=train_hooks,
214224
steps=batch_count)
215-
tf.gfile.DeleteRecursively(train_record_dir)
225+
if train_record_dir:
226+
tf.gfile.DeleteRecursively(train_record_dir)
216227

217228
tf.logging.info("Beginning evaluation.")
218229
eval_results = eval_estimator.evaluate(pred_input_fn)
@@ -245,7 +256,7 @@ def define_ncf_flags():
245256
num_parallel_calls=False,
246257
inter_op=False,
247258
intra_op=False,
248-
synthetic_data=False,
259+
synthetic_data=True,
249260
max_train_steps=False,
250261
dtype=False,
251262
all_reduce_alg=False

0 commit comments

Comments
 (0)