Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 89e19ed

Browse files
author
Alexander Gorban
committed
Fix demo_inference to properly normalize input.
Before the fix the demo_inference.py used batch_norm and it did the normalization of input image implicitly. If at inference time the batch_norm was disabled the inference produced incorrect results. This fix does the proper input image normalization and disables the batch_norm at inference time.
1 parent f893da6 commit 89e19ed

35 files changed

+184
-88
lines changed

research/attention_ocr/python/demo_inference.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""A script to run inference on a set of image files.
22
3-
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
4-
it will work only for images which look more or less similar to french street
5-
names. In order to apply it to images from a different distribution you need
6-
to retrain (or at least fine-tune) it using images from that distribution.
3+
NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
4+
it will work only for images which look more or less similar to french street
5+
names. In order to apply it to images from a different distribution you need
6+
to retrain (or at least fine-tune) it using images from that distribution.
77
88
NOTE #2: This script exists for demo purposes only. It is highly recommended
99
to use tools and mechanisms provided by the TensorFlow Serving system to run
@@ -20,10 +20,11 @@
2020

2121
import tensorflow as tf
2222
from tensorflow.python.platform import flags
23+
from tensorflow.python.training import monitored_session
2324

2425
import common_flags
2526
import datasets
26-
import model as attention_ocr
27+
import data_provider
2728

2829
FLAGS = flags.FLAGS
2930
common_flags.define()
@@ -44,7 +45,7 @@ def get_dataset_image_size(dataset_name):
4445
def load_images(file_pattern, batch_size, dataset_name):
4546
width, height = get_dataset_image_size(dataset_name)
4647
images_actual_data = np.ndarray(shape=(batch_size, height, width, 3),
47-
dtype='float32')
48+
dtype='uint8')
4849
for i in range(batch_size):
4950
path = file_pattern % i
5051
print("Reading %s" % path)
@@ -53,35 +54,40 @@ def load_images(file_pattern, batch_size, dataset_name):
5354
return images_actual_data
5455

5556

56-
def load_model(checkpoint, batch_size, dataset_name):
57+
def create_model(batch_size, dataset_name):
5758
width, height = get_dataset_image_size(dataset_name)
5859
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
5960
model = common_flags.create_model(
60-
num_char_classes=dataset.num_char_classes,
61-
seq_length=dataset.max_sequence_length,
62-
num_views=dataset.num_of_views,
63-
null_code=dataset.null_code,
64-
charset=dataset.charset)
65-
images_placeholder = tf.placeholder(tf.float32,
66-
shape=[batch_size, height, width, 3])
67-
endpoints = model.create_base(images_placeholder, labels_one_hot=None)
68-
init_fn = model.create_init_fn_to_restore(checkpoint)
69-
return images_placeholder, endpoints, init_fn
61+
num_char_classes=dataset.num_char_classes,
62+
seq_length=dataset.max_sequence_length,
63+
num_views=dataset.num_of_views,
64+
null_code=dataset.null_code,
65+
charset=dataset.charset)
66+
raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
67+
images = tf.map_fn(data_provider.preprocess_image, raw_images,
68+
dtype=tf.float32)
69+
endpoints = model.create_base(images, labels_one_hot=None)
70+
return raw_images, endpoints
71+
72+
73+
def run(checkpoint, batch_size, dataset_name, image_path_pattern):
74+
images_placeholder, endpoints = create_model(batch_size,
75+
dataset_name)
76+
images_data = load_images(image_path_pattern, batch_size,
77+
dataset_name)
78+
session_creator = monitored_session.ChiefSessionCreator(
79+
checkpoint_filename_with_path=checkpoint)
80+
with monitored_session.MonitoredSession(
81+
session_creator=session_creator) as sess:
82+
predictions = sess.run(endpoints.predicted_text,
83+
feed_dict={images_placeholder: images_data})
84+
return predictions.tolist()
7085

7186

7287
def main(_):
73-
images_placeholder, endpoints, init_fn = load_model(FLAGS.checkpoint,
74-
FLAGS.batch_size,
75-
FLAGS.dataset_name)
76-
images_data = load_images(FLAGS.image_path_pattern, FLAGS.batch_size,
77-
FLAGS.dataset_name)
78-
with tf.Session() as sess:
79-
tf.tables_initializer().run() # required by the CharsetMapper
80-
init_fn(sess)
81-
predictions = sess.run(endpoints.predicted_text,
82-
feed_dict={images_placeholder: images_data})
8388
print("Predicted strings:")
84-
for line in predictions:
89+
for line in run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name,
90+
FLAGS.image_path_pattern):
8591
print(line)
8692

8793

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/usr/bin/python
2+
# -*- coding: UTF-8 -*-
3+
import demo_inference
4+
import tensorflow as tf
5+
from tensorflow.python.training import monitored_session
6+
7+
_CHECKPOINT = 'model.ckpt-399731'
8+
_CHECKPOINT_URL = 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz'
9+
10+
11+
class DemoInferenceTest(tf.test.TestCase):
12+
def setUp(self):
13+
super(DemoInferenceTest, self).setUp()
14+
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
15+
filename = _CHECKPOINT + suffix
16+
self.assertTrue(tf.gfile.Exists(filename),
17+
msg='Missing checkpoint file %s. '
18+
'Please download and extract it from %s' %
19+
(filename, _CHECKPOINT_URL))
20+
self._batch_size = 32
21+
22+
def test_moving_variables_properly_loaded_from_a_checkpoint(self):
23+
batch_size = 32
24+
dataset_name = 'fsns'
25+
images_placeholder, endpoints = demo_inference.create_model(batch_size,
26+
dataset_name)
27+
image_path_pattern = 'testdata/fsns_train_%02d.png'
28+
images_data = demo_inference.load_images(image_path_pattern, batch_size,
29+
dataset_name)
30+
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
31+
moving_mean_tf = tf.get_default_graph().get_tensor_by_name(
32+
tensor_name + ':0')
33+
reader = tf.train.NewCheckpointReader(_CHECKPOINT)
34+
moving_mean_expected = reader.get_tensor(tensor_name)
35+
36+
session_creator = monitored_session.ChiefSessionCreator(
37+
checkpoint_filename_with_path=_CHECKPOINT)
38+
with monitored_session.MonitoredSession(
39+
session_creator=session_creator) as sess:
40+
moving_mean_np = sess.run(moving_mean_tf,
41+
feed_dict={images_placeholder: images_data})
42+
43+
self.assertAllEqual(moving_mean_expected, moving_mean_np)
44+
45+
def test_correct_results_on_test_data(self):
46+
image_path_pattern = 'testdata/fsns_train_%02d.png'
47+
predictions = demo_inference.run(_CHECKPOINT, self._batch_size,
48+
'fsns',
49+
image_path_pattern)
50+
self.assertEqual([
51+
'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
52+
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
53+
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
54+
'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
55+
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
56+
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
57+
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
58+
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
59+
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
60+
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
61+
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
62+
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
63+
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
64+
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
65+
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
66+
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
67+
'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
68+
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
69+
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
70+
'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
71+
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
72+
'Rue de la Libération░░░░░░░░░░░░░░░░░',
73+
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
74+
'Avenue de la Grand Mare░░░░░░░░░░░░░░',
75+
'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
76+
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
77+
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
78+
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
79+
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
80+
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
81+
'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
82+
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
83+
], predictions)
84+
85+
86+
if __name__ == '__main__':
87+
tf.test.main()

0 commit comments

Comments
 (0)