7
7
import queue
8
8
import threading
9
9
import time
10
+
11
+ from concurrent .futures import ThreadPoolExecutor
10
12
from typing import Any , Callable , Dict , Generic , Iterator , List , Literal , Optional , Protocol , Sequence , TypeVar , Union
11
13
12
14
import torch .multiprocessing as mp
@@ -65,7 +67,11 @@ def __call__(self, xlist: Sequence[X]) -> Sequence[T]:
65
67
return [self .map_fn (x ) for x in xlist ]
66
68
67
69
68
- def _sort_worker (in_q : Union [queue .Queue , mp .Queue ], out_q : queue .Queue , stop_event : threading .Event ):
70
+ def _sort_worker (
71
+ in_q : Union [queue .Queue , mp .Queue ],
72
+ out_q : queue .Queue ,
73
+ stop_event : threading .Event ,
74
+ ):
69
75
buffer : Dict [int , Any ] = {}
70
76
cur_idx = 0
71
77
while not stop_event .is_set ():
@@ -91,6 +97,25 @@ def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_ev
91
97
cur_idx += 1
92
98
93
99
100
+ def _transformation_pool (
101
+ pool : ThreadPoolExecutor ,
102
+ num_workers : int ,
103
+ in_q : queue .Queue ,
104
+ out_q : queue .Queue ,
105
+ map_fn : Callable [[X ], T ],
106
+ stop_event : threading .Event ,
107
+ ):
108
+ for worker_id in range (num_workers ):
109
+ args = (
110
+ worker_id ,
111
+ in_q ,
112
+ out_q ,
113
+ map_fn ,
114
+ stop_event ,
115
+ )
116
+ pool .submit (_apply_udf , * args )
117
+
118
+
94
119
class _InlineMapperIter (Iterator [T ]):
95
120
"""Non-Parallel implementation of Mapper"""
96
121
@@ -186,40 +211,49 @@ def __init__(
186
211
name = "read_thread(target=_populate_queue)" ,
187
212
daemon = self .daemonic_reading ,
188
213
)
189
- self ._workers : List [Union [threading .Thread , mp .Process ]] = []
190
- for worker_id in range (self .num_workers ):
191
- args = (
192
- worker_id ,
214
+ self ._read_thread .start ()
215
+
216
+ if self .method == "thread" :
217
+ self .pool = ThreadPoolExecutor (max_workers = self .num_workers )
218
+
219
+ _transformation_pool (
220
+ self .pool ,
221
+ self .num_workers ,
193
222
self ._in_q ,
194
223
self ._intermed_q ,
195
224
self .map_fn ,
196
- self ._stop if self . method == "thread" else self . _mp_stop ,
225
+ self ._stop ,
197
226
)
198
- self ._workers .append (
199
- threading .Thread (
200
- target = _apply_udf ,
201
- args = args ,
202
- daemon = True ,
203
- name = f"worker_thread_{ worker_id } (target=_apply_udf)" ,
227
+
228
+ elif self .method == "process" :
229
+ self ._workers : List [mp .Process ] = []
230
+ for worker_id in range (self .num_workers ):
231
+ args = (
232
+ worker_id ,
233
+ self ._in_q ,
234
+ self ._intermed_q ,
235
+ self .map_fn ,
236
+ self ._mp_stop ,
204
237
)
205
- if self .method == "thread"
206
- else mp_context . Process ( target = _apply_udf , args = args , daemon = True )
207
- )
238
+ self ._workers . append ( mp_context . Process ( target = _apply_udf , args = args , daemon = True ))
239
+ for t in self . _workers :
240
+ t . start ( )
208
241
209
242
self ._out_q = self ._intermed_q
210
243
if self .in_order :
211
244
self ._sort_q : queue .Queue = queue .Queue ()
212
245
self ._sort_thread = threading .Thread (
213
246
target = _sort_worker ,
214
- args = (self ._intermed_q , self ._sort_q , self ._stop ),
247
+ args = (
248
+ self ._intermed_q ,
249
+ self ._sort_q ,
250
+ self ._stop ,
251
+ ),
215
252
daemon = True ,
216
253
name = "sort_thread(target=_sort_worker)" ,
217
254
)
218
255
self ._out_q = self ._sort_q
219
256
220
- self ._read_thread .start ()
221
- for t in self ._workers :
222
- t .start ()
223
257
if self .in_order :
224
258
self ._sort_thread .start ()
225
259
@@ -260,6 +294,7 @@ def __next__(self) -> T:
260
294
elif isinstance (item , ExceptionWrapper ):
261
295
if not isinstance (item , StartupExceptionWrapper ):
262
296
self ._sem .release ()
297
+ self ._shutdown ()
263
298
item .reraise ()
264
299
265
300
self ._steps_since_snapshot += 1
@@ -286,12 +321,14 @@ def _shutdown(self):
286
321
self ._mp_stop .set ()
287
322
if hasattr (self , "_read_thread" ) and self ._read_thread .is_alive ():
288
323
self ._read_thread .join (timeout = QUEUE_TIMEOUT * 5 )
289
- if hasattr (self , "_sort_thread" ) and self . _sort_thread . is_alive ( ):
290
- self ._sort_thread . join ( timeout = QUEUE_TIMEOUT * 5 )
324
+ if hasattr (self , "pool" ):
325
+ self .pool . shutdown ( wait = True )
291
326
if hasattr (self , "_workers" ):
292
327
for t in self ._workers :
293
328
if t .is_alive ():
294
329
t .join (timeout = QUEUE_TIMEOUT * 5 )
330
+ if hasattr (self , "_sort_thread" ) and self ._sort_thread .is_alive ():
331
+ self ._sort_thread .join (timeout = QUEUE_TIMEOUT * 5 )
295
332
296
333
297
334
class _ParallelMapperImpl (BaseNode [T ]):
0 commit comments