|
11 | 11 | # TODO: temporarily disabling warnings to mute a pydantic warning from liteLLM
|
12 | 12 | import warnings
|
13 | 13 | from asyncio import AbstractEventLoop
|
| 14 | +from concurrent.futures import ThreadPoolExecutor |
14 | 15 | from functools import partial, reduce
|
15 | 16 | from itertools import count
|
16 | 17 | from os import getenv
|
@@ -986,11 +987,18 @@ def loop_body(iidx, items):
|
986 | 987 | map_loc,
|
987 | 988 | )
|
988 | 989 |
|
989 |
| - # with ThreadPoolExecutor(max_workers=4) as executor: |
990 |
| - # map_output = executor.map( |
991 |
| - map_output = map( # pylint: disable=bad-builtin |
992 |
| - loop_body, index_iterator, items_iterator |
993 |
| - ) |
| 990 | + map_output: Iterable[ |
| 991 | + Tuple[PdlLazy[Any], LazyMessages, ScopeType, BlockType] |
| 992 | + ] |
| 993 | + if block.maxWorkers == 0: |
| 994 | + map_output = map( # pylint: disable=bad-builtin |
| 995 | + loop_body, index_iterator, items_iterator |
| 996 | + ) |
| 997 | + else: |
| 998 | + with ThreadPoolExecutor(block.maxWorkers) as executor: |
| 999 | + map_output = executor.map( |
| 1000 | + loop_body, index_iterator, items_iterator |
| 1001 | + ) |
994 | 1002 | results, _, _, traces = _split_map_output(map_output)
|
995 | 1003 | # saved_background = IndependentContext(backgrounds)
|
996 | 1004 | except PDLRuntimeError as exc:
|
|
0 commit comments