-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathdask_runner.py
More file actions
253 lines (217 loc) · 9.07 KB
/
dask_runner.py
File metadata and controls
253 lines (217 loc) · 9.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""DaskRunner, executing remote jobs on Dask.distributed.
The DaskRunner is a runner implementation that executes a graph of
transformations across processes and workers via Dask distributed's
scheduler.
"""
import argparse
import collections
import dataclasses
import typing as t
from apache_beam import pvalue
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.dask.overrides import dask_overrides
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
from apache_beam.runners.dask.transform_evaluator import DaskBagWindowedIterator
from apache_beam.runners.dask.transform_evaluator import Flatten
from apache_beam.runners.dask.transform_evaluator import NoOp
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineState
from apache_beam.transforms.sideinputs import SideInputMap
from apache_beam.utils.interactive_utils import is_in_notebook
try:
# Added to try to prevent threading related issues, see
# https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456
import dask.distributed as ddist
except ImportError:
ddist = {}
class DaskOptions(PipelineOptions):
@staticmethod
def _parse_timeout(candidate):
try:
return int(candidate)
except (TypeError, ValueError):
import dask
return dask.config.no_default
@staticmethod
def _extract_bag_kwargs(dask_options: t.Dict) -> t.Dict:
"""Parse keyword arguments for `dask.Bag`s; used in graph translation."""
out = {}
if npartitions := dask_options.pop('npartitions', None):
out['npartitions'] = npartitions
if partition_size := dask_options.pop('partition_size', None):
out['partition_size'] = partition_size
return out
@classmethod
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'--dask_client_address',
dest='address',
type=str,
default=None,
help='Address of a dask Scheduler server. Will default to a '
'`dask.LocalCluster()`.')
parser.add_argument(
'--dask_connection_timeout',
dest='timeout',
type=DaskOptions._parse_timeout,
help='Timeout duration for initial connection to the scheduler.')
parser.add_argument(
'--dask_scheduler_file',
dest='scheduler_file',
type=str,
default=None,
help='Path to a file with scheduler information if available.')
# TODO(alxr): Add options for security.
parser.add_argument(
'--dask_client_name',
dest='name',
type=str,
default=None,
help='Gives the client a name that will be included in logs generated '
'on the scheduler for matters relating to this client.')
parser.add_argument(
'--dask_connection_limit',
dest='connection_limit',
type=int,
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')
partitions_parser = parser.add_mutually_exclusive_group()
partitions_parser.add_argument(
'--dask_npartitions',
dest='npartitions',
type=int,
default=None,
help='The desired number of `dask.Bag` partitions. When unspecified, '
'an educated guess is made.')
partitions_parser.add_argument(
'--dask_partition_size',
dest='partition_size',
type=int,
default=None,
help='The length of each `dask.Bag` partition. When unspecified, '
'an educated guess is made.')
@dataclasses.dataclass
class DaskRunnerResult(PipelineResult):
client: ddist.Client
futures: t.Sequence[ddist.Future]
def __post_init__(self):
super().__init__(PipelineState.RUNNING)
def wait_until_finish(self, duration=None) -> str:
try:
if duration is not None:
# Convert milliseconds to seconds
duration /= 1000
for _ in ddist.as_completed(self.futures,
timeout=duration,
with_results=True):
# without gathering results, worker errors are not raised on the client:
# https://distributed.dask.org/en/stable/resilience.html#user-code-failures
# so we want to gather results to raise errors client-side, but we do
# not actually need to use the results here, so we just pass. to gather,
# we use the iterative `as_completed(..., with_results=True)`, instead
# of aggregate `client.gather`, to minimize memory footprint of results.
pass
self._state = PipelineState.DONE
except: # pylint: disable=broad-except
self._state = PipelineState.FAILED
raise
return self._state
def cancel(self) -> str:
self._state = PipelineState.CANCELLING
self.client.cancel(self.futures)
self._state = PipelineState.CANCELLED
return self._state
def metrics(self):
# TODO(alxr): Collect and return metrics...
raise NotImplementedError('collecting metrics will come later!')
class DaskRunner(BundleBasedDirectRunner):
"""Executes a pipeline on a Dask distributed client."""
@staticmethod
def to_dask_bag_visitor(bag_kwargs=None) -> PipelineVisitor:
from dask import bag as db
if bag_kwargs is None:
bag_kwargs = {}
@dataclasses.dataclass
class DaskBagVisitor(PipelineVisitor):
bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
default_factory=collections.OrderedDict)
def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
op = op_class(transform_node, bag_kwargs=bag_kwargs)
op_kws = {"input_bag": None, "side_inputs": None}
inputs = list(transform_node.inputs)
if inputs:
bag_inputs = []
for input_value in inputs:
if isinstance(input_value, pvalue.PBegin):
bag_inputs.append(None)
prev_op = input_value.producer
if prev_op in self.bags:
bag_inputs.append(self.bags[prev_op])
# Input to `Flatten` could be of length 1, e.g. a single-element
# tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as
# an iterable, because `Flatten.apply` always takes an iterable.
if len(bag_inputs) == 1 and not isinstance(op, Flatten):
op_kws["input_bag"] = bag_inputs[0]
else:
op_kws["input_bag"] = bag_inputs
side_inputs = list(transform_node.side_inputs)
if side_inputs:
bag_side_inputs = []
for si in side_inputs:
si_asbag = self.bags.get(si.pvalue.producer)
bag_side_inputs.append(
SideInputMap(
type(si),
si._view_options(),
DaskBagWindowedIterator(si_asbag, si._window_mapping_fn)))
op_kws["side_inputs"] = bag_side_inputs
self.bags[transform_node] = op.apply(**op_kws)
return DaskBagVisitor()
@staticmethod
def is_fnapi_compatible():
return False
def run_pipeline(self, pipeline, options):
import dask
# TODO(alxmrs): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')
try:
import dask.distributed as ddist
except ImportError:
raise ImportError(
'DaskRunner is not available. Please install apache_beam[dask].')
dask_options = options.view_as(DaskOptions).get_all_options(
drop_default=True, current_only=True)
bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
client = ddist.Client(**dask_options)
pipeline.replace_all(dask_overrides())
dask_visitor = self.to_dask_bag_visitor(bag_kwargs)
pipeline.visit(dask_visitor)
# The dictionary in this visitor keeps a mapping of every Beam
# PTransform to the equivalent Bag operation. This is highly
# redundant. Thus, we can get away with computing just the last
# value, which should be connected to the full Bag Task Graph.
opt_graph = dask.optimize(list(dask_visitor.bags.values())[-1])
futures = client.compute(opt_graph)
return DaskRunnerResult(client, futures)