-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathasync_dofn.py
More file actions
454 lines (399 loc) · 17 KB
/
async_dofn.py
File metadata and controls
454 lines (399 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
import logging
import random
import uuid
from concurrent.futures import ThreadPoolExecutor
from math import floor
from threading import RLock
from time import sleep
from time import time
from types import GeneratorType
import apache_beam as beam
from apache_beam import TimeDomain
from apache_beam.coders import coders
from apache_beam.transforms.userstate import BagStateSpec
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
from apache_beam.transforms.userstate import TimerSpec
from apache_beam.transforms.userstate import on_timer
from apache_beam.utils.shared import Shared
from apache_beam.utils.timestamp import Duration
from apache_beam.utils.timestamp import Timestamp
# A wrapper around a dofn that processes that dofn in an asynchronous manner.
class AsyncWrapper(beam.DoFn):
"""Class that wraps a dofn and converts it from one which process elements
synchronously to one which processes them asynchronously.
For synchronous dofns the default settings mean that many (100s) of elements
will be processed in parallel and that processing an element will block all
other work on that key. In addition runners are optimized for latencies less
than a few seconds and longer operations can result in high retry rates. Async
should be considered when the default parallelism is not correct and/or items
are expected to take longer than a few seconds to process.
"""
TIMER = TimerSpec('timer', TimeDomain.REAL_TIME)
TIMER_SET = ReadModifyWriteStateSpec('timer_set', coders.BooleanCoder())
TO_PROCESS = BagStateSpec(
'to_process',
coders.TupleCoder(
[coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()]))
# The below items are one per dofn (not instance) so are maps of UUID to
# value.
_processing_elements = {}
_items_in_buffer = {}
_pool = {}
# Must be reentrant lock so that callbacks can either be called by a pool
# thread OR the thread submitting the callback which must already hold this
# lock.
_lock = RLock()
_verbose_logging = False
def __init__(
self,
sync_fn,
parallelism=1,
callback_frequency=5,
max_items_to_buffer=None,
timeout=1,
max_wait_time=0.5,
id_fn=None,
):
"""Wraps the sync_fn to create an asynchronous version.
Args:
sync_fn: The dofn to wrap. Must take (K, V) as input.
parallelism: The maximum number of elements to process in parallel per
worker for this dofn. By default this is set to 1 as the most common
case for async dofns are heavy CPU or GPU dofns where the dofn requires
the a signficant fraction of the CPU/GPU.
callback_frequency: The frequency with which the runner will check for
elements to commit. A short callback frequency will mean items are
commited shortly after processing but cause additional work for the
worker. A large callback frequency will result in slower commits but
less busy work. The default of 5s will result in a maximum added
latency of 5s while requiring relatively few resources. If your
messages take significantly longer than 5s to process it is recommended
to raise this.
max_items_to_buffer: We should ideally buffer enough to always be busy but
not so much that the worker ooms. By default will be 2x the parallelism
which should be good for most pipelines.
timeout: The maximum amount of time an item should try to be scheduled
locally before it goes in the queue of waiting work.
max_wait_time: The maximum amount of sleep time while attempting to
schedule an item. Used in testing to ensure timeouts are met.
id_fn: A function that returns a hashable object from an element. This
will be used to track items instead of the element's default hash.
"""
self._sync_fn = sync_fn
self._uuid = uuid.uuid4().hex
self._parallelism = parallelism
self._timeout = timeout
self._max_wait_time = max_wait_time
self._timer_frequency = callback_frequency
self._id_fn = id_fn or (lambda x: x)
if max_items_to_buffer is None:
self._max_items_to_buffer = max(parallelism * 2, 10)
else:
self._max_items_to_buffer = max_items_to_buffer
AsyncWrapper._processing_elements[self._uuid] = {}
AsyncWrapper._items_in_buffer[self._uuid] = 0
self.max_wait_time = max_wait_time
self._shared_handle = Shared()
@staticmethod
def initialize_pool(parallelism):
return lambda: ThreadPoolExecutor(max_workers=parallelism)
@staticmethod
def reset_state():
for pool in AsyncWrapper._pool.values():
pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown(
wait=True, cancel_futures=True)
with AsyncWrapper._lock:
AsyncWrapper._pool = {}
AsyncWrapper._processing_elements = {}
AsyncWrapper._items_in_buffer = {}
def setup(self):
"""Forwards to the wrapped dofn's setup method."""
self._sync_fn.setup()
with AsyncWrapper._lock:
if not self._uuid in AsyncWrapper._pool:
AsyncWrapper._pool[self._uuid] = Shared()
AsyncWrapper._processing_elements[self._uuid] = {}
AsyncWrapper._items_in_buffer[self._uuid] = 0
def teardown(self):
"""Forwards to the wrapped dofn's teardown method."""
self._sync_fn.teardown()
def sync_fn_process(self, element, *args, **kwargs):
"""Makes the call to the wrapped dofn's start_bundle, process
methods. It will then combine the results into a single generator.
Args:
element: The element to process.
*args: Any additional arguments to pass to the wrapped dofn's process
method. Will be the same args that the async wrapper is called with.
**kwargs: Any additional keyword arguments to pass to the wrapped dofn's
process method. Will be the same kwargs that the async wrapper is
called with.
Returns:
A generator of elements produced by the input element.
"""
self._sync_fn.start_bundle()
process_result = self._sync_fn.process(element, *args, **kwargs)
bundle_result = self._sync_fn.finish_bundle()
# both process and finish bundle may or may not return generators. We want
# to combine whatever results have been returned into a single generator. If
# they are single elements then wrap them in lists so that we can combine
# them.
if not process_result:
process_result = []
elif not isinstance(process_result, GeneratorType):
process_result = [process_result]
if not bundle_result:
bundle_result = []
elif not isinstance(bundle_result, GeneratorType):
bundle_result = [bundle_result]
to_return = []
for x in process_result:
to_return.append(x)
for x in bundle_result:
to_return.append(x)
return to_return
def decrement_items_in_buffer(self, future):
with AsyncWrapper._lock:
AsyncWrapper._items_in_buffer[self._uuid] -= 1
def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs):
"""Schedules an item to be processed asynchronously if there is room.
Args:
element: The element to process.
ignore_buffer: If true will ignore the buffer limit and schedule the item
regardless of the buffer size. Used when an item needs to skip to the
front such as retries.
*args: arguments that the wrapped dofn requires.
**kwargs: keyword arguments that the wrapped dofn requires.
Returns:
True if the item was scheduled False otherwise.
"""
with AsyncWrapper._lock:
element_id = self._id_fn(element[1])
if element_id in AsyncWrapper._processing_elements[self._uuid]:
logging.info('item %s already in processing elements', element)
return True
if self.accepting_items() or ignore_buffer:
result = AsyncWrapper._pool[self._uuid].acquire(
AsyncWrapper.initialize_pool(self._parallelism)).submit(
lambda: self.sync_fn_process(element, *args, **kwargs),
)
result.add_done_callback(self.decrement_items_in_buffer)
AsyncWrapper._processing_elements[self._uuid][element_id] = (
element, result)
AsyncWrapper._items_in_buffer[self._uuid] += 1
return True
else:
return False
# Add an item to the processing pool. Add the future returned by that item to
# processing_elements_.
def schedule_item(self, element, ignore_buffer=False, *args, **kwargs):
"""Schedules an item to be processed asynchronously.
If the queue is full will block until room opens up.
After calling AsyncWrapper will hold a future pointing to the
result of this processing
Args:
element: The element to process.
ignore_buffer: If true will ignore the buffer limit and schedule the item
regardless of the buffer size. Used when an item needs to skip to the
front such as retries.
*args: arguments that the wrapped dofn requires.
**kwargs: keyword arguments that the wrapped dofn requires.
"""
done = False
sleep_time = 0.01
total_sleep = 0
while not done and total_sleep < self._timeout:
done = self.schedule_if_room(element, ignore_buffer, *args, **kwargs)
if not done:
sleep_time = min(self.max_wait_time, sleep_time * 2)
if self._verbose_logging or total_sleep > 10:
logging.info(
'buffer is full for item %s, %s waiting %s seconds. Have waited'
' for %s seconds.',
element,
AsyncWrapper._items_in_buffer[self._uuid],
sleep_time,
total_sleep,
)
total_sleep += sleep_time
sleep(sleep_time)
def next_time_to_fire(self, key):
random.seed(key)
return (
floor((time() + self._timer_frequency) / self._timer_frequency) *
self._timer_frequency) + (
random.random() * self._timer_frequency)
def accepting_items(self):
with AsyncWrapper._lock:
return (
AsyncWrapper._items_in_buffer[self._uuid] < self._max_items_to_buffer)
def is_empty(self):
with AsyncWrapper._lock:
return AsyncWrapper._items_in_buffer[self._uuid] == 0
# Add the incoming element to a pool of elements to process asynchronously.
def process(
self,
element,
timer=beam.DoFn.TimerParam(TIMER),
to_process=beam.DoFn.StateParam(TO_PROCESS),
*args,
**kwargs):
"""Add the elements to the list of items to be processed asynchronously.
Performs additional bookkeeping to maintain exactly once and set timers to
commit item after it has finished processing.
Args:
element: The element to process.
timer: Callback timer that will commit elements.
to_process: State that keeps track of queued items for exactly once.
*args: arguments that the wrapped dofn requires.
**kwargs: keyword arguments that the wrapped dofn requires.
Returns:
An empty list. The elements will be output asynchronously.
"""
self.schedule_item(element)
to_process.add(element)
# Set a timer to fire on the next round increment of timer_frequency_. Note
# we do this so that each messages timer doesn't get overwritten by the
# next.
time_to_fire = self.next_time_to_fire(element[0])
timer.set(time_to_fire)
# Don't output any elements. This will be done in commit_finished_items.
return []
# Synchronises local state (processing_elements_) with SE state (to_process).
# Then outputs all finished elements. Finally, sets a timer to fire on the
# next round increment of timer_frequency_.
def commit_finished_items(
self,
to_process=beam.DoFn.StateParam(TO_PROCESS),
timer=beam.DoFn.TimerParam(TIMER),
):
"""Commits finished items and synchronizes local state with runner state.
Note timer firings are per key while local state contains messages for all
keys. Only messages for the given key will be output/cleaned up.
Args:
to_process: State that keeps track of queued messagees for this key.
timer: Timer that initiated this commit and can be reset if not all items
have finished..
Returns:
A list of elements that have finished processing for this key.
"""
# For all elements that are in processing state:
# If the element is done processing, delete it from all state and yield the
# output.
# If the element is not yet done, print it. If the element is not in
# local state, schedule it for processing.
items_finished = 0
items_not_yet_finished = 0
items_rescheduled = 0
items_cancelled = 0
items_in_processing_state = 0
items_in_se_state = 0
to_process_local = list(to_process.read())
key = None
to_reschedule = []
if to_process_local:
key = str(to_process_local[0][0])
else:
logging.error(
'no elements in state during timer callback. Timer should not have'
' been set.')
if self._verbose_logging:
logging.info('procesing timer for key: %s', key)
# processing state is per key so we expect this state to only contain a
# given key. Skip items in processing_elements which are for a different
# key.
with AsyncWrapper._lock:
processing_elements = AsyncWrapper._processing_elements[self._uuid]
to_process_local_ids = {self._id_fn(e[1]) for e in to_process_local}
to_remove_ids = []
for element_id, (element, future) in processing_elements.items():
if element[0] == key and element_id not in to_process_local_ids:
items_cancelled += 1
future.cancel()
to_remove_ids.append(element_id)
logging.info(
'cancelling item %s which is no longer in processing state',
element)
for element_id in to_remove_ids:
processing_elements.pop(element_id)
# For all elements which have finished processing output their result.
to_return = []
finished_items = []
for x in to_process_local:
items_in_se_state += 1
x_id = self._id_fn(x[1])
if x_id in processing_elements:
_, future = processing_elements[x_id]
if future.done():
to_return.append(future.result())
finished_items.append(x)
processing_elements.pop(x_id)
items_finished += 1
else:
items_not_yet_finished += 1
else:
logging.info(
'item %s found in processing state but not local state,'
' scheduling now',
x)
to_reschedule.append(x)
items_rescheduled += 1
# Reschedule the items not under a lock
for x in to_reschedule:
self.schedule_item(x, ignore_buffer=False)
# Update processing state to remove elements we've finished
to_process.clear()
for x in to_process_local:
if x not in finished_items:
items_in_processing_state += 1
to_process.add(x)
logging.info('items finished %d', items_finished)
logging.info('items not yet finished %d', items_not_yet_finished)
logging.info('items rescheduled %d', items_rescheduled)
logging.info('items cancelled %d', items_cancelled)
logging.info('items in processing state %d', items_in_processing_state)
logging.info(
'items in buffer %d', AsyncWrapper._items_in_buffer[self._uuid])
# If there are items not yet finished then set a timer to fire in the
# future.
self._next_time_to_fire = Timestamp.now() + Duration(seconds=5)
if items_in_processing_state > 0:
time_to_fire = self.next_time_to_fire(key)
timer.set(time_to_fire)
# Each result is a list. We want to combine them into a single
# list of all elements we wish to output.
merged_return = []
for x in to_return:
merged_return.extend(x)
return merged_return
@on_timer(TIMER)
def timer_callback(
self,
to_process=beam.DoFn.StateParam(TO_PROCESS),
timer=beam.DoFn.TimerParam(TIMER),
):
"""Helper method to commit finished items in response to timer firing.
Args:
to_process: State that keeps track of queued items for exactly once.
timer: Timer that initiated this commit and can be reset if not all items
have finished.
Returns:
A generator of elements that have finished processing for this key.
"""
return self.commit_finished_items(to_process, timer)