Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c49b3b3

Browse files
committed
updated the sample following the latest sample
1 parent ce30694 commit c49b3b3

File tree

1 file changed

+58
-54
lines changed

1 file changed

+58
-54
lines changed

home/samples/DigitsRecognitionNeuralNetwork.ipynb

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@
5151
"int batch_size = 100;\n",
5252
"float learning_rate = 0.001f;\n",
5353
"int h1 = 200; // number of nodes in the 1st hidden layer\n",
54+
"\n",
55+
"Tensor x, y;\n",
56+
"Tensor loss, accuracy;\n",
57+
"Operation optimizer;\n",
58+
"\n",
5459
"int display_freq = 100;\n",
5560
"float accuracy_test = 0f;\n",
5661
"float loss_test = 1f;"
@@ -91,25 +96,30 @@
9196
"metadata": {},
9297
"outputs": [],
9398
"source": [
94-
"var graph = new Graph().as_default();\n",
95-
"\n",
96-
"// Placeholders for inputs (x) and outputs(y)\n",
97-
"var x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: \"X\");\n",
98-
"var y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: \"Y\");\n",
99-
"\n",
100-
"// Create a fully-connected layer with h1 nodes as hidden layer\n",
101-
"var fc1 = fc_layer(x, h1, \"FC1\", use_relu: true);\n",
102-
"// Create a fully-connected layer with n_classes nodes as output layer\n",
103-
"var output_logits = fc_layer(fc1, n_classes, \"OUT\", use_relu: false);\n",
104-
"// Define the loss function, optimizer, and accuracy\n",
105-
"var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);\n",
106-
"var loss = tf.reduce_mean(logits, name: \"loss\");\n",
107-
"var optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: \"Adam-op\").minimize(loss);\n",
108-
"var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: \"correct_pred\");\n",
109-
"var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: \"accuracy\");\n",
110-
"\n",
111-
"// Network predictions\n",
112-
"var cls_prediction = tf.argmax(output_logits, axis: 1, name: \"predictions\");"
99+
"public Graph BuildGraph()\n",
100+
"{\n",
101+
" var graph = new Graph().as_default();\n",
102+
"\n",
103+
" // Placeholders for inputs (x) and outputs(y)\n",
104+
" x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: \"X\");\n",
105+
" y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: \"Y\");\n",
106+
"\n",
107+
" // Create a fully-connected layer with h1 nodes as hidden layer\n",
108+
" var fc1 = fc_layer(x, h1, \"FC1\", use_relu: true);\n",
109+
" // Create a fully-connected layer with n_classes nodes as output layer\n",
110+
" var output_logits = fc_layer(fc1, n_classes, \"OUT\", use_relu: false);\n",
111+
" // Define the loss function, optimizer, and accuracy\n",
112+
" var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);\n",
113+
" loss = tf.reduce_mean(logits, name: \"loss\");\n",
114+
" optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: \"Adam-op\").minimize(loss);\n",
115+
" var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: \"correct_pred\");\n",
116+
" accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: \"accuracy\");\n",
117+
"\n",
118+
" // Network predictions\n",
119+
" var cls_prediction = tf.argmax(output_logits, axis: 1, name: \"predictions\");\n",
120+
"\n",
121+
" return graph;\n",
122+
"}"
113123
]
114124
},
115125
{
@@ -118,7 +128,7 @@
118128
"metadata": {},
119129
"outputs": [],
120130
"source": [
121-
"void Train(Session sess, Datasets<MnistDataSet> mnist)\n",
131+
"public void Train(Session sess, Datasets<MnistDataSet> mnist)\n",
122132
"{\n",
123133
" // Number of training iterations in each epoch\n",
124134
" var num_tr_iter = mnist.Train.Labels.shape[0] / batch_size;\n",
@@ -129,9 +139,12 @@
129139
" float loss_val = 100.0f;\n",
130140
" float accuracy_val = 0f;\n",
131141
"\n",
142+
" var sw = new Stopwatch();\n",
143+
" sw.Start();\n",
144+
"\n",
132145
" foreach (var epoch in range(epochs))\n",
133146
" {\n",
134-
" Console.WriteLine($\"Training epoch: {epoch + 1}\");\n",
147+
" print($\"Training epoch: {epoch + 1}\");\n",
135148
" // Randomly shuffle the training data at the beginning of each epoch \n",
136149
" var (x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels);\n",
137150
"\n",
@@ -142,35 +155,24 @@
142155
" var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);\n",
143156
"\n",
144157
" // Run optimization op (backprop)\n",
145-
" sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));\n",
158+
" sess.run(optimizer, (x, x_batch), (y, y_batch));\n",
146159
"\n",
147160
" if (iteration % display_freq == 0)\n",
148161
" {\n",
149162
" // Calculate and display the batch loss and accuracy\n",
150-
" var result = sess.run(\n",
151-
" new[] { loss, accuracy },\n",
152-
" new FeedItem(x, x_batch),\n",
153-
" new FeedItem(y, y_batch));\n",
154-
" loss_val = result[0];\n",
155-
" accuracy_val = result[1];\n",
156-
" Console.WriteLine($\"iter {iteration.ToString(\"000\")}: Loss={loss_val.ToString(\"0.0000\")}, Training Accuracy={accuracy_val.ToString(\"P\")}\");\n",
163+
" (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch));\n",
164+
" print($\"iter {iteration.ToString(\"000\")}: Loss={loss_val.ToString(\"0.0000\")}, Training Accuracy={accuracy_val.ToString(\"P\")} {sw.ElapsedMilliseconds}ms\");\n",
165+
" sw.Restart();\n",
157166
" }\n",
158167
" }\n",
159168
"\n",
160169
" // Run validation after every epoch\n",
161-
" var results1 = sess.run(\n",
162-
" new[] { loss, accuracy },\n",
163-
" new FeedItem(x, mnist.Validation.Data),\n",
164-
" new FeedItem(y, mnist.Validation.Labels)\n",
165-
" );\n",
166-
" \n",
167-
" loss_val = results1[0];\n",
168-
" accuracy_val = results1[1];\n",
169-
" Console.WriteLine(\"---------------------------------------------------------\");\n",
170-
" Console.WriteLine($\"Epoch: {epoch + 1}, validation loss: {loss_val.ToString(\"0.0000\")}, validation accuracy: {accuracy_val.ToString(\"P\")}\");\n",
171-
" Console.WriteLine(\"---------------------------------------------------------\");\n",
170+
" (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, mnist.Validation.Data), (y, mnist.Validation.Labels));\n",
171+
" print(\"---------------------------------------------------------\");\n",
172+
" print($\"Epoch: {epoch + 1}, validation loss: {loss_val.ToString(\"0.0000\")}, validation accuracy: {accuracy_val.ToString(\"P\")}\");\n",
173+
" print(\"---------------------------------------------------------\");\n",
172174
" }\n",
173-
"}\n"
175+
"}"
174176
]
175177
},
176178
{
@@ -179,20 +181,13 @@
179181
"metadata": {},
180182
"outputs": [],
181183
"source": [
182-
"void Test(Session sess, Datasets<MnistDataSet> mnist)\n",
184+
"public void Test(Session sess, Datasets<MnistDataSet> mnist)\n",
183185
"{\n",
184-
" var result = sess.run(\n",
185-
" new[] { loss, accuracy },\n",
186-
" new FeedItem(x, mnist.Test.Data),\n",
187-
" new FeedItem(y, mnist.Test.Labels)\n",
188-
" );\n",
189-
" \n",
190-
" loss_test = result[0];\n",
191-
" accuracy_test = result[1];\n",
192-
" Console.WriteLine(\"---------------------------------------------------------\");\n",
193-
" Console.WriteLine($\"Test loss: {loss_test.ToString(\"0.0000\")}, test accuracy: {accuracy_test.ToString(\"P\")}\");\n",
194-
" Console.WriteLine(\"---------------------------------------------------------\");\n",
195-
"}\n"
186+
" (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, mnist.Test.Data), (y, mnist.Test.Labels));\n",
187+
" print(\"---------------------------------------------------------\");\n",
188+
" print($\"Test loss: {loss_test.ToString(\"0.0000\")}, test accuracy: {accuracy_test.ToString(\"P\")}\");\n",
189+
" print(\"---------------------------------------------------------\");\n",
190+
"}"
196191
]
197192
},
198193
{
@@ -204,6 +199,15 @@
204199
"var mnist = await MnistModelLoader.LoadAsync(\"mnist\", true);"
205200
]
206201
},
202+
{
203+
"cell_type": "code",
204+
"execution_count": null,
205+
"metadata": {},
206+
"outputs": [],
207+
"source": [
208+
"BuildGraph();"
209+
]
210+
},
207211
{
208212
"cell_type": "code",
209213
"execution_count": null,

0 commit comments

Comments
 (0)