2424
2525import mock
2626
27+ from apache_beam .io .gcp .pubsub import PubsubMessage
2728from apache_beam .io .gcp .tests import utils
28- from apache_beam .testing . test_utils import patch_retry
29+ from apache_beam .testing import test_utils
2930
3031# Protect against environments where bigquery library is not available.
3132try :
33+ from google .api_core import exceptions as gexc
3234 from google .cloud import bigquery
33- from google .cloud . exceptions import NotFound
35+ from google .cloud import pubsub
3436except ImportError :
37+ gexc = None
3538 bigquery = None
36- NotFound = None
39+ pubsub = None
3740
3841
3942@unittest .skipIf (bigquery is None , 'Bigquery dependencies are not installed.' )
4043@mock .patch .object (bigquery , 'Client' )
4144class UtilsTest (unittest .TestCase ):
4245
4346 def setUp (self ):
44- patch_retry (self , utils )
47+ test_utils . patch_retry (self , utils )
4548
4649 @mock .patch .object (bigquery , 'Dataset' )
4750 def test_create_bq_dataset (self , mock_dataset , mock_client ):
@@ -68,14 +71,201 @@ def test_delete_table_succeeds(self, mock_client):
6871 def test_delete_table_fails_not_found (self , mock_client ):
6972 mock_client .return_value .dataset .return_value .table .return_value = (
7073 'table_ref' )
71- mock_client .return_value .delete_table .side_effect = NotFound ('test' )
74+ mock_client .return_value .delete_table .side_effect = gexc . NotFound ('test' )
7275
7376 with self .assertRaisesRegexp (Exception , r'does not exist:.*table_ref' ):
7477 utils .delete_bq_table ('unused_project' ,
7578 'unused_dataset' ,
7679 'unused_table' )
7780
7881
82+ @unittest .skipIf (pubsub is None , 'GCP dependencies are not installed' )
83+ class PubSubUtilTest (unittest .TestCase ):
84+
85+ def test_write_to_pubsub (self ):
86+ mock_pubsub = mock .Mock ()
87+ topic_path = "project/fakeproj/topics/faketopic"
88+ data = b'data'
89+ utils .write_to_pubsub (mock_pubsub , topic_path , [data ])
90+ mock_pubsub .publish .assert_has_calls (
91+ [mock .call (topic_path , data ),
92+ mock .call ().result ()])
93+
94+ def test_write_to_pubsub_with_attributes (self ):
95+ mock_pubsub = mock .Mock ()
96+ topic_path = "project/fakeproj/topics/faketopic"
97+ data = b'data'
98+ attributes = {'key' : 'value' }
99+ message = PubsubMessage (data , attributes )
100+ utils .write_to_pubsub (
101+ mock_pubsub , topic_path , [message ], with_attributes = True )
102+ mock_pubsub .publish .assert_has_calls (
103+ [mock .call (topic_path , data , ** attributes ),
104+ mock .call ().result ()])
105+
106+ def test_write_to_pubsub_delay (self ):
107+ number_of_elements = 2
108+ chunk_size = 1
109+ mock_pubsub = mock .Mock ()
110+ topic_path = "project/fakeproj/topics/faketopic"
111+ data = b'data'
112+ with mock .patch ('apache_beam.io.gcp.tests.utils.time' ) as mock_time :
113+ utils .write_to_pubsub (
114+ mock_pubsub ,
115+ topic_path , [data ] * number_of_elements ,
116+ chunk_size = chunk_size ,
117+ delay_between_chunks = 123 )
118+ mock_time .sleep .assert_called_with (123 )
119+ mock_pubsub .publish .assert_has_calls (
120+ [mock .call (topic_path , data ),
121+ mock .call ().result ()] * number_of_elements )
122+
123+ def test_write_to_pubsub_many_chunks (self ):
124+ number_of_elements = 83
125+ chunk_size = 11
126+ mock_pubsub = mock .Mock ()
127+ topic_path = "project/fakeproj/topics/faketopic"
128+ data_list = [
129+ 'data {}' .format (i ).encode ("utf-8" ) for i in range (number_of_elements )
130+ ]
131+ utils .write_to_pubsub (
132+ mock_pubsub , topic_path , data_list , chunk_size = chunk_size )
133+ call_list = []
134+ for start in range (0 , number_of_elements , chunk_size ):
135+ # Publish a batch of messages
136+ call_list += [
137+ mock .call (topic_path , data )
138+ for data in data_list [start :start + chunk_size ]
139+ ]
140+ # Wait for those messages to be received
141+ call_list += [
142+ mock .call ().result () for _ in data_list [start :start + chunk_size ]
143+ ]
144+ mock_pubsub .publish .assert_has_calls (call_list )
145+
146+ def test_read_from_pubsub (self ):
147+ mock_pubsub = mock .Mock ()
148+ subscription_path = "project/fakeproj/subscriptions/fakesub"
149+ data = b'data'
150+ ack_id = 'ack_id'
151+ pull_response = test_utils .create_pull_response (
152+ [test_utils .PullResponseMessage (data , ack_id = ack_id )])
153+ mock_pubsub .pull .return_value = pull_response
154+ output = utils .read_from_pubsub (
155+ mock_pubsub , subscription_path , number_of_elements = 1 )
156+ self .assertEqual ([data ], output )
157+ mock_pubsub .acknowledge .assert_called_once_with (subscription_path , [ack_id ])
158+
159+ def test_read_from_pubsub_with_attributes (self ):
160+ mock_pubsub = mock .Mock ()
161+ subscription_path = "project/fakeproj/subscriptions/fakesub"
162+ data = b'data'
163+ ack_id = 'ack_id'
164+ attributes = {'key' : 'value' }
165+ message = PubsubMessage (data , attributes )
166+ pull_response = test_utils .create_pull_response (
167+ [test_utils .PullResponseMessage (data , attributes , ack_id = ack_id )])
168+ mock_pubsub .pull .return_value = pull_response
169+ output = utils .read_from_pubsub (
170+ mock_pubsub ,
171+ subscription_path ,
172+ with_attributes = True ,
173+ number_of_elements = 1 )
174+ self .assertEqual ([message ], output )
175+ mock_pubsub .acknowledge .assert_called_once_with (subscription_path , [ack_id ])
176+
177+ def test_read_from_pubsub_flaky (self ):
178+ number_of_elements = 10
179+ mock_pubsub = mock .Mock ()
180+ subscription_path = "project/fakeproj/subscriptions/fakesub"
181+ data = b'data'
182+ ack_id = 'ack_id'
183+ pull_response = test_utils .create_pull_response (
184+ [test_utils .PullResponseMessage (data , ack_id = ack_id )])
185+
186+ class FlakyPullResponse (object ):
187+
188+ def __init__ (self , pull_response ):
189+ self .pull_response = pull_response
190+ self ._state = - 1
191+
192+ def __call__ (self , * args , ** kwargs ):
193+ self ._state += 1
194+ if self ._state % 3 == 0 :
195+ raise gexc .RetryError ("" , "" )
196+ if self ._state % 3 == 1 :
197+ raise gexc .DeadlineExceeded ("" )
198+ if self ._state % 3 == 2 :
199+ return self .pull_response
200+
201+ mock_pubsub .pull .side_effect = FlakyPullResponse (pull_response )
202+ output = utils .read_from_pubsub (
203+ mock_pubsub , subscription_path , number_of_elements = number_of_elements )
204+ self .assertEqual ([data ] * number_of_elements , output )
205+ self ._assert_ack_ids_equal (mock_pubsub , [ack_id ] * number_of_elements )
206+
207+ def test_read_from_pubsub_many (self ):
208+ response_size = 33
209+ number_of_elements = 100
210+ mock_pubsub = mock .Mock ()
211+ subscription_path = "project/fakeproj/subscriptions/fakesub"
212+ data_list = [
213+ 'data {}' .format (i ).encode ("utf-8" ) for i in range (number_of_elements )
214+ ]
215+ attributes_list = [{
216+ 'key' : 'value {}' .format (i )
217+ } for i in range (number_of_elements )]
218+ ack_ids = ['ack_id_{}' .format (i ) for i in range (number_of_elements )]
219+ messages = [
220+ PubsubMessage (data , attributes )
221+ for data , attributes in zip (data_list , attributes_list )
222+ ]
223+ response_messages = [
224+ test_utils .PullResponseMessage (data , attributes , ack_id = ack_id )
225+ for data , attributes , ack_id in zip (data_list , attributes_list , ack_ids )
226+ ]
227+
228+ class SequentialPullResponse (object ):
229+
230+ def __init__ (self , response_messages , response_size ):
231+ self .response_messages = response_messages
232+ self .response_size = response_size
233+ self ._index = 0
234+
235+ def __call__ (self , * args , ** kwargs ):
236+ start = self ._index
237+ self ._index += self .response_size
238+ response = test_utils .create_pull_response (
239+ self .response_messages [start :start + self .response_size ])
240+ return response
241+
242+ mock_pubsub .pull .side_effect = SequentialPullResponse (
243+ response_messages , response_size )
244+ output = utils .read_from_pubsub (
245+ mock_pubsub ,
246+ subscription_path ,
247+ with_attributes = True ,
248+ number_of_elements = number_of_elements )
249+ self .assertEqual (messages , output )
250+ self ._assert_ack_ids_equal (mock_pubsub , ack_ids )
251+
252+ def test_read_from_pubsub_invalid_arg (self ):
253+ sub_client = mock .Mock ()
254+ subscription_path = "project/fakeproj/subscriptions/fakesub"
255+ with self .assertRaisesRegexp (ValueError , "number_of_elements" ):
256+ utils .read_from_pubsub (sub_client , subscription_path )
257+ with self .assertRaisesRegexp (ValueError , "number_of_elements" ):
258+ utils .read_from_pubsub (
259+ sub_client , subscription_path , with_attributes = True )
260+
261+ def _assert_ack_ids_equal (self , mock_pubsub , ack_ids ):
262+ actual_ack_ids = [
263+ ack_id for args_list in mock_pubsub .acknowledge .call_args_list
264+ for ack_id in args_list [0 ][1 ]
265+ ]
266+ self .assertEqual (actual_ack_ids , ack_ids )
267+
268+
79269if __name__ == '__main__' :
80270 logging .getLogger ().setLevel (logging .INFO )
81271 unittest .main ()
0 commit comments