15
15
16
16
"""Examples of using AI Platform's online prediction service."""
17
17
import argparse
18
- import base64
19
18
import json
20
19
21
20
# [START import_libraries]
22
21
import googleapiclient .discovery
23
- import six
24
22
# [END import_libraries]
25
23
26
24
@@ -61,83 +59,7 @@ def predict_json(project, model, instances, version=None):
61
59
# [END predict_json]
62
60
63
61
64
- # [START predict_tf_records]
65
- def predict_examples (project ,
66
- model ,
67
- example_bytes_list ,
68
- version = None ):
69
- """Send protocol buffer data to a deployed model for prediction.
70
-
71
- Args:
72
- project (str): project where the AI Platform Model is deployed.
73
- model (str): model name.
74
- example_bytes_list ([str]): A list of bytestrings representing
75
- serialized tf.train.Example protocol buffers. The contents of this
76
- protocol buffer will change depending on the signature of your
77
- deployed model.
78
- version: str, version of the model to target.
79
- Returns:
80
- Mapping[str: any]: dictionary of prediction results defined by the
81
- model.
82
- """
83
- service = googleapiclient .discovery .build ('ml' , 'v1' )
84
- name = 'projects/{}/models/{}' .format (project , model )
85
-
86
- if version is not None :
87
- name += '/versions/{}' .format (version )
88
-
89
- response = service .projects ().predict (
90
- name = name ,
91
- body = {'instances' : [
92
- {'b64' : base64 .b64encode (example_bytes ).decode ('utf-8' )}
93
- for example_bytes in example_bytes_list
94
- ]}
95
- ).execute ()
96
-
97
- if 'error' in response :
98
- raise RuntimeError (response ['error' ])
99
-
100
- return response ['predictions' ]
101
- # [END predict_tf_records]
102
-
103
-
104
- # [START census_to_example_bytes]
105
- def census_to_example_bytes (json_instance ):
106
- """Serialize a JSON example to the bytes of a tf.train.Example.
107
- This method is specific to the signature of the Census example.
108
- See: https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
109
- for details.
110
-
111
- Args:
112
- json_instance (Mapping[str: Any]): Keys should be the names of Tensors
113
- your deployed model expects to parse using it's tf.FeatureSpec.
114
- Values should be datatypes convertible to Tensors, or (potentially
115
- nested) lists of datatypes convertible to tensors.
116
- Returns:
117
- str: A string as a container for the serialized bytes of
118
- tf.train.Example protocol buffer.
119
- """
120
- import tensorflow as tf
121
- feature_dict = {}
122
- for key , data in six .iteritems (json_instance ):
123
- if isinstance (data , six .string_types ):
124
- feature_dict [key ] = tf .train .Feature (
125
- bytes_list = tf .train .BytesList (value = [data .encode ('utf-8' )]))
126
- elif isinstance (data , float ):
127
- feature_dict [key ] = tf .train .Feature (
128
- float_list = tf .train .FloatList (value = [data ]))
129
- elif isinstance (data , int ):
130
- feature_dict [key ] = tf .train .Feature (
131
- int64_list = tf .train .Int64List (value = [data ]))
132
- return tf .train .Example (
133
- features = tf .train .Features (
134
- feature = feature_dict
135
- )
136
- ).SerializeToString ()
137
- # [END census_to_example_bytes]
138
-
139
-
140
- def main (project , model , version = None , force_tfrecord = False ):
62
+ def main (project , model , version = None ):
141
63
"""Send user input to the prediction service."""
142
64
while True :
143
65
try :
@@ -148,16 +70,8 @@ def main(project, model, version=None, force_tfrecord=False):
148
70
if not isinstance (user_input , list ):
149
71
user_input = [user_input ]
150
72
try :
151
- if force_tfrecord :
152
- example_bytes_list = [
153
- census_to_example_bytes (e )
154
- for e in user_input
155
- ]
156
- result = predict_examples (
157
- project , model , example_bytes_list , version = version )
158
- else :
159
- result = predict_json (
160
- project , model , user_input , version = version )
73
+ result = predict_json (
74
+ project , model , user_input , version = version )
161
75
except RuntimeError as err :
162
76
print (str (err ))
163
77
else :
@@ -183,16 +97,9 @@ def main(project, model, version=None, force_tfrecord=False):
183
97
help = 'Name of the version.' ,
184
98
type = str
185
99
)
186
- parser .add_argument (
187
- '--force-tfrecord' ,
188
- help = 'Send predictions as TFRecords rather than raw JSON' ,
189
- action = 'store_true' ,
190
- default = False
191
- )
192
100
args = parser .parse_args ()
193
101
main (
194
102
args .project ,
195
103
args .model ,
196
104
version = args .version ,
197
- force_tfrecord = args .force_tfrecord
198
105
)
0 commit comments