Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4bdecd0

Browse files
lirundongalpha0422
andauthored
[contrib/torchsched] Refactor cross-stream dependency handling (#1922)
* [contrib/torchsched] Refactor cross-stream dependency handling Co-authored-by: Wil Kong <[email protected]> Co-authored-by: Rundong Li <[email protected]> * [contrib/torchsched] Fix symbolic event ordering and config Co-authored-by: Wil Kong <[email protected]> --------- Co-authored-by: Wil Kong <[email protected]>
1 parent 446a6c1 commit 4bdecd0

10 files changed

Lines changed: 351 additions & 220 deletions

File tree

apex/contrib/torchsched/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from .backend import get_backend
1414

1515
if TYPE_CHECKING:
16+
from collections.abc import Callable
1617
from typing import Any
17-
from typing import Callable
1818

1919
from torch._ops import OpOverload
2020

@@ -50,7 +50,7 @@ def set_default_backend(backend: str) -> None:
5050
Parameters:
5151
backend (str): The backend to use as the default for torch.compile.
5252
"""
53-
global _SUPPORTED_BACKENDS, _DEFAULT_BACKEND
53+
global _DEFAULT_BACKEND
5454
assert backend in _SUPPORTED_BACKENDS, f"Unknown backend {backend}"
5555
_DEFAULT_BACKEND = backend
5656

apex/contrib/torchsched/backend.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,23 @@
55
import functools
66
from copy import copy
77
from typing import TYPE_CHECKING
8-
from typing import Callable
98
from typing import ParamSpec
109
from typing import TypeVar
1110

1211
if TYPE_CHECKING:
12+
from collections.abc import Callable
1313
from types import NotImplementedType
1414

1515
import torch
1616
from torch import Tensor
1717
from torch import _TorchCompileInductorWrapper
1818
from torch._dynamo import lookup_backend
19-
from torch._inductor.codegen.common import register_backend_for_device
20-
from torch._inductor.codegen.cuda_combined_scheduling import CUDACombinedScheduling
21-
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
2219
from torch._inductor.compile_fx import compile_fx
2320
from torch._inductor.compile_fx import compile_fx_inner
2421
from torch._inductor.decomposition import select_decomp_table
2522

2623
import apex.contrib.torchsched.config as config
2724
from apex.contrib.torchsched.inductor import patch_graph_lowering
28-
from apex.contrib.torchsched.inductor.wrapper import MultiStreamWrapperCodegen
2925
from apex.contrib.torchsched.passes import pre_grad_custom_pass
3026

3127
aten = torch.ops.aten
@@ -43,21 +39,9 @@ def enable_multi_stream_scheduling(compile_fn: Callable[P, R]) -> Callable[P, R]
4339

4440
@functools.wraps(compile_fn)
4541
def _compile_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
46-
register_backend_for_device("cuda", CUDACombinedScheduling, MultiStreamWrapperCodegen)
4742
patch_graph_lowering(patch=True)
48-
49-
# torch.compile explicitly calls `write_get_raw_stream` via wrapper's class method in its
50-
# lowering process to walk around the wrapper-stream LRU cache mechanism. To be compatible
51-
# with this, we got to patch wrapper's class method as well.
52-
_origin_write_get_raw_stream = PythonWrapperCodegen.write_get_raw_stream
53-
PythonWrapperCodegen.write_get_raw_stream = MultiStreamWrapperCodegen._write_get_raw_stream
54-
5543
compile_results = compile_fn(*args, **kwargs)
56-
57-
register_backend_for_device("cuda", CUDACombinedScheduling, PythonWrapperCodegen)
5844
patch_graph_lowering(patch=False)
59-
PythonWrapperCodegen.write_get_raw_stream = _origin_write_get_raw_stream
60-
6145
return compile_results
6246

6347
return _compile_wrapper
@@ -302,11 +286,49 @@ def get_backend(
302286
if backend == "torch":
303287
return lookup_backend("inductor")
304288

289+
# [NOTE] Disable buffer reuse and inplace buffers to avoid inter-stream conflicts.
290+
#
291+
# In PyTorch Inductor, the safety of buffer reuse and in-place buffer update is ensured by the
292+
# program's single-stream, serial execution. That is, if op2 is launched only after op1 has
293+
# completed execution, then these cases are safe:
294+
#
295+
# Case 1: Safe to reuse buffer `workspace1` as `op2`'s workspace.
296+
#
297+
# op1 -> op2 op1 -> op2
298+
# ↕ ↕ ⇒ ↕ ↑
299+
# workspace1 workspace2 workspace1 ←----┘
300+
#
301+
# Case 2: Safe to inpalace `op1`'s output to `buf1` then send to `op2` as input.
302+
#
303+
# buf1 -> op1 -> buf2 -> op2 ⇒ buf1 ↔ op1
304+
# └-------> op2
305+
#
306+
# However, if operators are dispatched to distinct CUDA Streams and execute in parallel, above
307+
# cases are not safe any more:
308+
#
309+
# Counter example 1: Case 1 is not safe if op1 and op2 are in parallel.
310+
#
311+
# op1
312+
# ↕
313+
# workspace1 (Buffer modified concurrently by op1 and op2.)
314+
# ↕
315+
# op2
316+
#
317+
# Counter example 2: Case 2 is not safe if op1 and op2 are in parallel.
318+
#
319+
# buf1 <--> op1
320+
# └------> op2 (Op2 could read op1's input data.)
321+
#
322+
# Thus currently we disable both buffer reuse and inplace buffer update to ensure multi-stream
323+
# correctness.
324+
#
325+
# TODO(@davidli): Add cross-stream dependency to Inductor scheduling's dependency system so we
326+
# can safely reuse and inplace update buffers even in multi-stream scenario.
327+
305328
if scheme == "dwb":
306329
return DecompositionsWrapper(
307330
mode="default",
308-
# TODO(@davidli): Elegantly solve cross-stream buffer reusing conflicts.
309-
options={"allow_buffer_reuse": False},
331+
options={"allow_buffer_reuse": False, "inplace_buffers": False},
310332
dynamic=False,
311333
decompositions={
312334
aten.convolution_backward.default: convolution_backward_decomp_dwb,
@@ -315,7 +337,7 @@ def get_backend(
315337
elif scheme == "wbd":
316338
return DecompositionsWrapper(
317339
mode="default",
318-
options={"allow_buffer_reuse": False},
340+
options={"allow_buffer_reuse": False, "inplace_buffers": False},
319341
dynamic=False,
320342
decompositions={
321343
aten.convolution_backward.default: convolution_backward_decomp_wbd,

apex/contrib/torchsched/config.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Configurations for graph scheduler."""
22

3+
import functools
34
import os
5+
import re
46
import sys
57

68
# Debug info and dump grpahs
@@ -17,6 +19,56 @@
1719
# scheduled to other streams in a round-robin way.
1820
num_streams = int(os.getenv("TORCH_SCHED_NUM_STREAMS", "8"))
1921

22+
23+
def _get_skip_post_grad_graph_ids() -> set[int]:
24+
if ids := os.environ.get("TORCH_SCHED_SKIP_GRAPH_IDS"):
25+
result: set[int] = set()
26+
for part in ids.split(","):
27+
if "-" in part:
28+
start, end = map(int, part.split("-"))
29+
result.update(range(start, end + 1))
30+
else:
31+
result.add(int(part))
32+
return result
33+
else:
34+
return set()
35+
36+
37+
# IDs of post AOT-autograd graphs that should be skipped for multi-stream scheduling. Can be
38+
# specified via TORCH_SCHED_SKIP_GRAPH_IDS environment variable in a SLURM-like scheme, e.g.,
39+
# TORCH_SCHED_SKIP_GRAPH_IDS=1,2,3-5,7-10
40+
skip_post_grad_graph_ids: set[int] = _get_skip_post_grad_graph_ids()
41+
42+
# Reduce the number of allocated CUDA Events in the generated program by:
43+
# 1. Track reference count of each CUDA Event in the scheduling phase. Skip generating CUDA Events
44+
# that have no reference counts, i.e., have not been waited by other streams;
45+
# 2. Reuse allocated CUDA Events when feasible.
46+
# This option is enable by default.
47+
reuse_cuda_event: bool = os.getenv("TORCH_SCHED_REUSE_CUDA_EVENT", "1") == "1"
48+
49+
50+
@functools.lru_cache
51+
def __get_dump_code_backends_and_dir(dump_code: str | None) -> tuple[list[str], str | None]:
52+
pattern = r"(?:\+(?P<backend>\w+),)?(?P<dir>[\w\/\.\-\s@#~]+)"
53+
backends, dir = ["torchsched"], None
54+
if dump_code and (match := re.match(pattern, dump_code)):
55+
if backend := match.group("backend"):
56+
backends.append(backend)
57+
dir = os.path.abspath(match.group("dir"))
58+
return backends, dir
59+
60+
61+
# Specify dump code backend types and output directory by::
62+
#
63+
# TORCH_SCHED_DUMP_CODE='+inductor,/dir/to/save/code'
64+
#
65+
# Where `+inductor` enables dump both Inductor and torchsched code. If omitted, only dump
66+
# torchsched code. `/dir/to/save/code` specifies a directory to dump code to.
67+
(
68+
dump_code_backends,
69+
dump_code_dir,
70+
) = __get_dump_code_backends_and_dir(os.getenv("TORCH_SCHED_DUMP_CODE"))
71+
2072
from torch.utils._config_module import install_config_module # noqa: E402
2173

2274
# adds patch, save_config, etc

apex/contrib/torchsched/inductor/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
STREAM_NAME_TEMPLATE: str = "stream{stream_idx:d}"
2727

2828

29-
@functools.lru_cache()
29+
@functools.lru_cache
3030
def get_stream_name(stream_idx: int) -> str:
3131
"""Generate CUDA Stream name from stream index number.
3232

apex/contrib/torchsched/inductor/event.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from torch._inductor.codegen.wrapper import IndentedBuffer
1919
from torch._inductor.codegen.wrapper import WrapperLine
2020

21+
import apex.contrib.torchsched.config as torchsched_config
22+
from apex.contrib.torchsched.inductor._utils import DEFAULT_STREAM_IDX
2123
from apex.contrib.torchsched.inductor._utils import ENTRANCE_EVENT
2224
from apex.contrib.torchsched.inductor._utils import EVENT_NAME_TEMPLATE
2325
from apex.contrib.torchsched.inductor._utils import get_stream_name
@@ -31,6 +33,7 @@ class CudaEventSym:
3133
Args:
3234
factory: The CUDAEventFactory that generate this event.
3335
idx: Indexing number assigned in chronological order during scheduling.
36+
originate_stream_idx: The index of the CUDA stream that this event originated from.
3437
ref_count: Reference count of this event instance.
3538
materialized_event: The actual CUDA Event name that will be used in the final PyTorch
3639
program. Only symbolic event with reference count larger than one will be materialized.
@@ -42,22 +45,30 @@ class CudaEventSym:
4245

4346
factory: CudaEventFactory
4447
idx: int
48+
originate_stream_idx: int
4549
ref_count: int = 0
4650
materialized_event: str | None = None
4751

4852
def __lt__(self, rhs: CudaEventSym) -> bool:
4953
"""Whether the current event is generated before the rhs event."""
50-
return self.idx < rhs.idx and self.factory is rhs.factory
54+
if self.factory is not rhs.factory:
55+
return NotImplemented
56+
return (self.idx, self.originate_stream_idx) < (rhs.idx, rhs.originate_stream_idx)
5157

5258
def __eq__(self, rhs: object) -> bool:
5359
"""Whether the current event is identical to the rhs event."""
5460
if not isinstance(rhs, CudaEventSym):
5561
return NotImplemented
56-
return self.idx == rhs.idx and self.factory is rhs.factory
62+
return (
63+
self.idx == rhs.idx
64+
and self.originate_stream_idx == rhs.originate_stream_idx
65+
and self.factory is rhs.factory
66+
)
5767

5868
def __str__(self) -> str:
5969
"""Represent this symbolic event in string."""
6070
ret = f"{self.__class__.__name__} (idx={self.idx}"
71+
ret += f", originate_stream_idx={self.originate_stream_idx}"
6172
if self.ref_count:
6273
ret += f", ref_count={self.ref_count}"
6374
if self.materialized_event:
@@ -67,7 +78,7 @@ def __str__(self) -> str:
6778

6879
def __hash__(self) -> int:
6980
"""Hash this symbolic event."""
70-
return hash(f"{id(self.factory)=},{self.idx=}")
81+
return hash((id(self.factory), self.idx, self.originate_stream_idx))
7182

7283
def record(self, stream_idx: int) -> _CudaEventRecordLine:
7384
"""Record this event on a given stream.
@@ -103,6 +114,7 @@ def wait(self, stream_idx: int) -> _CudaEventWaitLine:
103114
the reference count of this event. If an event object has called this method, it is
104115
guaranteed to be generated in the final program.
105116
"""
117+
assert stream_idx != self.originate_stream_idx
106118
self.ref_count += 1
107119
stream = get_stream_name(stream_idx)
108120
return _CudaEventWaitLine(self, stream)
@@ -113,11 +125,12 @@ class _CudaEventRecordLine(WrapperLine):
113125

114126
event: CudaEventSym
115127
stream: str
128+
_reuse_cuda_event: bool = torchsched_config.reuse_cuda_event
116129

117130
def codegen(self, code: IndentedBuffer) -> None:
118131
assert 0 <= self.event.ref_count
119132
assert self.event.materialized_event is None
120-
if self.event.ref_count:
133+
if self.event.ref_count or not self._reuse_cuda_event:
121134
self.event.materialized_event = self.event.factory.get_materialized_event(code)
122135
code.writeline(f"{self.event.materialized_event}.record({self.stream})")
123136

@@ -131,12 +144,13 @@ class _CudaEventWaitLine(WrapperLine):
131144
def codegen(self, code: IndentedBuffer) -> None:
132145
assert 0 < self.event.ref_count
133146
assert self.event.materialized_event is not None
134-
code.writeline(f"{self.event.materialized_event}.wait({self.stream})")
147+
code_line = f"{self.event.materialized_event}.wait({self.stream})"
135148
self.event.ref_count -= 1
136149
if self.event.ref_count == 0:
137150
self.event.factory.deposit_materialized_event(self.event.materialized_event)
138151
self.event.materialized_event = None
139-
code.writeline(f"# End lifecycle of {self.event}")
152+
code_line += f" # End lifecycle of {self.event}"
153+
code.writeline(code_line)
140154

141155

142156
class CudaEventFactory:
@@ -153,23 +167,32 @@ def __init__(self) -> None:
153167
self.materialized_event_idx: itertools.count = itertools.count(start=1)
154168
self.available_materialized_events: set[str] = set()
155169
self._entrance_event: CudaEventSym | None = None
170+
self._reuse_cuda_event: bool = torchsched_config.reuse_cuda_event
156171

157172
def get_entrance_event(self) -> CudaEventSym:
158173
"""Return the cuda event that corresponding to compute graph entering."""
159174
if self._entrance_event is None:
160-
self._entrance_event = CudaEventSym(factory=self, idx=0)
175+
self._entrance_event = CudaEventSym(
176+
factory=self,
177+
idx=0,
178+
originate_stream_idx=DEFAULT_STREAM_IDX,
179+
)
161180
# Code-gen for entrance event is almost hard-coded in device guard enter so the
162181
# materialization is slightly different here.
163182
self._entrance_event.materialized_event = ENTRANCE_EVENT
164183
return self._entrance_event
165184

166-
def get_sym_event(self) -> CudaEventSym:
185+
def get_sym_event(self, originate_stream_idx: int) -> CudaEventSym:
167186
"""Allocate a symbolic cuda event."""
168-
return CudaEventSym(factory=self, idx=next(self.symbolic_event_idx))
187+
return CudaEventSym(
188+
factory=self,
189+
idx=next(self.symbolic_event_idx),
190+
originate_stream_idx=originate_stream_idx,
191+
)
169192

170193
def get_materialized_event(self, code: IndentedBuffer) -> str:
171194
"""Allocate or reuse a materialized cuda event."""
172-
if self.available_materialized_events:
195+
if self._reuse_cuda_event and self.available_materialized_events:
173196
return self.available_materialized_events.pop()
174197
else:
175198
event = EVENT_NAME_TEMPLATE.format(event_idx=next(self.materialized_event_idx))

0 commit comments

Comments
 (0)