1
1
"""A script to run inference on a set of image files.
2
2
3
- NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
4
- it will work only for images which look more or less similar to french street
5
- names. In order to apply it to images from a different distribution you need
6
- to retrain (or at least fine-tune) it using images from that distribution.
3
+ NOTE #1: The Attention OCR model was trained only using FSNS train dataset and
4
+ it will work only for images which look more or less similar to french street
5
+ names. In order to apply it to images from a different distribution you need
6
+ to retrain (or at least fine-tune) it using images from that distribution.
7
7
8
8
NOTE #2: This script exists for demo purposes only. It is highly recommended
9
9
to use tools and mechanisms provided by the TensorFlow Serving system to run
20
20
21
21
import tensorflow as tf
22
22
from tensorflow .python .platform import flags
23
+ from tensorflow .python .training import monitored_session
23
24
24
25
import common_flags
25
26
import datasets
26
- import model as attention_ocr
27
+ import data_provider
27
28
28
29
FLAGS = flags .FLAGS
29
30
common_flags .define ()
@@ -44,7 +45,7 @@ def get_dataset_image_size(dataset_name):
44
45
def load_images (file_pattern , batch_size , dataset_name ):
45
46
width , height = get_dataset_image_size (dataset_name )
46
47
images_actual_data = np .ndarray (shape = (batch_size , height , width , 3 ),
47
- dtype = 'float32 ' )
48
+ dtype = 'uint8 ' )
48
49
for i in range (batch_size ):
49
50
path = file_pattern % i
50
51
print ("Reading %s" % path )
@@ -53,35 +54,40 @@ def load_images(file_pattern, batch_size, dataset_name):
53
54
return images_actual_data
54
55
55
56
56
- def load_model ( checkpoint , batch_size , dataset_name ):
57
+ def create_model ( batch_size , dataset_name ):
57
58
width , height = get_dataset_image_size (dataset_name )
58
59
dataset = common_flags .create_dataset (split_name = FLAGS .split_name )
59
60
model = common_flags .create_model (
60
- num_char_classes = dataset .num_char_classes ,
61
- seq_length = dataset .max_sequence_length ,
62
- num_views = dataset .num_of_views ,
63
- null_code = dataset .null_code ,
64
- charset = dataset .charset )
65
- images_placeholder = tf .placeholder (tf .float32 ,
66
- shape = [batch_size , height , width , 3 ])
67
- endpoints = model .create_base (images_placeholder , labels_one_hot = None )
68
- init_fn = model .create_init_fn_to_restore (checkpoint )
69
- return images_placeholder , endpoints , init_fn
61
+ num_char_classes = dataset .num_char_classes ,
62
+ seq_length = dataset .max_sequence_length ,
63
+ num_views = dataset .num_of_views ,
64
+ null_code = dataset .null_code ,
65
+ charset = dataset .charset )
66
+ raw_images = tf .placeholder (tf .uint8 , shape = [batch_size , height , width , 3 ])
67
+ images = tf .map_fn (data_provider .preprocess_image , raw_images ,
68
+ dtype = tf .float32 )
69
+ endpoints = model .create_base (images , labels_one_hot = None )
70
+ return raw_images , endpoints
71
+
72
+
73
+ def run (checkpoint , batch_size , dataset_name , image_path_pattern ):
74
+ images_placeholder , endpoints = create_model (batch_size ,
75
+ dataset_name )
76
+ images_data = load_images (image_path_pattern , batch_size ,
77
+ dataset_name )
78
+ session_creator = monitored_session .ChiefSessionCreator (
79
+ checkpoint_filename_with_path = checkpoint )
80
+ with monitored_session .MonitoredSession (
81
+ session_creator = session_creator ) as sess :
82
+ predictions = sess .run (endpoints .predicted_text ,
83
+ feed_dict = {images_placeholder : images_data })
84
+ return predictions .tolist ()
70
85
71
86
72
87
def main (_ ):
73
- images_placeholder , endpoints , init_fn = load_model (FLAGS .checkpoint ,
74
- FLAGS .batch_size ,
75
- FLAGS .dataset_name )
76
- images_data = load_images (FLAGS .image_path_pattern , FLAGS .batch_size ,
77
- FLAGS .dataset_name )
78
- with tf .Session () as sess :
79
- tf .tables_initializer ().run () # required by the CharsetMapper
80
- init_fn (sess )
81
- predictions = sess .run (endpoints .predicted_text ,
82
- feed_dict = {images_placeholder : images_data })
83
88
print ("Predicted strings:" )
84
- for line in predictions :
89
+ for line in run (FLAGS .checkpoint , FLAGS .batch_size , FLAGS .dataset_name ,
90
+ FLAGS .image_path_pattern ):
85
91
print (line )
86
92
87
93
0 commit comments