-
Notifications
You must be signed in to change notification settings - Fork 198
Expand file tree
/
Copy pathtfkeras_integration.py
More file actions
147 lines (107 loc) · 4.14 KB
/
tfkeras_integration.py
File metadata and controls
147 lines (107 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Optuna example that demonstrates a pruner for tf.keras.
In this example, we optimize the validation accuracy of hand-written digit recognition
using tf.keras and MNIST, where the architecture of the neural network
and the parameters of optimizer are optimized.
Throughout the training of neural networks,
a pruner observes intermediate results and stops unpromising trials.
You can run this example as follows:
$ python tfkeras_integration.py
"""
import urllib
import optuna
from optuna.integration import TFKerasPruningCallback
from optuna.trial import TrialState
import tensorflow_datasets as tfds
import tensorflow as tf
# TODO(crcrpar): Remove the below three lines once everything is ok.
# Register a global custom opener to avoid HTTP Error 403: Forbidden when downloading MNIST.
opener = urllib.request.build_opener()
opener.addheaders = [("User-agent", "Mozilla/5.0")]
urllib.request.install_opener(opener)
BATCHSIZE = 128
CLASSES = 10
EPOCHS = 20
N_TRAIN_EXAMPLES = 3000
STEPS_PER_EPOCH = int(N_TRAIN_EXAMPLES / BATCHSIZE / 10)
VALIDATION_STEPS = 30
def train_dataset():
ds = tfds.load("mnist", split=tfds.Split.TRAIN, shuffle_files=True)
ds = ds.map(lambda x: (tf.cast(x["image"], tf.float32) / 255.0, x["label"]))
ds = ds.repeat().shuffle(1024).batch(BATCHSIZE)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
def eval_dataset():
ds = tfds.load("mnist", split=tfds.Split.TEST, shuffle_files=False)
ds = ds.map(lambda x: (tf.cast(x["image"], tf.float32) / 255.0, x["label"]))
ds = ds.repeat().batch(BATCHSIZE)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
def create_model(trial):
# Hyperparameters to be tuned by Optuna.
learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True)
momentum = trial.suggest_float("momentum", 0.0, 1.0)
units = trial.suggest_categorical("units", [32, 64, 128, 256, 512])
# Compose neural network with one hidden layer.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units=units, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(CLASSES, activation=tf.nn.softmax))
# Compile model.
model.compile(
optimizer=tf.keras.optimizers.SGD(
learning_rate=learning_rate, momentum=momentum, nesterov=True
),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
return model
def objective(trial):
# Clear clutter from previous TensorFlow graphs.
tf.keras.backend.clear_session()
# Metrics to be monitored by Optuna.
if tf.__version__ >= "2":
monitor = "val_accuracy"
else:
monitor = "val_acc"
# Create tf.keras model instance.
model = create_model(trial)
# Create dataset instance.
ds_train = train_dataset()
ds_eval = eval_dataset()
# Create callbacks for early stopping and pruning.
callbacks = [
tf.keras.callbacks.EarlyStopping(patience=3),
TFKerasPruningCallback(trial, monitor),
]
# Train model.
history = model.fit(
ds_train,
epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=ds_eval,
validation_steps=VALIDATION_STEPS,
callbacks=callbacks,
)
return history.history[monitor][-1]
def show_result(study):
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
def main():
study = optuna.create_study(
direction="maximize", pruner=optuna.pruners.MedianPruner(n_startup_trials=2)
)
study.optimize(objective, n_trials=25, timeout=600)
show_result(study)
if __name__ == "__main__":
main()