Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e13873d

Browse files
alpha0422lirundong
andauthored
Graph Compiler For Multi-stream (#1886)
* Graph compiler for multi-stream. Co-authored-by: Wil Kong <[email protected]> Co-authored-by: Rundong Li <[email protected]> * Reuse streams across torch compile graphs. --------- Co-authored-by: Rundong Li <[email protected]>
1 parent b9d758c commit e13873d

9 files changed

Lines changed: 1637 additions & 1 deletion

File tree

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,81 @@
1+
"""Graph scheduler package."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
17
import torch
8+
import torch._inductor
9+
from torch._dynamo import list_backends
10+
from torch._dynamo import register_backend
11+
from torch._inductor.compile_fx import compile_fx_inner
12+
13+
from .backend import get_backend
14+
15+
if TYPE_CHECKING:
16+
from typing import Any
17+
from typing import Callable
218

19+
from torch._ops import OpOverload
20+
21+
__all__ = ["get_backend", "set_default_backend"]
22+
23+
# Register custom operators
324
torch.ops.import_module("apex.contrib.torchsched.ops")
25+
26+
27+
# Register torch-sched backend
28+
# Same API as torch._inductor.compile_fx
29+
@register_backend
30+
def torchsched(
31+
model_: torch.fx.GraphModule,
32+
example_inputs_: list[torch.Tensor],
33+
inner_compile: Callable[..., Any] = compile_fx_inner,
34+
config_patches: dict[str, Any] | None = None,
35+
decompositions: dict[OpOverload, Callable[..., Any]] | None = None,
36+
) -> Callable:
37+
backend = get_backend(backend="torchsched", scheme="dwb")
38+
return backend(model_, example_inputs_, inner_compile, config_patches, decompositions)
39+
40+
41+
_SUPPORTED_BACKENDS = list_backends()
42+
_DEFAULT_BACKEND = "inductor"
43+
__torch_compile__ = torch.compile
44+
45+
46+
def set_default_backend(backend: str) -> None:
47+
"""
48+
Set the default backend for torch.compile.
49+
50+
Parameters:
51+
backend (str): The backend to use as the default for torch.compile.
52+
"""
53+
global _SUPPORTED_BACKENDS, _DEFAULT_BACKEND
54+
assert backend in _SUPPORTED_BACKENDS, f"Unknown backend {backend}"
55+
_DEFAULT_BACKEND = backend
56+
57+
58+
def torchsched_compile(
59+
*args: object,
60+
backend: str | Callable | None = None,
61+
**kwargs: object,
62+
) -> object:
63+
"""
64+
Wrap around the original torch.compile to support default backend.
65+
66+
Parameters:
67+
*args (object): Positional arguments for torch.compile.
68+
backend (Union[str, Callable, None]): The backend to use.
69+
If None, the default backend is used.
70+
**kwargs (object): Additional keyword arguments for torch.compile.
71+
72+
Returns:
73+
object: Compiler or compiled model.
74+
"""
75+
if backend is None:
76+
backend = _DEFAULT_BACKEND
77+
return __torch_compile__(*args, backend=backend, **kwargs)
78+
79+
80+
# Monkey patch torch.compile to set default backend
81+
torch.compile = torchsched_compile

0 commit comments

Comments
 (0)