@@ -71,38 +71,61 @@ def test_dataset_input_fn(self):
71
71
for pixel in row :
72
72
self .assertAllClose (pixel , np .array ([- 1.225 , 0. , 1.225 ]), rtol = 1e-3 )
73
73
74
+ def _cifar10_model_fn_helper (self , mode , version , dtype , multi_gpu = False ):
75
+ with tf .Graph ().as_default () as g :
76
+ input_fn = cifar10_main .get_synth_input_fn ()
77
+ dataset = input_fn (True , '' , _BATCH_SIZE )
78
+ iterator = dataset .make_one_shot_iterator ()
79
+ features , labels = iterator .get_next ()
80
+ spec = cifar10_main .cifar10_model_fn (
81
+ features , labels , mode , {
82
+ 'dtype' : dtype ,
83
+ 'resnet_size' : 32 ,
84
+ 'data_format' : 'channels_last' ,
85
+ 'batch_size' : _BATCH_SIZE ,
86
+ 'version' : version ,
87
+ 'loss_scale' : 128 if dtype == tf .float16 else 1 ,
88
+ 'multi_gpu' : multi_gpu
89
+ })
90
+
91
+ predictions = spec .predictions
92
+ self .assertAllEqual (predictions ['probabilities' ].shape ,
93
+ (_BATCH_SIZE , 10 ))
94
+ self .assertEqual (predictions ['probabilities' ].dtype , tf .float32 )
95
+ self .assertAllEqual (predictions ['classes' ].shape , (_BATCH_SIZE ,))
96
+ self .assertEqual (predictions ['classes' ].dtype , tf .int64 )
97
+
98
+ if mode != tf .estimator .ModeKeys .PREDICT :
99
+ loss = spec .loss
100
+ self .assertAllEqual (loss .shape , ())
101
+ self .assertEqual (loss .dtype , tf .float32 )
102
+
103
+ if mode == tf .estimator .ModeKeys .EVAL :
104
+ eval_metric_ops = spec .eval_metric_ops
105
+ self .assertAllEqual (eval_metric_ops ['accuracy' ][0 ].shape , ())
106
+ self .assertAllEqual (eval_metric_ops ['accuracy' ][1 ].shape , ())
107
+ self .assertEqual (eval_metric_ops ['accuracy' ][0 ].dtype , tf .float32 )
108
+ self .assertEqual (eval_metric_ops ['accuracy' ][1 ].dtype , tf .float32 )
109
+
110
+ for v in tf .trainable_variables ():
111
+ self .assertEqual (v .dtype .base_dtype , tf .float32 )
112
+
113
+ tensors_to_check = ('initial_conv:0' , 'block_layer1:0' , 'block_layer2:0' ,
114
+ 'block_layer3:0' , 'final_reduce_mean:0' ,
115
+ 'final_dense:0' )
116
+
117
+ for tensor_name in tensors_to_check :
118
+ tensor = g .get_tensor_by_name ('resnet_model/' + tensor_name )
119
+ self .assertEqual (tensor .dtype , dtype ,
120
+ 'Tensor {} has dtype {}, while dtype {} was '
121
+ 'expected' .format (tensor , tensor .dtype ,
122
+ dtype ))
123
+
74
124
def cifar10_model_fn_helper (self , mode , version , multi_gpu = False ):
75
- input_fn = cifar10_main .get_synth_input_fn ()
76
- dataset = input_fn (True , '' , _BATCH_SIZE )
77
- iterator = dataset .make_one_shot_iterator ()
78
- features , labels = iterator .get_next ()
79
- spec = cifar10_main .cifar10_model_fn (
80
- features , labels , mode , {
81
- 'resnet_size' : 32 ,
82
- 'data_format' : 'channels_last' ,
83
- 'batch_size' : _BATCH_SIZE ,
84
- 'version' : version ,
85
- 'multi_gpu' : multi_gpu
86
- })
87
-
88
- predictions = spec .predictions
89
- self .assertAllEqual (predictions ['probabilities' ].shape ,
90
- (_BATCH_SIZE , 10 ))
91
- self .assertEqual (predictions ['probabilities' ].dtype , tf .float32 )
92
- self .assertAllEqual (predictions ['classes' ].shape , (_BATCH_SIZE ,))
93
- self .assertEqual (predictions ['classes' ].dtype , tf .int64 )
94
-
95
- if mode != tf .estimator .ModeKeys .PREDICT :
96
- loss = spec .loss
97
- self .assertAllEqual (loss .shape , ())
98
- self .assertEqual (loss .dtype , tf .float32 )
99
-
100
- if mode == tf .estimator .ModeKeys .EVAL :
101
- eval_metric_ops = spec .eval_metric_ops
102
- self .assertAllEqual (eval_metric_ops ['accuracy' ][0 ].shape , ())
103
- self .assertAllEqual (eval_metric_ops ['accuracy' ][1 ].shape , ())
104
- self .assertEqual (eval_metric_ops ['accuracy' ][0 ].dtype , tf .float32 )
105
- self .assertEqual (eval_metric_ops ['accuracy' ][1 ].dtype , tf .float32 )
125
+ self ._cifar10_model_fn_helper (mode = mode , version = version , dtype = tf .float32 ,
126
+ multi_gpu = multi_gpu )
127
+ self ._cifar10_model_fn_helper (mode = mode , version = version , dtype = tf .float16 ,
128
+ multi_gpu = multi_gpu )
106
129
107
130
def test_cifar10_model_fn_train_mode_v1 (self ):
108
131
self .cifar10_model_fn_helper (tf .estimator .ModeKeys .TRAIN , version = 1 )
@@ -130,19 +153,22 @@ def test_cifar10_model_fn_predict_mode_v1(self):
130
153
def test_cifar10_model_fn_predict_mode_v2 (self ):
131
154
self .cifar10_model_fn_helper (tf .estimator .ModeKeys .PREDICT , version = 2 )
132
155
133
- def test_cifar10model_shape (self ):
156
+ def _test_cifar10model_shape (self , version ):
134
157
batch_size = 135
135
158
num_classes = 246
136
159
137
- for version in (1 , 2 ):
138
- model = cifar10_main .Cifar10Model (
139
- 32 , data_format = 'channels_last' , num_classes = num_classes ,
140
- version = version )
141
- fake_input = tf .random_uniform (
142
- [batch_size , _HEIGHT , _WIDTH , _NUM_CHANNELS ])
143
- output = model (fake_input , training = True )
160
+ model = cifar10_main .Cifar10Model (32 , data_format = 'channels_last' ,
161
+ num_classes = num_classes , version = version )
162
+ fake_input = tf .random_uniform ([batch_size , _HEIGHT , _WIDTH , _NUM_CHANNELS ])
163
+ output = model (fake_input , training = True )
164
+
165
+ self .assertAllEqual (output .shape , (batch_size , num_classes ))
166
+
167
+ def test_cifar10model_shape_v1 (self ):
168
+ self ._test_cifar10model_shape (version = 1 )
144
169
145
- self .assertAllEqual (output .shape , (batch_size , num_classes ))
170
+ def test_cifar10model_shape_v2 (self ):
171
+ self ._test_cifar10model_shape (version = 2 )
146
172
147
173
def test_cifar10_end_to_end_synthetic_v1 (self ):
148
174
integration .run_synthetic (
0 commit comments