|
| 1 | +# Copyright 2016 Google Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Tests for inception.""" |
| 16 | +from __future__ import absolute_import |
| 17 | +from __future__ import division |
| 18 | +from __future__ import print_function |
| 19 | + |
| 20 | +import tensorflow as tf |
| 21 | + |
| 22 | +from inception.slim import slim |
| 23 | + |
| 24 | + |
| 25 | +def get_variables(scope=None): |
| 26 | + return slim.variables.get_variables(scope) |
| 27 | + |
| 28 | + |
| 29 | +def get_variables_by_name(name): |
| 30 | + return slim.variables.get_variables_by_name(name) |
| 31 | + |
| 32 | + |
| 33 | +class CollectionsTest(tf.test.TestCase): |
| 34 | + |
| 35 | + def testVariables(self): |
| 36 | + batch_size = 5 |
| 37 | + height, width = 299, 299 |
| 38 | + with self.test_session(): |
| 39 | + inputs = tf.random_uniform((batch_size, height, width, 3)) |
| 40 | + with slim.arg_scope([slim.ops.conv2d], |
| 41 | + batch_norm_params={'decay': 0.9997}): |
| 42 | + slim.inception.inception_v3(inputs) |
| 43 | + self.assertEqual(len(get_variables()), 388) |
| 44 | + self.assertEqual(len(get_variables_by_name('weights')), 98) |
| 45 | + self.assertEqual(len(get_variables_by_name('biases')), 2) |
| 46 | + self.assertEqual(len(get_variables_by_name('beta')), 96) |
| 47 | + self.assertEqual(len(get_variables_by_name('gamma')), 0) |
| 48 | + self.assertEqual(len(get_variables_by_name('moving_mean')), 96) |
| 49 | + self.assertEqual(len(get_variables_by_name('moving_variance')), 96) |
| 50 | + |
| 51 | + def testVariablesWithoutBatchNorm(self): |
| 52 | + batch_size = 5 |
| 53 | + height, width = 299, 299 |
| 54 | + with self.test_session(): |
| 55 | + inputs = tf.random_uniform((batch_size, height, width, 3)) |
| 56 | + with slim.arg_scope([slim.ops.conv2d], |
| 57 | + batch_norm_params=None): |
| 58 | + slim.inception.inception_v3(inputs) |
| 59 | + self.assertEqual(len(get_variables()), 196) |
| 60 | + self.assertEqual(len(get_variables_by_name('weights')), 98) |
| 61 | + self.assertEqual(len(get_variables_by_name('biases')), 98) |
| 62 | + self.assertEqual(len(get_variables_by_name('beta')), 0) |
| 63 | + self.assertEqual(len(get_variables_by_name('gamma')), 0) |
| 64 | + self.assertEqual(len(get_variables_by_name('moving_mean')), 0) |
| 65 | + self.assertEqual(len(get_variables_by_name('moving_variance')), 0) |
| 66 | + |
| 67 | + def testVariablesByLayer(self): |
| 68 | + batch_size = 5 |
| 69 | + height, width = 299, 299 |
| 70 | + with self.test_session(): |
| 71 | + inputs = tf.random_uniform((batch_size, height, width, 3)) |
| 72 | + with slim.arg_scope([slim.ops.conv2d], |
| 73 | + batch_norm_params={'decay': 0.9997}): |
| 74 | + slim.inception.inception_v3(inputs) |
| 75 | + self.assertEqual(len(get_variables()), 388) |
| 76 | + self.assertEqual(len(get_variables('conv0')), 4) |
| 77 | + self.assertEqual(len(get_variables('conv1')), 4) |
| 78 | + self.assertEqual(len(get_variables('conv2')), 4) |
| 79 | + self.assertEqual(len(get_variables('conv3')), 4) |
| 80 | + self.assertEqual(len(get_variables('conv4')), 4) |
| 81 | + self.assertEqual(len(get_variables('mixed_35x35x256a')), 28) |
| 82 | + self.assertEqual(len(get_variables('mixed_35x35x288a')), 28) |
| 83 | + self.assertEqual(len(get_variables('mixed_35x35x288b')), 28) |
| 84 | + self.assertEqual(len(get_variables('mixed_17x17x768a')), 16) |
| 85 | + self.assertEqual(len(get_variables('mixed_17x17x768b')), 40) |
| 86 | + self.assertEqual(len(get_variables('mixed_17x17x768c')), 40) |
| 87 | + self.assertEqual(len(get_variables('mixed_17x17x768d')), 40) |
| 88 | + self.assertEqual(len(get_variables('mixed_17x17x768e')), 40) |
| 89 | + self.assertEqual(len(get_variables('mixed_8x8x2048a')), 36) |
| 90 | + self.assertEqual(len(get_variables('mixed_8x8x2048b')), 36) |
| 91 | + self.assertEqual(len(get_variables('logits')), 2) |
| 92 | + self.assertEqual(len(get_variables('aux_logits')), 10) |
| 93 | + |
| 94 | + def testVariablesToRestore(self): |
| 95 | + batch_size = 5 |
| 96 | + height, width = 299, 299 |
| 97 | + with self.test_session(): |
| 98 | + inputs = tf.random_uniform((batch_size, height, width, 3)) |
| 99 | + with slim.arg_scope([slim.ops.conv2d], |
| 100 | + batch_norm_params={'decay': 0.9997}): |
| 101 | + slim.inception.inception_v3(inputs) |
| 102 | + variables_to_restore = tf.get_collection( |
| 103 | + slim.variables.VARIABLES_TO_RESTORE) |
| 104 | + self.assertEqual(len(variables_to_restore), 388) |
| 105 | + self.assertListEqual(variables_to_restore, get_variables()) |
| 106 | + |
| 107 | + def testVariablesToRestoreWithoutLogits(self): |
| 108 | + batch_size = 5 |
| 109 | + height, width = 299, 299 |
| 110 | + with self.test_session(): |
| 111 | + inputs = tf.random_uniform((batch_size, height, width, 3)) |
| 112 | + with slim.arg_scope([slim.ops.conv2d], |
| 113 | + batch_norm_params={'decay': 0.9997}): |
| 114 | + slim.inception.inception_v3(inputs, restore_logits=False) |
| 115 | + variables_to_restore = tf.get_collection( |
| 116 | + slim.variables.VARIABLES_TO_RESTORE) |
| 117 | + self.assertEqual(len(variables_to_restore), 384) |
| 118 | + |
| 119 | + def testRegularizationLosses(self): |
| 120 | + batch_size = 5 |
| 121 | + height, width = 299, 299 |
| 122 | + with self.test_session(): |
| 123 | + inputs = tf.random_uniform((batch_size, height, width, 3)) |
| 124 | + with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004): |
| 125 | + slim.inception.inception_v3(inputs) |
| 126 | + losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) |
| 127 | + self.assertEqual(len(losses), len(get_variables_by_name('weights'))) |
| 128 | + |
| 129 | + def testTotalLossWithoutRegularization(self): |
| 130 | + batch_size = 5 |
| 131 | + height, width = 299, 299 |
| 132 | + num_classes = 1001 |
| 133 | + with self.test_session(): |
| 134 | + inputs = tf.random_uniform((batch_size, height, width, 3)) |
| 135 | + dense_labels = tf.random_uniform((batch_size, num_classes)) |
| 136 | + with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0): |
| 137 | + logits, end_points = slim.inception.inception_v3( |
| 138 | + inputs, |
| 139 | + num_classes=num_classes) |
| 140 | + # Cross entropy loss for the main softmax prediction. |
| 141 | + slim.losses.cross_entropy_loss(logits, |
| 142 | + dense_labels, |
| 143 | + label_smoothing=0.1, |
| 144 | + weight=1.0) |
| 145 | + # Cross entropy loss for the auxiliary softmax head. |
| 146 | + slim.losses.cross_entropy_loss(end_points['aux_logits'], |
| 147 | + dense_labels, |
| 148 | + label_smoothing=0.1, |
| 149 | + weight=0.4, |
| 150 | + scope='aux_loss') |
| 151 | + losses = tf.get_collection(slim.losses.LOSSES_COLLECTION) |
| 152 | + self.assertEqual(len(losses), 2) |
| 153 | + |
| 154 | + def testTotalLossWithRegularization(self): |
| 155 | + batch_size = 5 |
| 156 | + height, width = 299, 299 |
| 157 | + num_classes = 1000 |
| 158 | + with self.test_session(): |
| 159 | + inputs = tf.random_uniform((batch_size, height, width, 3)) |
| 160 | + dense_labels = tf.random_uniform((batch_size, num_classes)) |
| 161 | + with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004): |
| 162 | + logits, end_points = slim.inception.inception_v3(inputs, num_classes) |
| 163 | + # Cross entropy loss for the main softmax prediction. |
| 164 | + slim.losses.cross_entropy_loss(logits, |
| 165 | + dense_labels, |
| 166 | + label_smoothing=0.1, |
| 167 | + weight=1.0) |
| 168 | + # Cross entropy loss for the auxiliary softmax head. |
| 169 | + slim.losses.cross_entropy_loss(end_points['aux_logits'], |
| 170 | + dense_labels, |
| 171 | + label_smoothing=0.1, |
| 172 | + weight=0.4, |
| 173 | + scope='aux_loss') |
| 174 | + losses = tf.get_collection(slim.losses.LOSSES_COLLECTION) |
| 175 | + self.assertEqual(len(losses), 2) |
| 176 | + reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) |
| 177 | + self.assertEqual(len(reg_losses), 98) |
| 178 | + |
| 179 | + |
| 180 | +if __name__ == '__main__': |
| 181 | + tf.test.main() |
0 commit comments