|
| 1 | +"""Graph scheduler package.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import TYPE_CHECKING |
| 6 | + |
1 | 7 | import torch |
| 8 | +import torch._inductor |
| 9 | +from torch._dynamo import list_backends |
| 10 | +from torch._dynamo import register_backend |
| 11 | +from torch._inductor.compile_fx import compile_fx_inner |
| 12 | + |
| 13 | +from .backend import get_backend |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + from typing import Any |
| 17 | + from typing import Callable |
2 | 18 |
|
| 19 | + from torch._ops import OpOverload |
| 20 | + |
| 21 | +__all__ = ["get_backend", "set_default_backend"] |
| 22 | + |
| 23 | +# Register custom operators |
3 | 24 | torch.ops.import_module("apex.contrib.torchsched.ops") |
| 25 | + |
| 26 | + |
| 27 | +# Register torch-sched backend |
| 28 | +# Same API as torch._inductor.compile_fx |
| 29 | +@register_backend |
| 30 | +def torchsched( |
| 31 | + model_: torch.fx.GraphModule, |
| 32 | + example_inputs_: list[torch.Tensor], |
| 33 | + inner_compile: Callable[..., Any] = compile_fx_inner, |
| 34 | + config_patches: dict[str, Any] | None = None, |
| 35 | + decompositions: dict[OpOverload, Callable[..., Any]] | None = None, |
| 36 | +) -> Callable: |
| 37 | + backend = get_backend(backend="torchsched", scheme="dwb") |
| 38 | + return backend(model_, example_inputs_, inner_compile, config_patches, decompositions) |
| 39 | + |
| 40 | + |
| 41 | +_SUPPORTED_BACKENDS = list_backends() |
| 42 | +_DEFAULT_BACKEND = "inductor" |
| 43 | +__torch_compile__ = torch.compile |
| 44 | + |
| 45 | + |
| 46 | +def set_default_backend(backend: str) -> None: |
| 47 | + """ |
| 48 | + Set the default backend for torch.compile. |
| 49 | +
|
| 50 | + Parameters: |
| 51 | + backend (str): The backend to use as the default for torch.compile. |
| 52 | + """ |
| 53 | + global _SUPPORTED_BACKENDS, _DEFAULT_BACKEND |
| 54 | + assert backend in _SUPPORTED_BACKENDS, f"Unknown backend {backend}" |
| 55 | + _DEFAULT_BACKEND = backend |
| 56 | + |
| 57 | + |
| 58 | +def torchsched_compile( |
| 59 | + *args: object, |
| 60 | + backend: str | Callable | None = None, |
| 61 | + **kwargs: object, |
| 62 | +) -> object: |
| 63 | + """ |
| 64 | + Wrap around the original torch.compile to support default backend. |
| 65 | +
|
| 66 | + Parameters: |
| 67 | + *args (object): Positional arguments for torch.compile. |
| 68 | + backend (Union[str, Callable, None]): The backend to use. |
| 69 | + If None, the default backend is used. |
| 70 | + **kwargs (object): Additional keyword arguments for torch.compile. |
| 71 | +
|
| 72 | + Returns: |
| 73 | + object: Compiler or compiled model. |
| 74 | + """ |
| 75 | + if backend is None: |
| 76 | + backend = _DEFAULT_BACKEND |
| 77 | + return __torch_compile__(*args, backend=backend, **kwargs) |
| 78 | + |
| 79 | + |
| 80 | +# Monkey patch torch.compile to set default backend |
| 81 | +torch.compile = torchsched_compile |
0 commit comments