Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 379937e

Browse files
authored
[contrib] torch.compile pass for replacing layer norm with cudnn version. (#1892)
1 parent c02c6c8 commit 379937e

6 files changed

Lines changed: 601 additions & 0 deletions

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import torch
2+
3+
torch.ops.import_module("apex.contrib.torchsched.ops")

apex/contrib/torchsched/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Configurations for graph scheduler."""""
2+
3+
import sys
4+
5+
# Pre grad pass patterns
6+
pre_grad_pass_options: list[str] = ["cudnn_layer_norm"]
7+
8+
from torch.utils._config_module import install_config_module # noqa: E402
9+
10+
# adds patch, save_config, etc
11+
install_config_module(sys.modules[__name__])
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Custom PyTorch operators."""
2+
3+
import torch
4+
5+
__all__: list[str] = []
6+
7+
# Register custom operators
8+
torch.ops.import_module("apex.contrib.torchsched.ops.layer_norm")

0 commit comments

Comments
 (0)