@@ -268,9 +268,11 @@ def _verify_export(self, fused, fused_x):
268268 # check that export() is working
269269 onnx_str = torch .onnx .export_to_pretty_string (fused , (fused_x ,),
270270 input_names = ['x_in' ],
271+ opset_version = 18 ,
271272 )
273+ print (onnx_str )
272274 assert 'x_in' in onnx_str
273- assert 'ReduceMean' in onnx_str
275+ assert 'ReduceMean' in onnx_str or 'LayerNormalization' in onnx_str
274276
275277 def test_rms_export (self ):
276278 batch_size = 16
@@ -279,7 +281,7 @@ def test_rms_export(self):
279281 normalized_shape = normalized_shape , elementwise_affine = True
280282 ).cuda ()
281283 fused_m = MixedFusedRMSNorm (
282- normalized_shape = normalized_shape , elementwise_affine = True
284+ normalized_shape = normalized_shape
283285 ).cuda ()
284286 native_x , fused_x = _prep_inputs (batch_size , normalized_shape , torch .float32 )
285287 self ._verify_export (fused , fused_x )
@@ -292,7 +294,7 @@ def test_layer_norm_export(self):
292294 normalized_shape = normalized_shape , elementwise_affine = True
293295 ).cuda ()
294296 fused_m = MixedFusedLayerNorm (
295- normalized_shape = normalized_shape , elementwise_affine = True
297+ normalized_shape = normalized_shape
296298 ).cuda ()
297299 native_x , fused_x = _prep_inputs (batch_size , normalized_shape , torch .float32 )
298300 self ._verify_export (fused , fused_x )
0 commit comments