Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e087818

Browse files
author
Alec Glassford
authored
Fix AI Platform online prediction tests (GoogleCloudPlatform#4791)
## Description Fixes GoogleCloudPlatform#4776 and fixes GoogleCloudPlatform#4777 by using a new model version (created in the Cloud Console based on the same trained ML model) and updating code accordingly. It's unknown why the old model version that the test used stopped working. Fixes GoogleCloudPlatform#4778 by removing the code in question, which is no longer used in documentation. ## Checklist - [x] I have followed [Sample Guidelines from AUTHORING_GUIDE.MD](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/AUTHORING_GUIDE.md) - [ ] README is updated to include [all relevant information](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/AUTHORING_GUIDE.md#readme-file) - [x] **Tests** pass: `nox -s py-3.6` (see [Test Environment Setup](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/AUTHORING_GUIDE.md#test-environment-setup)) - [x] **Lint** pass: `nox -s lint` (see [Test Environment Setup](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/AUTHORING_GUIDE.md#test-environment-setup)) - [ ] These samples need a new **API enabled** in testing projects to pass (let us know which ones) - [ ] These samples need a new/updated **env vars** in testing projects set to pass (let us know which ones) - [x] Please **merge** this PR for me once it is approved. - [ ] This sample adds a new sample directory, and I updated the [CODEOWNERS file](https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/.github/CODEOWNERS) with the codeowners for this sample
1 parent f6ad120 commit e087818

File tree

2 files changed

+5
-118
lines changed

2 files changed

+5
-118
lines changed

ml_engine/online_prediction/predict.py

Lines changed: 3 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515

1616
"""Examples of using AI Platform's online prediction service."""
1717
import argparse
18-
import base64
1918
import json
2019

2120
# [START import_libraries]
2221
import googleapiclient.discovery
23-
import six
2422
# [END import_libraries]
2523

2624

@@ -61,83 +59,7 @@ def predict_json(project, model, instances, version=None):
6159
# [END predict_json]
6260

6361

64-
# [START predict_tf_records]
65-
def predict_examples(project,
66-
model,
67-
example_bytes_list,
68-
version=None):
69-
"""Send protocol buffer data to a deployed model for prediction.
70-
71-
Args:
72-
project (str): project where the AI Platform Model is deployed.
73-
model (str): model name.
74-
example_bytes_list ([str]): A list of bytestrings representing
75-
serialized tf.train.Example protocol buffers. The contents of this
76-
protocol buffer will change depending on the signature of your
77-
deployed model.
78-
version: str, version of the model to target.
79-
Returns:
80-
Mapping[str: any]: dictionary of prediction results defined by the
81-
model.
82-
"""
83-
service = googleapiclient.discovery.build('ml', 'v1')
84-
name = 'projects/{}/models/{}'.format(project, model)
85-
86-
if version is not None:
87-
name += '/versions/{}'.format(version)
88-
89-
response = service.projects().predict(
90-
name=name,
91-
body={'instances': [
92-
{'b64': base64.b64encode(example_bytes).decode('utf-8')}
93-
for example_bytes in example_bytes_list
94-
]}
95-
).execute()
96-
97-
if 'error' in response:
98-
raise RuntimeError(response['error'])
99-
100-
return response['predictions']
101-
# [END predict_tf_records]
102-
103-
104-
# [START census_to_example_bytes]
105-
def census_to_example_bytes(json_instance):
106-
"""Serialize a JSON example to the bytes of a tf.train.Example.
107-
This method is specific to the signature of the Census example.
108-
See: https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
109-
for details.
110-
111-
Args:
112-
json_instance (Mapping[str: Any]): Keys should be the names of Tensors
113-
your deployed model expects to parse using it's tf.FeatureSpec.
114-
Values should be datatypes convertible to Tensors, or (potentially
115-
nested) lists of datatypes convertible to tensors.
116-
Returns:
117-
str: A string as a container for the serialized bytes of
118-
tf.train.Example protocol buffer.
119-
"""
120-
import tensorflow as tf
121-
feature_dict = {}
122-
for key, data in six.iteritems(json_instance):
123-
if isinstance(data, six.string_types):
124-
feature_dict[key] = tf.train.Feature(
125-
bytes_list=tf.train.BytesList(value=[data.encode('utf-8')]))
126-
elif isinstance(data, float):
127-
feature_dict[key] = tf.train.Feature(
128-
float_list=tf.train.FloatList(value=[data]))
129-
elif isinstance(data, int):
130-
feature_dict[key] = tf.train.Feature(
131-
int64_list=tf.train.Int64List(value=[data]))
132-
return tf.train.Example(
133-
features=tf.train.Features(
134-
feature=feature_dict
135-
)
136-
).SerializeToString()
137-
# [END census_to_example_bytes]
138-
139-
140-
def main(project, model, version=None, force_tfrecord=False):
62+
def main(project, model, version=None):
14163
"""Send user input to the prediction service."""
14264
while True:
14365
try:
@@ -148,16 +70,8 @@ def main(project, model, version=None, force_tfrecord=False):
14870
if not isinstance(user_input, list):
14971
user_input = [user_input]
15072
try:
151-
if force_tfrecord:
152-
example_bytes_list = [
153-
census_to_example_bytes(e)
154-
for e in user_input
155-
]
156-
result = predict_examples(
157-
project, model, example_bytes_list, version=version)
158-
else:
159-
result = predict_json(
160-
project, model, user_input, version=version)
73+
result = predict_json(
74+
project, model, user_input, version=version)
16175
except RuntimeError as err:
16276
print(str(err))
16377
else:
@@ -183,16 +97,9 @@ def main(project, model, version=None, force_tfrecord=False):
18397
help='Name of the version.',
18498
type=str
18599
)
186-
parser.add_argument(
187-
'--force-tfrecord',
188-
help='Send predictions as TFRecords rather than raw JSON',
189-
action='store_true',
190-
default=False
191-
)
192100
args = parser.parse_args()
193101
main(
194102
args.project,
195103
args.model,
196104
version=args.version,
197-
force_tfrecord=args.force_tfrecord
198105
)

