Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit db87e2b

Browse files
authored
Replace launching explicit threads with a threadpoolexecutor
Differential Revision: D72825598 Pull Request resolved: pytorch#1469
1 parent ac98992 commit db87e2b

File tree

2 files changed

+122
-24
lines changed

2 files changed

+122
-24
lines changed

test/nodes/test_map.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8-
98
import unittest
109
from typing import List, Optional
10+
from unittest import mock
1111

1212
from parameterized import parameterized
1313
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase
@@ -131,7 +131,11 @@ def test_out_of_order_process_prebatch(self):
131131
)
132132
)
133133
def test_save_load_state_thread(
134-
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
134+
self,
135+
midpoint: int,
136+
in_order: bool,
137+
snapshot_frequency: int,
138+
prebatch: Optional[int],
135139
):
136140
method = "thread"
137141
batch_size = 6
@@ -159,7 +163,11 @@ def test_save_load_state_thread(
159163
)
160164
)
161165
def test_save_load_state_process(
162-
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
166+
self,
167+
midpoint: int,
168+
in_order: bool,
169+
snapshot_frequency: int,
170+
prebatch: Optional[int],
163171
):
164172
method = "process"
165173
batch_size = 6
@@ -179,3 +187,56 @@ def test_save_load_state_process(
179187
)
180188
node = Prefetcher(node, prefetch_factor=2)
181189
run_test_save_load_state(self, node, midpoint)
190+
191+
def test_thread_pool_executor_shutdown_on_del(self):
192+
"""Test that the ThreadPoolExecutor is properly shut down when the iterator is deleted."""
193+
# Create a ParallelMapper with method="thread"
194+
src = MockSource(num_samples=10)
195+
node = ParallelMapper(
196+
src,
197+
RandomSleepUdf(),
198+
num_workers=2,
199+
method="thread",
200+
)
201+
202+
# Reset the node to create the iterator
203+
node.reset()
204+
205+
# We need to consume some items to ensure the ThreadPoolExecutor is created
206+
# and the worker threads are started
207+
for _ in range(5):
208+
next(node)
209+
210+
# Use mock.patch to intercept the ThreadPoolExecutor.shutdown method
211+
with mock.patch("concurrent.futures.ThreadPoolExecutor.shutdown") as mock_shutdown:
212+
# Delete the node, which should trigger the shutdown of the ThreadPoolExecutor
213+
del node
214+
215+
# Verify that shutdown was called
216+
mock_shutdown.assert_called()
217+
218+
def test_thread_pool_executor_shutdown_on_exception(self):
219+
"""Test that the ThreadPoolExecutor is properly shut down when the iterator is deleted."""
220+
# Create a ParallelMapper with method="thread"
221+
src = MockSource(num_samples=10)
222+
node = ParallelMapper(
223+
src,
224+
udf_raises,
225+
num_workers=2,
226+
method="thread",
227+
)
228+
229+
# Reset the node to create the iterator
230+
node.reset()
231+
232+
# Use mock.patch to intercept the ThreadPoolExecutor.shutdown method
233+
with mock.patch("concurrent.futures.ThreadPoolExecutor.shutdown") as mock_shutdown:
234+
# Consumer the iterator to ensure the ThreadPoolExecutor is created
235+
# and exception is raised
236+
try:
237+
next(node)
238+
except ValueError:
239+
pass
240+
241+
# Verify that shutdown was called
242+
mock_shutdown.assert_called()

torchdata/nodes/map.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import queue
88
import threading
99
import time
10+
11+
from concurrent.futures import ThreadPoolExecutor
1012
from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, Protocol, Sequence, TypeVar, Union
1113

1214
import torch.multiprocessing as mp
@@ -65,7 +67,11 @@ def __call__(self, xlist: Sequence[X]) -> Sequence[T]:
6567
return [self.map_fn(x) for x in xlist]
6668

6769

68-
def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_event: threading.Event):
70+
def _sort_worker(
71+
in_q: Union[queue.Queue, mp.Queue],
72+
out_q: queue.Queue,
73+
stop_event: threading.Event,
74+
):
6975
buffer: Dict[int, Any] = {}
7076
cur_idx = 0
7177
while not stop_event.is_set():
@@ -91,6 +97,25 @@ def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_ev
9197
cur_idx += 1
9298

