Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d545ece

Browse files
authored
Merge pull request #11128 from [BEAM-9524] Fix for ib.show() executing indefinitely
* Fix ib.show() spinning forever when rexecuting cells without kernel restart Change-Id: I53aa32a75645086efffa091a53880a076c3a689d * Add CacheKey class Change-Id: I1ab6e7036172d7e2d07c774778a50e165df6bdca * fix dep loop Change-Id: I247f37cd7acffb6ad796ce0fa8b54b0feff400d1
1 parent 12e6e4a commit d545ece

7 files changed

Lines changed: 182 additions & 130 deletions

File tree

sdks/python/apache_beam/runners/interactive/background_caching_job.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,20 @@ def is_background_caching_job_needed(user_pipeline):
189189
cache_changed))
190190

191191

192+
def is_cache_complete(pipeline_id):
193+
# type: (str) -> bool
194+
195+
"""Returns True if the backgrond cache for the given pipeline is done.
196+
"""
197+
user_pipeline = ie.current_env().pipeline_id_to_pipeline(pipeline_id)
198+
job = ie.current_env().get_background_caching_job(user_pipeline)
199+
is_done = job and job.is_done()
200+
cache_changed = is_source_to_cache_changed(
201+
user_pipeline, update_cached_source_signature=False)
202+
203+
return is_done and not cache_changed
204+
205+
192206
def has_source_to_cache(user_pipeline):
193207
"""Determines if a user-defined pipeline contains any source that need to be
194208
cached. If so, also immediately wrap current cache manager held by current
@@ -208,14 +222,6 @@ def has_source_to_cache(user_pipeline):
208222
if has_cache:
209223
if not isinstance(ie.current_env().cache_manager(),
210224
streaming_cache.StreamingCache):
211-
# Wrap the cache manager into a streaming cache manager. Note this
212-
# does not invalidate the current cache manager.
213-
def is_cache_complete():
214-
job = ie.current_env().get_background_caching_job(user_pipeline)
215-
is_done = job and job.is_done()
216-
cache_changed = is_source_to_cache_changed(
217-
user_pipeline, update_cached_source_signature=False)
218-
return is_done and not cache_changed
219225

220226
ie.current_env().set_cache_manager(
221227
streaming_cache.StreamingCache(

sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def __init__(
153153
self._coder = coder
154154
self._labels = labels
155155
self._is_cache_complete = (
156-
is_cache_complete if is_cache_complete else lambda: True)
156+
is_cache_complete if is_cache_complete else lambda _: True)
157+
158+
from apache_beam.runners.interactive.pipeline_instrument import CacheKey
159+
self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
157160

158161
def _wait_until_file_exists(self, timeout_secs=30):
159162
"""Blocks until the file exists for a maximum of timeout_secs.
@@ -186,7 +189,7 @@ def _emit_from_file(self, fh, tail):
186189
# Check if we are at EOF or if we have an incomplete line.
187190
if not line or (line and line[-1] != b'\n'[0]):
188191
# Complete reading only when the cache is complete.
189-
if self._is_cache_complete():
192+
if self._is_cache_complete(self._pipeline_id):
190193
break
191194

192195
if not tail:
@@ -273,8 +276,7 @@ def read(self, *labels):
273276
return iter([]), -1
274277

275278
reader = StreamingCacheSource(
276-
self._cache_dir, labels,
277-
is_cache_complete=self._is_cache_complete).read(tail=False)
279+
self._cache_dir, labels, self._is_cache_complete).read(tail=False)
278280
header = next(reader)
279281
return StreamingCache.Reader([header], [reader]).read(), 1
280282

