Thanks to visit codestin.com
Credit goes to github.com

Skip to content
This repository was archived by the owner on Dec 9, 2024. It is now read-only.

Commit 119534b

Browse files
haoyuztensorflower-gardener
authored andcommitted
Move color distortion ops in SSD model from CPU to GPU.
PiperOrigin-RevId: 229664805
1 parent 747e8da commit 119534b

File tree

5 files changed

+41
-14
lines changed

5 files changed

+41
-14
lines changed

scripts/tf_cnn_benchmarks/models/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ def get_synthetic_inputs(self, input_name, nclass):
257257
name=self.model_name + '_synthetic_labels')
258258
return (inputs, labels)
259259

260+
def gpu_preprocess_nhwc(self, images, phase_train=True):
261+
del phase_train
262+
return images
263+
260264
def build_network(self,
261265
inputs,
262266
phase_train=True,
@@ -273,6 +277,7 @@ def build_network(self,
273277
information.
274278
"""
275279
images = inputs[0]
280+
images = self.gpu_preprocess_nhwc(images, phase_train)
276281
if self.data_format == 'NCHW':
277282
images = tf.transpose(images, [0, 3, 1, 2])
278283
var_type = tf.float32

scripts/tf_cnn_benchmarks/models/ssd_model.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,26 @@ def __init__(self, label_num=ssd_constants.NUM_CLASSES, batch_size=32,
116116
def skip_final_affine_layer(self):
117117
return True
118118

119+
def gpu_preprocess_nhwc(self, images, phase_train=True):
120+
try:
121+
import ssd_dataloader # pylint: disable=g-import-not-at-top
122+
except ImportError:
123+
raise ImportError('To use the COCO dataset, you must clone the '
124+
'repo https://github.com/tensorflow/models and add '
125+
'tensorflow/models and tensorflow/models/research to '
126+
'the PYTHONPATH, and compile the protobufs by '
127+
'following https://github.com/tensorflow/models/blob/'
128+
'master/research/object_detection/g3doc/installation.md'
129+
'#protobuf-compilation ; To evaluate using COCO'
130+
'metric, download and install Python COCO API from'
131+
'https://github.com/cocodataset/cocoapi')
132+
133+
if phase_train:
134+
images = ssd_dataloader.color_jitter(
135+
images, brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05)
136+
images = ssd_dataloader.normalize_image(images)
137+
return images
138+
119139
def add_backbone_model(self, cnn):
120140
# --------------------------------------------------------------------------
121141
# Resnet-34 backbone model -- modified for SSD
@@ -152,8 +172,8 @@ def add_backbone_model(self, cnn):
152172
resnet_model.residual_block(cnn, 256, stride, version, i == 0)
153173

154174
# ResNet-34 block group 4: removed final block group
155-
# The following 3 lines are intentially commented out to differentiate from
156-
# the original ResNet-34 model
175+
# The following 3 lines are intentionally commented out to differentiate
176+
# from the original ResNet-34 model
157177
# for i in range(resnet34_layers[3]):
158178
# stride = 2 if i == 0 else 1
159179
# resnet_model.residual_block(cnn, 512, stride, version, i == 0)

scripts/tf_cnn_benchmarks/models/trivial_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,18 @@ class TrivialSSD300Model(model.CNNModel):
5353

5454
def __init__(self, params=None):
5555
super(TrivialSSD300Model, self).__init__(
56-
'trivial', 300, 32, 0.005, params=params)
56+
'trivial', 300, params.batch_size, 0.005, params=params)
5757

5858
def add_inference(self, cnn):
5959
cnn.reshape([-1, 300 * 300 * 3])
6060
cnn.affine(1)
6161
cnn.affine(4096)
6262

6363
def get_input_shapes(self, subset):
64-
return [[32, 300, 300, 3], [32, 8732, 4], [32, 8732, 1], [32]]
64+
return [[self.batch_size, 300, 300, 3],
65+
[self.batch_size, 8732, 4],
66+
[self.batch_size, 8732, 1],
67+
[self.batch_size]]
6568

6669
def loss_function(self, inputs, build_network_result):
6770
images, _, _, labels = inputs

scripts/tf_cnn_benchmarks/preprocessing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -971,9 +971,6 @@ def preprocess(self, data):
971971
# See https://github.com/tensorflow/models/blob/master/research/object_detection/core/preprocessor.py # pylint: disable=line-too-long
972972
mlperf.logger.log(key=mlperf.tags.RANDOM_FLIP_PROBABILITY, value=0.5)
973973

974-
image = ssd_dataloader.color_jitter(
975-
image, brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05)
976-
image = ssd_dataloader.normalize_image(image)
977974
image = tf.cast(image, self.dtype)
978975

979976
encoded_returns = ssd_encoder.encode_labels(boxes, classes)

scripts/tf_cnn_benchmarks/ssd_dataloader.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,22 +356,24 @@ def color_jitter(image, brightness=0, contrast=0, saturation=0, hue=0):
356356
return image
357357

358358

359-
def normalize_image(image):
360-
"""Normalize the image to zero mean and unit variance.
359+
def normalize_image(images):
360+
"""Normalize image to zero mean and unit variance.
361361
362362
Args:
363-
image: 3D tensor of type float32, value in [0, 1]
363+
images: a tensor representing images, at least 3-D.
364364
Returns:
365-
image normalized by mean and stdev.
365+
images normalized by mean and stdev.
366366
"""
367-
image = tf.subtract(image, ssd_constants.NORMALIZATION_MEAN)
368-
image = tf.divide(image, ssd_constants.NORMALIZATION_STD)
367+
data_type = images.dtype
368+
mean = tf.constant(ssd_constants.NORMALIZATION_MEAN, data_type)
369+
std = tf.constant(ssd_constants.NORMALIZATION_STD, data_type)
370+
images = tf.divide(tf.subtract(images, mean), std)
369371

370372
mlperf.logger.log(key=mlperf.tags.DATA_NORMALIZATION_MEAN,
371373
value=ssd_constants.NORMALIZATION_MEAN)
372374
mlperf.logger.log(key=mlperf.tags.DATA_NORMALIZATION_STD,
373375
value=ssd_constants.NORMALIZATION_STD)
374-
return image
376+
return images
375377

376378

377379
class Encoder(object):

0 commit comments

Comments
 (0)