Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d0e7be1

Browse files
crcrparCopilot
andauthored
Fix contrib openfold_triton tests (#1938)
* use `torch.backends.cudnn.flags` instead Signed-off-by: Masaki Kozuki <[email protected]> * import run_tests from common_utils Signed-off-by: Masaki Kozuki <[email protected]> * use TRITON_ALLOW_NON_CONSTEXPR_GLOBALS as a workaround Signed-off-by: Masaki Kozuki <[email protected]> * Update apex/contrib/test/openfold_triton/test_fused_adam_swa.py Co-authored-by: Copilot <[email protected]> --------- Signed-off-by: Masaki Kozuki <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 184ea24 commit d0e7be1

2 files changed

Lines changed: 17 additions & 3 deletions

File tree

apex/contrib/test/openfold_triton/test_fused_adam_swa.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import os
1515
from itertools import chain
1616
import random
1717
import unittest
@@ -88,9 +88,20 @@ def setUp(self):
8888
self._seed = 19260817
8989
random.seed(self._seed)
9090
torch.manual_seed(self._seed)
91-
torch.backends.cudnn.deterministic = True
91+
# FIXME: correctly fix: """NameError("Cannot access global variable _DTYPE2TRITON from within @jit'ed function.
92+
# Triton kernels can only access global variables that are instanstiated as constexpr (`x = triton.language.constexpr(42)`).
93+
# Note that this is different from annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported.
94+
# Alternatively, set the envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not promise to support this forever.")"""
95+
os.environ["TRITON_ALLOW_NON_CONSTEXPR_GLOBALS"] = "1"
96+
97+
def tearDown(self):
98+
os.environ.pop("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", None)
9299

93100
def test_fused_update_on_random_data(self):
101+
with torch.backends.cudnn.flags(deterministic=True):
102+
self._run_fused_update_on_random_data()
103+
104+
def _run_fused_update_on_random_data(self):
94105
device = torch.device("cuda:0")
95106
compute_dtype = torch.float32
96107
state_dtype = torch.float64

apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import os
2+
23
import torch
34
import torch.distributed as dist
5+
from torch.testing._internal.common_utils import run_tests
46
from torch.testing._internal.common_distributed import (
57
MultiProcessTestCase,
68
requires_nccl,
79
skip_if_lt_x_gpu,
8-
run_tests,
910
)
11+
1012
from apex.contrib.openfold_triton import (
1113
LayerNormSmallShapeOptImpl,
1214
sync_triton_auto_tune_cache_across_gpus,
1315
_tuneable_triton_kernels,
1416
)
1517

18+
1619
class SyncTritonAutoTuneCacheTest(MultiProcessTestCase):
1720
device_type = "cuda"
1821
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)