From 7853fff2014ae502924ed7e0433b6854fee90181 Mon Sep 17 00:00:00 2001 From: bobrenjc93 Date: Thu, 17 Jul 2025 08:52:16 -0700 Subject: [PATCH] Unify torch.tensor and torch.ops.aten.scalar_tensor behavior (#158537) Fixes #158376 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158537 Approved by: https://github.com/atalman --- aten/src/ATen/ScalarOps.cpp | 23 ++++++++++++++++++++++- test/dynamo/test_misc.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp index 693fb46e639f..da4f7a35a2f4 100644 --- a/aten/src/ATen/ScalarOps.cpp +++ b/aten/src/ATen/ScalarOps.cpp @@ -8,7 +8,28 @@ namespace at { namespace { template inline void fill_inplace(Tensor& self, const Scalar& value_scalar) { - auto value = value_scalar.to(); + scalar_t value{}; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // relaxed float cast: allow inf similar to the torch.tensor constructor + // + // without this, we had the following divergence: + // torch.tensor(1123581321.0, dtype=torch.float16) + // => tensor(inf, dtype=torch.float16) + // torch.ops.aten.scalar_tensor.default(1123581321, dtype=torch.float16) + // => RuntimeError: value cannot be converted to type at::Half without overflow + + value = static_cast(value_scalar.to()); + } else { + value = value_scalar.to(); + } + scalar_t* dptr = static_cast(self.data_ptr()); *dptr = value; } diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 50ede0b54656..f3cc410d9eed 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -12955,6 +12955,38 @@ def f(actions, n_act, epsilon=0.1): y = torch.tensor(5) f(x, y) + def test_dynamic_float_scalar_tensor_coersion(self): + # Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367 + class Foo: + def __init__(self): + self.config = type( + "Config", (), {"pad_val": 1123581321.0, "tolerance": 1e-6} + ) + + @torch.compile(fullgraph=True) + def forward(self, input): + outputs = torch.where( + torch.abs(input - self.config.pad_val) < self.config.tolerance, + torch.tensor( + self.config.pad_val, dtype=input.dtype, device=input.device + ), + torch.tensor( + self.config.pad_val + 1, dtype=input.dtype, device=input.device + ), + ) + return outputs + + foo = Foo() + inputs = torch.randn(3, 4) + result = foo.forward(inputs) + + original_pad_val = foo.config.pad_val + foo.config.pad_val += 1.0 + result2 = foo.forward(inputs) + + # Previously would crash with: + # RuntimeError: value cannot be converted to type at::Half without overflow + devices = ("cuda", "hpu") instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)