@@ -286,9 +288,8 @@ def read_multiple(self, labels):
286288
pipeline runtime which needs to block.
287289
"""
288290
readers = [
289-
StreamingCacheSource(
290-
self._cache_dir, l,
291-
is_cache_complete=self._is_cache_complete).read(tail=True)
291+
StreamingCacheSource(self._cache_dir, l,
292+
self._is_cache_complete).read(tail=True)
292293
for l in labels
293294
]
294295
headers = [next(r) for r in readers]

sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
3030
from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
3131
from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
32+
from apache_beam.runners.interactive.pipeline_instrument import CacheKey
3233
from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder
3334
from apache_beam.testing.test_pipeline import TestPipeline
3435
from apache_beam.testing.test_stream import TestStream
@@ -56,7 +57,7 @@ def test_exists(self):
5657
def test_single_reader(self):
5758
"""Tests that we expect to see all the correctly emitted TestStreamPayloads.
5859
"""
59-
CACHED_PCOLLECTION_KEY = 'arbitrary_key'
60+
CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
6061

6162
values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
6263
.add_element(element=0, event_time_secs=0)
@@ -109,9 +110,9 @@ def test_multiple_readers(self):
109110
"""Tests that the service advances the clock with multiple outputs.
110111
"""
111112

112-
CACHED_LETTERS = 'letters'
113-
CACHED_NUMBERS = 'numbers'
114-
CACHED_LATE = 'late'
113+
CACHED_LETTERS = repr(CacheKey('letters', '', '', ''))
114+
CACHED_NUMBERS = repr(CacheKey('numbers', '', '', ''))
115+
CACHED_LATE = repr(CacheKey('late', '', '', ''))
115116

116117
letters = (FileRecordsBuilder(CACHED_LETTERS)
117118
.advance_processing_time(1)
@@ -235,21 +236,22 @@ def test_read_and_write(self):
235236
This ensures that the sink and source speak the same language in terms of
236237
coders, protos, order, and units.
237238
"""
239+
CACHED_RECORDS = repr(CacheKey('records', '', '', ''))
238240

239241
# Units here are in seconds.
240242
test_stream = (TestStream()
241-
.advance_watermark_to(0, tag='records')
243+
.advance_watermark_to(0, tag=CACHED_RECORDS)
242244
.advance_processing_time(5)
243-
.add_elements(['a', 'b', 'c'], tag='records')
244-
.advance_watermark_to(10, tag='records')
245+
.add_elements(['a', 'b', 'c'], tag=CACHED_RECORDS)
246+
.advance_watermark_to(10, tag=CACHED_RECORDS)
245247
.advance_processing_time(1)
246248
.add_elements(
247249
[
248250
TimestampedValue('1', 15),
249251
TimestampedValue('2', 15),
250252
TimestampedValue('3', 15)
251253
],
252-
tag='records')) # yapf: disable
254+
tag=CACHED_RECORDS)) # yapf: disable
253255

254256
coder = SafeFastPrimitivesCoder()
255257
cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0)
@@ -259,9 +261,9 @@ def test_read_and_write(self):
259261
'passthrough_pcollection_output_ids')
260262
with TestPipeline(options=options) as p:
261263
# pylint: disable=expression-not-assigned
262-
p | test_stream | cache.sink(['records'])
264+
p | test_stream | cache.sink([CACHED_RECORDS])
263265

264-
reader, _ = cache.read('records')
266+
reader, _ = cache.read(CACHED_RECORDS)
265267
actual_events = list(reader)
266268

267269
# Units here are in microseconds.
@@ -271,7 +273,7 @@ def test_read_and_write(self):
271273
advance_duration=5 * 10**6)),
272274
TestStreamPayload.Event(
273275
watermark_event=TestStreamPayload.Event.AdvanceWatermark(
274-
new_watermark=0, tag='records')),
276+
new_watermark=0, tag=CACHED_RECORDS)),
275277
TestStreamPayload.Event(
276278
element_event=TestStreamPayload.Event.AddElements(
277279
elements=[
@@ -282,13 +284,13 @@ def test_read_and_write(self):
282284
TestStreamPayload.TimestampedElement(
283285
encoded_element=coder.encode('c'), timestamp=0),
284286
],
285-
tag='records')),
287+
tag=CACHED_RECORDS)),
286288
TestStreamPayload.Event(
287289
processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
288290
advance_duration=1 * 10**6)),
289291
TestStreamPayload.Event(
290292
watermark_event=TestStreamPayload.Event.AdvanceWatermark(
291-
new_watermark=10 * 10**6, tag='records')),
293+
new_watermark=10 * 10**6, tag=CACHED_RECORDS)),
292294
TestStreamPayload.Event(
293295
element_event=TestStreamPayload.Event.AddElements(
294296
elements=[
@@ -302,7 +304,7 @@ def test_read_and_write(self):
302304
encoded_element=coder.encode('3'), timestamp=15 *
303305
10**6),
304306
],
305-
tag='records')),
307+
tag=CACHED_RECORDS)),
306308
]
307309
self.assertEqual(actual_events, expected_events)
308310