ml_engine/online_prediction/predict_test.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
import predict
2121

2222
MODEL = 'census'
23-
JSON_VERSION = 'v1json'
24-
EXAMPLES_VERSION = 'v1example'
23+
JSON_VERSION = 'v2json'
2524
PROJECT = 'python-docs-samples-tests'
2625
EXPECTED_OUTPUT = {
27-
u'confidence': 0.7760371565818787,
26+
u'confidence': 0.7760370969772339,
2827
u'predictions': u' <=50K'
2928
}
3029

@@ -37,10 +36,6 @@
3736
JSON = json.load(f)
3837

3938

40-
with open('resources/census_example_bytes.pb', 'rb') as f:
41-
BYTESTRING = f.read()
42-
43-
4439
@pytest.mark.flaky
4540
def test_predict_json():
4641
result = predict.predict_json(
@@ -53,18 +48,3 @@ def test_predict_json_error():
5348
with pytest.raises(RuntimeError):
5449
predict.predict_json(
5550
PROJECT, MODEL, [{"foo": "bar"}], version=JSON_VERSION)
56-
57-
58-
@pytest.mark.flaky
59-
def test_census_example_to_bytes():
60-
import tensorflow as tf
61-
b = predict.census_to_example_bytes(JSON)
62-
assert tf.train.Example.FromString(b) == tf.train.Example.FromString(
63-
BYTESTRING)
64-
65-
66-
@pytest.mark.flaky(max_runs=6)
67-
def test_predict_examples():
68-
result = predict.predict_examples(
69-
PROJECT, MODEL, [BYTESTRING, BYTESTRING], version=EXAMPLES_VERSION)
70-
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result

0 commit comments

Comments
 (0)