9399

100+
def _transformation_pool(
101+
pool: ThreadPoolExecutor,
102+
num_workers: int,
103+
in_q: queue.Queue,
104+
out_q: queue.Queue,
105+
map_fn: Callable[[X], T],
106+
stop_event: threading.Event,
107+
):
108+
for worker_id in range(num_workers):
109+
args = (
110+
worker_id,
111+
in_q,
112+
out_q,
113+
map_fn,
114+
stop_event,
115+
)
116+
pool.submit(_apply_udf, *args)
117+
118+
94119
class _InlineMapperIter(Iterator[T]):
95120
"""Non-Parallel implementation of Mapper"""
96121

@@ -186,40 +211,49 @@ def __init__(
186211
name="read_thread(target=_populate_queue)",
187212
daemon=self.daemonic_reading,
188213
)
189-
self._workers: List[Union[threading.Thread, mp.Process]] = []
190-
for worker_id in range(self.num_workers):
191-
args = (
192-
worker_id,
214+
self._read_thread.start()
215+
216+
if self.method == "thread":
217+
self.pool = ThreadPoolExecutor(max_workers=self.num_workers)
218+
219+
_transformation_pool(
220+
self.pool,
221+
self.num_workers,
193222
self._in_q,
194223
self._intermed_q,
195224
self.map_fn,
196-
self._stop if self.method == "thread" else self._mp_stop,
225+
self._stop,
197226
)
198-
self._workers.append(
199-
threading.Thread(
200-
target=_apply_udf,
201-
args=args,
202-
daemon=True,
203-
name=f"worker_thread_{worker_id}(target=_apply_udf)",
227+
228+
elif self.method == "process":
229+
self._workers: List[mp.Process] = []
230+
for worker_id in range(self.num_workers):
231+
args = (
232+
worker_id,
233+
self._in_q,
234+
self._intermed_q,
235+
self.map_fn,
236+
self._mp_stop,
204237
)
205-
if self.method == "thread"
206-
else mp_context.Process(target=_apply_udf, args=args, daemon=True)
207-
)
238+
self._workers.append(mp_context.Process(target=_apply_udf, args=args, daemon=True))
239+
for t in self._workers:
240+
t.start()
208241

209242
self._out_q = self._intermed_q
210243
if self.in_order:
211244
self._sort_q: queue.Queue = queue.Queue()
212245
self._sort_thread = threading.Thread(
213246
target=_sort_worker,
214-
args=(self._intermed_q, self._sort_q, self._stop),
247+
args=(
248+
self._intermed_q,
249+
self._sort_q,
250+
self._stop,
251+
),
215252
daemon=True,
216253
name="sort_thread(target=_sort_worker)",
217254
)
218255
self._out_q = self._sort_q
219256

220-
self._read_thread.start()
221-
for t in self._workers:
222-
t.start()
223257
if self.in_order:
224258
self._sort_thread.start()
225259

@@ -260,6 +294,7 @@ def __next__(self) -> T:
260294
elif isinstance(item, ExceptionWrapper):
261295
if not isinstance(item, StartupExceptionWrapper):
262296
self._sem.release()
297+
self._shutdown()
263298
item.reraise()
264299

265300
self._steps_since_snapshot += 1
@@ -286,12 +321,14 @@ def _shutdown(self):
286321
self._mp_stop.set()
287322
if hasattr(self, "_read_thread") and self._read_thread.is_alive():
288323
self._read_thread.join(timeout=QUEUE_TIMEOUT * 5)
289-
if hasattr(self, "_sort_thread") and self._sort_thread.is_alive():
290-
self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5)
324+
if hasattr(self, "pool"):
325+
self.pool.shutdown(wait=True)
291326
if hasattr(self, "_workers"):
292327
for t in self._workers:
293328
if t.is_alive():
294329
t.join(timeout=QUEUE_TIMEOUT * 5)
330+
if hasattr(self, "_sort_thread") and self._sort_thread.is_alive():
331+
self._sort_thread.join(timeout=QUEUE_TIMEOUT * 5)
295332

296333

297334
class _ParallelMapperImpl(BaseNode[T]):

0 commit comments

Comments
 (0)