@@ -312,8 +314,8 @@ def test_read_and_write_multiple_outputs(self):
312314
This tests the funcionatlity that the StreamingCache reads from multiple
313315
files and combines them into a single sorted output.
314316
"""
315-
LETTERS_TAG = 'letters'
316-
NUMBERS_TAG = 'numbers'
317+
LETTERS_TAG = repr(CacheKey('letters', '', '', ''))
318+
NUMBERS_TAG = repr(CacheKey('numbers', '', '', ''))
317319

318320
# Units here are in seconds.
319321
test_stream = (TestStream()

sdks/python/apache_beam/runners/interactive/interactive_environment.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,19 +134,19 @@ def __init__(self, cache_manager=None):
134134
self._watching_set = set()
135135
# Holds variables list of (Dict[str, object]).
136136
self._watching_dict_list = []
137-
# Holds results of main jobs as Dict[Pipeline, PipelineResult].
137+
# Holds results of main jobs as Dict[str, PipelineResult].
138138
# Each key is a pipeline instance defined by the end user. The
139139
# InteractiveRunner is responsible for populating this dictionary
140140
# implicitly.
141141
self._main_pipeline_results = {}
142-
# Holds background caching jobs as Dict[Pipeline, BackgroundCachingJob].
142+
# Holds background caching jobs as Dict[str, BackgroundCachingJob].
143143
# Each key is a pipeline instance defined by the end user. The
144144
# InteractiveRunner or its enclosing scope is responsible for populating
145145
# this dictionary implicitly when a background caching jobs is started.
146146
self._background_caching_jobs = {}
147147
# Holds TestStreamServiceControllers that controls gRPC servers serving
148148
# events as test stream of TestStreamPayload.Event.
149-
# Dict[Pipeline, TestStreamServiceController]. Each key is a pipeline
149+
# Dict[str, TestStreamServiceController]. Each key is a pipeline
150150
# instance defined by the end user. The InteractiveRunner or its enclosing
151151
# scope is responsible for populating this dictionary implicitly when a new
152152
# controller is created to start a new gRPC server. The server stays alive
@@ -301,15 +301,15 @@ def set_pipeline_result(self, pipeline, result):
301301
assert issubclass(type(result), runner.PipelineResult), (
302302
'result must be an instance of '
303303
'apache_beam.runners.runner.PipelineResult or its subclass')
304-
self._main_pipeline_results[pipeline] = result
304+
self._main_pipeline_results[str(id(pipeline))] = result
305305

306306
def evict_pipeline_result(self, pipeline):
307307
"""Evicts the tracking of given pipeline run. Noop if absent."""
308-
return self._main_pipeline_results.pop(pipeline, None)
308+
return self._main_pipeline_results.pop(str(id(pipeline)), None)
309309

310310
def pipeline_result(self, pipeline):
311311
"""Gets the pipeline run result. None if absent."""
312-
return self._main_pipeline_results.get(pipeline, None)
312+
return self._main_pipeline_results.get(str(id(pipeline)), None)
313313

314314
def set_background_caching_job(self, pipeline, background_caching_job):
315315
"""Sets the background caching job started from the given pipeline."""
@@ -318,32 +318,32 @@ def set_background_caching_job(self, pipeline, background_caching_job):
318318
from apache_beam.runners.interactive.background_caching_job import BackgroundCachingJob
319319
assert isinstance(background_caching_job, BackgroundCachingJob), (
320320
'background_caching job must be an instance of BackgroundCachingJob')
321-
self._background_caching_jobs[pipeline] = background_caching_job
321+
self._background_caching_jobs[str(id(pipeline))] = background_caching_job
322322

323323
def get_background_caching_job(self, pipeline):
324324
"""Gets the background caching job started from the given pipeline."""
325-
return self._background_caching_jobs.get(pipeline, None)
325+
return self._background_caching_jobs.get(str(id(pipeline)), None)
326326

327327
def set_test_stream_service_controller(self, pipeline, controller):
328328
"""Sets the test stream service controller that has started a gRPC server
329329
serving the test stream for any job started from the given user-defined
330330
pipeline.
331331
"""
332-
self._test_stream_service_controllers[pipeline] = controller
332+
self._test_stream_service_controllers[str(id(pipeline))] = controller
333333

334334
def get_test_stream_service_controller(self, pipeline):
335335
"""Gets the test stream service controller that has started a gRPC server
336336
serving the test stream for any job started from the given user-defined
337337
pipeline.
338338
"""
339-
return self._test_stream_service_controllers.get(pipeline, None)
339+
return self._test_stream_service_controllers.get(str(id(pipeline)), None)
340340

341341
def evict_test_stream_service_controller(self, pipeline):
342342
"""Evicts and pops the test stream service controller that has started a
343343
gRPC server serving the test stream for any job started from the given
344344
user-defined pipeline.
345345
"""
346-
return self._test_stream_service_controllers.pop(pipeline, None)
346+
return self._test_stream_service_controllers.pop(str(id(pipeline)), None)
347347

348348
def is_terminated(self, pipeline):
349349
"""Queries if the most recent job (by executing the given pipeline) state
@@ -354,14 +354,14 @@ def is_terminated(self, pipeline):
354354
return True
355355

