Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit badf311

Browse files
authored
Fix GroupNorm test: properly check graph breaks (#1894)
1 parent 2eb64da commit badf311

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

apex/contrib/test/group_norm/test_group_norm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import functools
1919
import importlib
2020
import pathlib
21+
import sys
2122
import torch
2223
import unittest
2324

@@ -202,7 +203,12 @@ def test_group_norm_inductor(self):
202203
y.backward(dy)
203204

204205
from torch._dynamo.utils import counters
205-
self.assertNotIn('graph_break', counters, "Shouldn't see any graph breaks.")
206+
# TODO: Remove this when 3.9 is no longer supported
207+
if sys.version_info < (3, 10):
208+
num_graph_breaks = sum(counters["graph_break"].values())
209+
else:
210+
num_graph_breaks = counters["graph_break"].total()
211+
self.assertEqual(num_graph_breaks, 0, "Shouldn't see any graph breaks.")
206212
self.assertEqual(counters['stats']['unique_graphs'], 1, "Expect only one graph.")
207213

208214
def test_16_groups(self):

0 commit comments

Comments
 (0)