|
51 | 51 | "int batch_size = 100;\n",
|
52 | 52 | "float learning_rate = 0.001f;\n",
|
53 | 53 | "int h1 = 200; // number of nodes in the 1st hidden layer\n",
|
| 54 | + "\n", |
| 55 | + "Tensor x, y;\n", |
| 56 | + "Tensor loss, accuracy;\n", |
| 57 | + "Operation optimizer;\n", |
| 58 | + "\n", |
54 | 59 | "int display_freq = 100;\n",
|
55 | 60 | "float accuracy_test = 0f;\n",
|
56 | 61 | "float loss_test = 1f;"
|
|
91 | 96 | "metadata": {},
|
92 | 97 | "outputs": [],
|
93 | 98 | "source": [
|
94 |
| - "var graph = new Graph().as_default();\n", |
95 |
| - "\n", |
96 |
| - "// Placeholders for inputs (x) and outputs(y)\n", |
97 |
| - "var x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: \"X\");\n", |
98 |
| - "var y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: \"Y\");\n", |
99 |
| - "\n", |
100 |
| - "// Create a fully-connected layer with h1 nodes as hidden layer\n", |
101 |
| - "var fc1 = fc_layer(x, h1, \"FC1\", use_relu: true);\n", |
102 |
| - "// Create a fully-connected layer with n_classes nodes as output layer\n", |
103 |
| - "var output_logits = fc_layer(fc1, n_classes, \"OUT\", use_relu: false);\n", |
104 |
| - "// Define the loss function, optimizer, and accuracy\n", |
105 |
| - "var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);\n", |
106 |
| - "var loss = tf.reduce_mean(logits, name: \"loss\");\n", |
107 |
| - "var optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: \"Adam-op\").minimize(loss);\n", |
108 |
| - "var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: \"correct_pred\");\n", |
109 |
| - "var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: \"accuracy\");\n", |
110 |
| - "\n", |
111 |
| - "// Network predictions\n", |
112 |
| - "var cls_prediction = tf.argmax(output_logits, axis: 1, name: \"predictions\");" |
| 99 | + "public Graph BuildGraph()\n", |
| 100 | + "{\n", |
| 101 | + " var graph = new Graph().as_default();\n", |
| 102 | + "\n", |
| 103 | + " // Placeholders for inputs (x) and outputs(y)\n", |
| 104 | + " x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: \"X\");\n", |
| 105 | + " y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: \"Y\");\n", |
| 106 | + "\n", |
| 107 | + " // Create a fully-connected layer with h1 nodes as hidden layer\n", |
| 108 | + " var fc1 = fc_layer(x, h1, \"FC1\", use_relu: true);\n", |
| 109 | + " // Create a fully-connected layer with n_classes nodes as output layer\n", |
| 110 | + " var output_logits = fc_layer(fc1, n_classes, \"OUT\", use_relu: false);\n", |
| 111 | + " // Define the loss function, optimizer, and accuracy\n", |
| 112 | + " var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);\n", |
| 113 | + " loss = tf.reduce_mean(logits, name: \"loss\");\n", |
| 114 | + " optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: \"Adam-op\").minimize(loss);\n", |
| 115 | + " var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: \"correct_pred\");\n", |
| 116 | + " accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: \"accuracy\");\n", |
| 117 | + "\n", |
| 118 | + " // Network predictions\n", |
| 119 | + " var cls_prediction = tf.argmax(output_logits, axis: 1, name: \"predictions\");\n", |
| 120 | + "\n", |
| 121 | + " return graph;\n", |
| 122 | + "}" |
113 | 123 | ]
|
114 | 124 | },
|
115 | 125 | {
|
|
118 | 128 | "metadata": {},
|
119 | 129 | "outputs": [],
|
120 | 130 | "source": [
|
121 |
| - "void Train(Session sess, Datasets<MnistDataSet> mnist)\n", |
| 131 | + "public void Train(Session sess, Datasets<MnistDataSet> mnist)\n", |
122 | 132 | "{\n",
|
123 | 133 | " // Number of training iterations in each epoch\n",
|
124 | 134 | " var num_tr_iter = mnist.Train.Labels.shape[0] / batch_size;\n",
|
|
129 | 139 | " float loss_val = 100.0f;\n",
|
130 | 140 | " float accuracy_val = 0f;\n",
|
131 | 141 | "\n",
|
| 142 | + " var sw = new Stopwatch();\n", |
| 143 | + " sw.Start();\n", |
| 144 | + "\n", |
132 | 145 | " foreach (var epoch in range(epochs))\n",
|
133 | 146 | " {\n",
|
134 |
| - " Console.WriteLine($\"Training epoch: {epoch + 1}\");\n", |
| 147 | + " print($\"Training epoch: {epoch + 1}\");\n", |
135 | 148 | " // Randomly shuffle the training data at the beginning of each epoch \n",
|
136 | 149 | " var (x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels);\n",
|
137 | 150 | "\n",
|
|
142 | 155 | " var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);\n",
|
143 | 156 | "\n",
|
144 | 157 | " // Run optimization op (backprop)\n",
|
145 |
| - " sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));\n", |
| 158 | + " sess.run(optimizer, (x, x_batch), (y, y_batch));\n", |
146 | 159 | "\n",
|
147 | 160 | " if (iteration % display_freq == 0)\n",
|
148 | 161 | " {\n",
|
149 | 162 | " // Calculate and display the batch loss and accuracy\n",
|
150 |
| - " var result = sess.run(\n", |
151 |
| - " new[] { loss, accuracy },\n", |
152 |
| - " new FeedItem(x, x_batch),\n", |
153 |
| - " new FeedItem(y, y_batch));\n", |
154 |
| - " loss_val = result[0];\n", |
155 |
| - " accuracy_val = result[1];\n", |
156 |
| - " Console.WriteLine($\"iter {iteration.ToString(\"000\")}: Loss={loss_val.ToString(\"0.0000\")}, Training Accuracy={accuracy_val.ToString(\"P\")}\");\n", |
| 163 | + " (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch));\n", |
| 164 | + " print($\"iter {iteration.ToString(\"000\")}: Loss={loss_val.ToString(\"0.0000\")}, Training Accuracy={accuracy_val.ToString(\"P\")} {sw.ElapsedMilliseconds}ms\");\n", |
| 165 | + " sw.Restart();\n", |
157 | 166 | " }\n",
|
158 | 167 | " }\n",
|
159 | 168 | "\n",
|
160 | 169 | " // Run validation after every epoch\n",
|
161 |
| - " var results1 = sess.run(\n", |
162 |
| - " new[] { loss, accuracy },\n", |
163 |
| - " new FeedItem(x, mnist.Validation.Data),\n", |
164 |
| - " new FeedItem(y, mnist.Validation.Labels)\n", |
165 |
| - " );\n", |
166 |
| - " \n", |
167 |
| - " loss_val = results1[0];\n", |
168 |
| - " accuracy_val = results1[1];\n", |
169 |
| - " Console.WriteLine(\"---------------------------------------------------------\");\n", |
170 |
| - " Console.WriteLine($\"Epoch: {epoch + 1}, validation loss: {loss_val.ToString(\"0.0000\")}, validation accuracy: {accuracy_val.ToString(\"P\")}\");\n", |
171 |
| - " Console.WriteLine(\"---------------------------------------------------------\");\n", |
| 170 | + " (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, mnist.Validation.Data), (y, mnist.Validation.Labels));\n", |
| 171 | + " print(\"---------------------------------------------------------\");\n", |
| 172 | + " print($\"Epoch: {epoch + 1}, validation loss: {loss_val.ToString(\"0.0000\")}, validation accuracy: {accuracy_val.ToString(\"P\")}\");\n", |
| 173 | + " print(\"---------------------------------------------------------\");\n", |
172 | 174 | " }\n",
|
173 |
| - "}\n" |
| 175 | + "}" |
174 | 176 | ]
|
175 | 177 | },
|
176 | 178 | {
|
|
179 | 181 | "metadata": {},
|
180 | 182 | "outputs": [],
|
181 | 183 | "source": [
|
182 |
| - "void Test(Session sess, Datasets<MnistDataSet> mnist)\n", |
| 184 | + "public void Test(Session sess, Datasets<MnistDataSet> mnist)\n", |
183 | 185 | "{\n",
|
184 |
| - " var result = sess.run(\n", |
185 |
| - " new[] { loss, accuracy },\n", |
186 |
| - " new FeedItem(x, mnist.Test.Data),\n", |
187 |
| - " new FeedItem(y, mnist.Test.Labels)\n", |
188 |
| - " );\n", |
189 |
| - " \n", |
190 |
| - " loss_test = result[0];\n", |
191 |
| - " accuracy_test = result[1];\n", |
192 |
| - " Console.WriteLine(\"---------------------------------------------------------\");\n", |
193 |
| - " Console.WriteLine($\"Test loss: {loss_test.ToString(\"0.0000\")}, test accuracy: {accuracy_test.ToString(\"P\")}\");\n", |
194 |
| - " Console.WriteLine(\"---------------------------------------------------------\");\n", |
195 |
| - "}\n" |
| 186 | + " (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, mnist.Test.Data), (y, mnist.Test.Labels));\n", |
| 187 | + " print(\"---------------------------------------------------------\");\n", |
| 188 | + " print($\"Test loss: {loss_test.ToString(\"0.0000\")}, test accuracy: {accuracy_test.ToString(\"P\")}\");\n", |
| 189 | + " print(\"---------------------------------------------------------\");\n", |
| 190 | + "}" |
196 | 191 | ]
|
197 | 192 | },
|
198 | 193 | {
|
|
204 | 199 | "var mnist = await MnistModelLoader.LoadAsync(\"mnist\", true);"
|
205 | 200 | ]
|
206 | 201 | },
|
| 202 | + { |
| 203 | + "cell_type": "code", |
| 204 | + "execution_count": null, |
| 205 | + "metadata": {}, |
| 206 | + "outputs": [], |
| 207 | + "source": [ |
| 208 | + "BuildGraph();" |
| 209 | + ] |
| 210 | + }, |
207 | 211 | {
|
208 | 212 | "cell_type": "code",
|
209 | 213 | "execution_count": null,
|
|
0 commit comments