Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f792e2e

Browse files
Alexey Strokachaaltay
authored andcommitted
Add helper functions for reading and writing to PubSub directly from Python (#9212)
* Add helper functions for reading and writing to PubSub directly from Python These functions are helpful when writing tests and when working with streaming pipelines interactively (e.g. inside a Jupyter notebook). Notes: - Not sure if apache_beam/testing/test_utils.py is a better place for the helper functions than apache_beam/io/gcp/tests/utils.py? - google.cloud.exceptions seems to have moved to google.api_core.exceptions. Currently, google.cloud.exceptions re-imports some, but not all, of the exceptions defined in google.api_core.exceptions.
1 parent 5504aa7 commit f792e2e

2 files changed

Lines changed: 249 additions & 8 deletions

File tree

sdks/python/apache_beam/io/gcp/tests/utils.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@
2525
import time
2626

2727
from apache_beam.io import filesystems
28+
from apache_beam.io.gcp.pubsub import PubsubMessage
2829
from apache_beam.utils import retry
2930

3031
# Protect against environments where bigquery library is not available.
3132
try:
33+
from google.api_core import exceptions as gexc
3234
from google.cloud import bigquery
33-
from google.cloud.exceptions import NotFound
3435
except ImportError:
36+
gexc = None
3537
bigquery = None
36-
NotFound = None
3738

3839

3940
class GcpTestIOError(retry.PermanentException):
@@ -98,7 +99,7 @@ def delete_bq_table(project, dataset_id, table_id):
9899
table_ref = client.dataset(dataset_id).table(table_id)
99100
try:
100101
client.delete_table(table_ref)
101-
except NotFound:
102+
except gexc.NotFound:
102103
raise GcpTestIOError('BigQuery table does not exist: %s' % table_ref)
103104

104105

@@ -113,3 +114,53 @@ def delete_directory(directory):
113114
"gs://mybucket/mydir/", "s3://...", ...)
114115
"""
115116
filesystems.FileSystems.delete([directory])
117+
118+
119+
def write_to_pubsub(pub_client,
120+
topic_path,
121+
messages,
122+
with_attributes=False,
123+
chunk_size=100,
124+
delay_between_chunks=0.1):
125+
for start in range(0, len(messages), chunk_size):
126+
message_chunk = messages[start:start + chunk_size]
127+
if with_attributes:
128+
futures = [
129+
pub_client.publish(topic_path, message.data, **message.attributes)
130+
for message in message_chunk
131+
]
132+
else:
133+
futures = [
134+
pub_client.publish(topic_path, message) for message in message_chunk
135+
]
136+
for future in futures:
137+
future.result()
138+
time.sleep(delay_between_chunks)
139+
140+
141+
def read_from_pubsub(sub_client,
142+
subscription_path,
143+
with_attributes=False,
144+
number_of_elements=None,
145+
timeout=None):
146+
if number_of_elements is None and timeout is None:
147+
raise ValueError("Either number_of_elements or timeout must be specified.")
148+
messages = []
149+
start_time = time.time()
150+
151+
while ((number_of_elements is None or len(messages) < number_of_elements) and
152+
(timeout is None or (time.time() - start_time) < timeout)):
153+
try:
154+
response = sub_client.pull(
155+
subscription_path, max_messages=1000, retry=None, timeout=10)
156+
except (gexc.RetryError, gexc.DeadlineExceeded):
157+
continue
158+
ack_ids = [msg.ack_id for msg in response.received_messages]
159+
sub_client.acknowledge(subscription_path, ack_ids)
160+
for msg in response.received_messages:
161+
message = PubsubMessage._from_message(msg.message)
162+
if with_attributes:
163+
messages.append(message)
164+
else:
165+
messages.append(message.data)
166+
return messages

sdks/python/apache_beam/io/gcp/tests/utils_test.py

Lines changed: 195 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,27 @@
2424

2525
import mock
2626

27+
from apache_beam.io.gcp.pubsub import PubsubMessage
2728
from apache_beam.io.gcp.tests import utils
28-
from apache_beam.testing.test_utils import patch_retry
29+
from apache_beam.testing import test_utils
2930

3031
# Protect against environments where bigquery library is not available.
3132
try:
33+
from google.api_core import exceptions as gexc
3234
from google.cloud import bigquery
33-
from google.cloud.exceptions import NotFound
35+
from google.cloud import pubsub
3436
except ImportError:
37+
gexc = None
3538
bigquery = None
36-
NotFound = None
39+
pubsub = None
3740

3841

3942
@unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.')
4043
@mock.patch.object(bigquery, 'Client')
4144
class UtilsTest(unittest.TestCase):
4245

4346
def setUp(self):
44-
patch_retry(self, utils)
47+
test_utils.patch_retry(self, utils)
4548

4649
@mock.patch.object(bigquery, 'Dataset')
4750
def test_create_bq_dataset(self, mock_dataset, mock_client):
@@ -68,14 +71,201 @@ def test_delete_table_succeeds(self, mock_client):
6871
def test_delete_table_fails_not_found(self, mock_client):
6972
mock_client.return_value.dataset.return_value.table.return_value = (
7073
'table_ref')
71-
mock_client.return_value.delete_table.side_effect = NotFound('test')
74+
mock_client.return_value.delete_table.side_effect = gexc.NotFound('test')
7275

7376
with self.assertRaisesRegexp(Exception, r'does not exist:.*table_ref'):
7477
utils.delete_bq_table('unused_project',
7578
'unused_dataset',
7679
'unused_table')
7780

7881

82+
@unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
83+
class PubSubUtilTest(unittest.TestCase):
84+
85+
def test_write_to_pubsub(self):
86+
mock_pubsub = mock.Mock()
87+
topic_path = "project/fakeproj/topics/faketopic"
88+
data = b'data'
89+
utils.write_to_pubsub(mock_pubsub, topic_path, [data])
90+
mock_pubsub.publish.assert_has_calls(
91+
[mock.call(topic_path, data),
92+
mock.call().result()])
93+
94+
def test_write_to_pubsub_with_attributes(self):
95+
mock_pubsub = mock.Mock()
96+
topic_path = "project/fakeproj/topics/faketopic"
97+
data = b'data'
98+
attributes = {'key': 'value'}
99+
message = PubsubMessage(data, attributes)
100+
utils.write_to_pubsub(
101+
mock_pubsub, topic_path, [message], with_attributes=True)
102+
mock_pubsub.publish.assert_has_calls(
103+
[mock.call(topic_path, data, **attributes),
104+
mock.call().result()])
105+
106+
def test_write_to_pubsub_delay(self):
107+
number_of_elements = 2
108+
chunk_size = 1
109+
mock_pubsub = mock.Mock()
110+
topic_path = "project/fakeproj/topics/faketopic"
111+
data = b'data'
112+
with mock.patch('apache_beam.io.gcp.tests.utils.time') as mock_time:
113+
utils.write_to_pubsub(
114+
mock_pubsub,
115+
topic_path, [data] * number_of_elements,
116+
chunk_size=chunk_size,
117+
delay_between_chunks=123)
118+
mock_time.sleep.assert_called_with(123)
119+
mock_pubsub.publish.assert_has_calls(
120+
[mock.call(topic_path, data),
121+
mock.call().result()] * number_of_elements)
122+
123+
def test_write_to_pubsub_many_chunks(self):
124+
number_of_elements = 83
125+
chunk_size = 11
126+
mock_pubsub = mock.Mock()
127+
topic_path = "project/fakeproj/topics/faketopic"
128+
data_list = [
129+
'data {}'.format(i).encode("utf-8") for i in range(number_of_elements)
130+
]
131+
utils.write_to_pubsub(
132+
mock_pubsub, topic_path, data_list, chunk_size=chunk_size)
133+
call_list = []
134+
for start in range(0, number_of_elements, chunk_size):
135+
# Publish a batch of messages
136+
call_list += [
137+
mock.call(topic_path, data)
138+
for data in data_list[start:start + chunk_size]
139+
]
140+
# Wait for those messages to be received
141+
call_list += [
142+
mock.call().result() for _ in data_list[start:start + chunk_size]
143+
]
144+
mock_pubsub.publish.assert_has_calls(call_list)
145+
146+
def test_read_from_pubsub(self):
147+
mock_pubsub = mock.Mock()
148+
subscription_path = "project/fakeproj/subscriptions/fakesub"
149+
data = b'data'
150+
ack_id = 'ack_id'
151+
pull_response = test_utils.create_pull_response(
152+
[test_utils.PullResponseMessage(data, ack_id=ack_id)])
153+
mock_pubsub.pull.return_value = pull_response
154+
output = utils.read_from_pubsub(
155+
mock_pubsub, subscription_path, number_of_elements=1)
156+
self.assertEqual([data], output)
157+
mock_pubsub.acknowledge.assert_called_once_with(subscription_path, [ack_id])
158+
159+
def test_read_from_pubsub_with_attributes(self):
160+
mock_pubsub = mock.Mock()
161+
subscription_path = "project/fakeproj/subscriptions/fakesub"
162+
data = b'data'
163+
ack_id = 'ack_id'
164+
attributes = {'key': 'value'}
165+
message = PubsubMessage(data, attributes)
166+
pull_response = test_utils.create_pull_response(
167+
[test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)])
168+
mock_pubsub.pull.return_value = pull_response
169+
output = utils.read_from_pubsub(
170+
mock_pubsub,
171+
subscription_path,
172+
with_attributes=True,
173+
number_of_elements=1)
174+
self.assertEqual([message], output)
175+
mock_pubsub.acknowledge.assert_called_once_with(subscription_path, [ack_id])
176+
177+
def test_read_from_pubsub_flaky(self):
178+
number_of_elements = 10
179+
mock_pubsub = mock.Mock()
180+
subscription_path = "project/fakeproj/subscriptions/fakesub"
181+
data = b'data'
182+
ack_id = 'ack_id'
183+
pull_response = test_utils.create_pull_response(
184+
[test_utils.PullResponseMessage(data, ack_id=ack_id)])
185+
186+
class FlakyPullResponse(object):
187+
188+
def __init__(self, pull_response):
189+
self.pull_response = pull_response
190+
self._state = -1
191+
192+
def __call__(self, *args, **kwargs):
193+
self._state += 1
194+
if self._state % 3 == 0:
195+
raise gexc.RetryError("", "")
196+
if self._state % 3 == 1:
197+
raise gexc.DeadlineExceeded("")
198+
if self._state % 3 == 2:
199+
return self.pull_response
200+
201+
mock_pubsub.pull.side_effect = FlakyPullResponse(pull_response)
202+
output = utils.read_from_pubsub(
203+
mock_pubsub, subscription_path, number_of_elements=number_of_elements)
204+
self.assertEqual([data] * number_of_elements, output)
205+
self._assert_ack_ids_equal(mock_pubsub, [ack_id] * number_of_elements)
206+
207+
def test_read_from_pubsub_many(self):
208+
response_size = 33
209+
number_of_elements = 100
210+
mock_pubsub = mock.Mock()
211+
subscription_path = "project/fakeproj/subscriptions/fakesub"
212+
data_list = [
213+
'data {}'.format(i).encode("utf-8") for i in range(number_of_elements)
214+
]
215+
attributes_list = [{
216+
'key': 'value {}'.format(i)
217+
} for i in range(number_of_elements)]
218+
ack_ids = ['ack_id_{}'.format(i) for i in range(number_of_elements)]
219+
messages = [
220+
PubsubMessage(data, attributes)
221+
for data, attributes in zip(data_list, attributes_list)
222+
]
223+
response_messages = [
224+
test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)
225+
for data, attributes, ack_id in zip(data_list, attributes_list, ack_ids)
226+
]
227+
228+
class SequentialPullResponse(object):
229+
230+
def __init__(self, response_messages, response_size):
231+
self.response_messages = response_messages
232+
self.response_size = response_size
233+
self._index = 0
234+
235+
def __call__(self, *args, **kwargs):
236+
start = self._index
237+
self._index += self.response_size
238+
response = test_utils.create_pull_response(
239+
self.response_messages[start:start + self.response_size])
240+
return response
241+
242+
mock_pubsub.pull.side_effect = SequentialPullResponse(
243+
response_messages, response_size)
244+
output = utils.read_from_pubsub(
245+
mock_pubsub,
246+
subscription_path,
247+
with_attributes=True,
248+
number_of_elements=number_of_elements)
249+
self.assertEqual(messages, output)
250+
self._assert_ack_ids_equal(mock_pubsub, ack_ids)
251+
252+
def test_read_from_pubsub_invalid_arg(self):
253+
sub_client = mock.Mock()
254+
subscription_path = "project/fakeproj/subscriptions/fakesub"
255+
with self.assertRaisesRegexp(ValueError, "number_of_elements"):
256+
utils.read_from_pubsub(sub_client, subscription_path)
257+
with self.assertRaisesRegexp(ValueError, "number_of_elements"):
258+
utils.read_from_pubsub(
259+
sub_client, subscription_path, with_attributes=True)
260+
261+
def _assert_ack_ids_equal(self, mock_pubsub, ack_ids):
262+
actual_ack_ids = [
263+
ack_id for args_list in mock_pubsub.acknowledge.call_args_list
264+
for ack_id in args_list[0][1]
265+
]
266+
self.assertEqual(actual_ack_ids, ack_ids)
267+
268+
79269
if __name__ == '__main__':
80270
logging.getLogger().setLevel(logging.INFO)
81271
unittest.main()

0 commit comments

Comments
 (0)