356356
def set_cached_source_signature(self, pipeline, signature):
357-
self._cached_source_signature[pipeline] = signature
357+
self._cached_source_signature[str(id(pipeline))] = signature
358358

359359
def get_cached_source_signature(self, pipeline):
360-
return self._cached_source_signature.get(pipeline, set())
360+
return self._cached_source_signature.get(str(id(pipeline)), set())
361361

362362
def evict_cached_source_signature(self, pipeline=None):
363363
if pipeline:
364-
self._cached_source_signature.pop(pipeline, None)
364+
self._cached_source_signature.pop(str(id(pipeline)), None)
365365
else:
366366
self._cached_source_signature.clear()
367367

@@ -395,6 +395,13 @@ def track_user_pipelines(self):
395395
def tracked_user_pipelines(self):
396396
return self._tracked_user_pipelines
397397

398+
def pipeline_id_to_pipeline(self, pid):
399+
"""Converts a pipeline id to a user pipeline.
400+
"""
401+
402+
pid_to_pipelines = {str(id(p)): p for p in self._tracked_user_pipelines}
403+
return pid_to_pipelines[pid]
404+
398405
def mark_pcollection_computed(self, pcolls):
399406
"""Marks computation completeness for the given pcolls.
400407

sdks/python/apache_beam/runners/interactive/interactive_runner_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,11 @@ def test_mark_pcollection_completed_after_successful_run(self, cell):
185185
ib.watch(locals())
186186
result = p.run()
187187
self.assertTrue(init in ie.current_env().computed_pcollections)
188-
self.assertEqual([0, 1, 2, 3, 4], list(result.get(init)))
188+
self.assertEqual({0, 1, 2, 3, 4}, set(result.get(init)))
189189
self.assertTrue(square in ie.current_env().computed_pcollections)
190-
self.assertEqual([0, 1, 4, 9, 16], list(result.get(square)))
190+
self.assertEqual({0, 1, 4, 9, 16}, set(result.get(square)))
191191
self.assertTrue(cube in ie.current_env().computed_pcollections)
192-
self.assertEqual([0, 1, 8, 27, 64], list(result.get(cube)))
192+
self.assertEqual({0, 1, 8, 27, 64}, set(result.get(cube)))
193193

194194

195195
if __name__ == '__main__':

0 commit comments

Comments
 (0)