Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[Inductor] Dynamo hangs when processing an operator, seemingly depending on a logical argument value #151743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
alexsamardzic opened this issue Apr 19, 2025 · 2 comments
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@alexsamardzic
Copy link
Collaborator

alexsamardzic commented Apr 19, 2025

πŸ› Describe the bug

Here is a reproducer:

import torch

device = "cuda"
group_size = 4
M, N, K = 16, 32, 64
dtype_AB = torch.float8_e4m3fn
dtype_scale = torch.float32
dtype_offset = torch.int32
dtype_C = torch.bfloat16

A = torch.ones(M, K * group_size, device=device).to(dtype_AB)
B = torch.ones(N, K * group_size, device=device).to(dtype_AB)
A_scale = torch.ones(group_size * M, device=device, dtype=dtype_scale)
B_scale = torch.ones(group_size * N, device=device, dtype=dtype_scale)
offs = torch.arange(K, group_size * K + 1, K, device=device, dtype=dtype_offset)

f_ref = torch._scaled_grouped_mm
f = torch.compile(
    f_ref,
)
torch.compiler.allow_in_graph(f_ref)

for use_fast_accum in [False, True]:
    print("use_fast_accum =", use_fast_accum)
    C_ref = f_ref(
        A,
        B.transpose(-2, -1),
        A_scale,
        B_scale,
        offs,
        out_dtype=dtype_C,
        use_fast_accum=use_fast_accum,
    )
    C = f(
        A,
        B.transpose(-2, -1),
        A_scale,
        B_scale,
        offs,
        out_dtype=dtype_C,
        use_fast_accum=use_fast_accum,
    )
    assert torch.allclose(C, C_ref, atol=1e-3, rtol=1e-3)

The first iteration of the loop, when use_fast_accum argument of _scaled_grouped_mm operator is set to False, goes fine, but in the second iteration, when the argument set to True, the compilation hangs. If a breakpoint set here, and then trying to step over and return from this function, it seems that the hang happens at this place.

(Note: the _scaled_grouped_mm operator works on Hopper only.)

Background: Initial support for auto-tuning of this operator is added through #150421, and I've encountered the issue while working on extending it through #150944. However, the problem is not related to auto-tuning, it could be reproduced with c3bc6b3, that was before #150421.

Error logs

Here is a backtrace from gdb, when reproducer stopped after being hang for some time. Apparently, it hangs in a cudaStreamSynchronize().

Gdb backtrace
#0  0x00007f95417203bf in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#1  0x00007f95413d368c in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007f954149699a in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007f95416f0029 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007f954153d89d in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#5  0x00007f95b12143a5 in ?? () from /scratch/pytorch-dev/lib/libcudart.so.12
#6  0x00007f95b12757d8 in cudaStreamSynchronize () from /scratch/pytorch-dev/lib/libcudart.so.12
#7  0x00007f959a673f3c in at::native::_local_scalar_dense_cuda(at::Tensor const&)::{lambda()#1}::operator()() const [clone .isra.0] ()
   from /scratch/pytorch/torch/lib/libtorch_cuda.so
#8  0x00007f959a675995 in at::native::_local_scalar_dense_cuda(at::Tensor const&) () from /scratch/pytorch/torch/lib/libtorch_cuda.so
#9  0x00007f959c298788 in at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___local_scalar_dense(at::Tensor const&) ()
   from /scratch/pytorch/torch/lib/libtorch_cuda.so
#10 0x00007f959c298810 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___local_scalar_dense>, c10::Scalar, c10::guts::typelist::typelist<at::Tensor const&> >, c10::Scalar (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) () from /scratch/pytorch/torch/lib/libtorch_cuda.so
#11 0x00007f95a5b5d93a in at::_ops::_local_scalar_dense::call(at::Tensor const&) () from /scratch/pytorch/torch/lib/libtorch_cpu.so
#12 0x00007f95a512eff3 in at::native::item(at::Tensor const&) () from /scratch/pytorch/torch/lib/libtorch_cpu.so
#13 0x00007f95a624cb31 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__item>, c10::Scalar, c10::guts::typelist::typelist<at::Tensor const&> >, c10::Scalar (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) () from /scratch/pytorch/torch/lib/libtorch_cpu.so
#14 0x00007f95a599133a in at::_ops::item::call(at::Tensor const&) () from /scratch/pytorch/torch/lib/libtorch_cpu.so
#15 0x00007f95a6808057 in unsigned char at::Tensor::item<unsigned char>() const () from /scratch/pytorch/torch/lib/libtorch_cpu.so
#16 0x00007f95a51d2899 in at::native::allclose(at::Tensor const&, at::Tensor const&, double, double, bool) () from /scratch/pytorch/torch/lib/libtorch_cpu.so
#17 0x00007f95a79742df in torch::autograd::VariableType::(anonymous namespace)::allclose(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, double, double, bool) ()
   from /scratch/pytorch/torch/lib/libtorch_cpu.so
#18 0x00007f95a577cceb in at::_ops::allclose::call(at::Tensor const&, at::Tensor const&, double, double, bool) () from /scratch/pytorch/torch/lib/libtorch_cpu.so
#19 0x00007f95b0634f2d in torch::autograd::THPVariable_allclose(_object*, _object*, _object*) () from /scratch/pytorch/torch/lib/libtorch_python.so
#20 0x000055bf35f0e4b6 in cfunction_call (func=<built-in method allclose of type object at remote 0x7f95b1187fe0>, args=<optimized out>, kwargs=<optimized out>)
    at /usr/local/src/conda/python-3.9.22/Objects/methodobject.c:543
#21 0x000055bf35ef6d4c in _PyObject_MakeTpCall (tstate=0x55bf36327ca0, callable=callable@entry=<built-in method allclose of type object at remote 0x7f95b1187fe0>, 
    args=<optimized out>, nargs=<optimized out>, keywords=keywords@entry=('atol', 'rtol')) at /usr/local/src/conda/python-3.9.22/Objects/call.c:191
#22 0x000055bf35ef3488 in _PyObject_VectorcallTstate (kwnames=('atol', 'rtol'), nargsf=<optimized out>, args=<optimized out>, 
    callable=<built-in method allclose of type object at remote 0x7f95b1187fe0>, tstate=<optimized out>) at /usr/local/src/conda/python-3.9.22/Include/cpython/abstract.h:116
#23 _PyObject_VectorcallTstate (kwnames=('atol', 'rtol'), nargsf=<optimized out>, args=<optimized out>, 
    callable=<built-in method allclose of type object at remote 0x7f95b1187fe0>, tstate=<optimized out>) at /usr/local/src/conda/python-3.9.22/Include/cpython/abstract.h:103
#24 PyObject_Vectorcall (kwnames=('atol', 'rtol'), nargsf=<optimized out>, args=<optimized out>, callable=<built-in method allclose of type object at remote 0x7f95b1187fe0>)
    at /usr/local/src/conda/python-3.9.22/Include/cpython/abstract.h:127
#25 call_function (kwnames=('atol', 'rtol'), oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=<optimized out>) at /usr/local/src/conda/python-3.9.22/Python/ceval.c:5077
#26 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=Frame 0x55bf36384a90, for file /scratch/pytorch/repro.py, line 34, in <module> (), throwflag=<optimized out>)
    at /usr/local/src/conda/python-3.9.22/Python/ceval.c:3537
#27 0x000055bf35eed685 in _PyEval_EvalFrame (throwflag=0, f=Frame 0x55bf36384a90, for file /scratch/pytorch/repro.py, line 34, in <module> (), tstate=0x55bf36327ca0)
    at /usr/local/src/conda/python-3.9.22/Include/internal/pycore_ceval.h:40
#28 _PyEval_EvalCode (tstate=0x55bf36327ca0, _co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=argcount@entry=0, kwnames=0x0, 
    kwargs=0x0, kwcount=<optimized out>, kwstep=2, defs=0x0, defcount=<optimized out>, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0)
    at /usr/local/src/conda/python-3.9.22/Python/ceval.c:4329
#29 0x000055bf35eed338 in _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=argcount@entry=0, 
    kwnames=<optimized out>, kwargs=0x0, kwcount=0, kwstep=2, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0)
    at /usr/local/src/conda/python-3.9.22/Python/ceval.c:4361
#30 0x000055bf35eed2e9 in PyEval_EvalCodeEx (_co=_co@entry=<code at remote 0x7f95b84a45b0>, 
    globals=globals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated), 
    locals=locals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated), args=args@entry=0x0, 
    argcount=argcount@entry=0, kws=kws@entry=0x0, kwcount=0, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0) at /usr/local/src/conda/python-3.9.22/Python/ceval.c:4377
#31 0x000055bf35f97ddb in PyEval_EvalCode (co=co@entry=<code at remote 0x7f95b84a45b0>, 
    globals=globals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated), 
    locals=locals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated))
    at /usr/local/src/conda/python-3.9.22/Python/ceval.c:828
#32 0x000055bf35fc4eaa in run_eval_code_obj (tstate=tstate@entry=0x55bf36327ca0, co=co@entry=0x7f95b84a45b0, 
    globals=globals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated), 
    locals=locals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated))
    at /usr/local/src/conda/python-3.9.22/Python/pythonrun.c:1221
#33 0x000055bf35fc1353 in run_mod (mod=mod@entry=0x55bf363fe360, filename=filename@entry='/scratch/pytorch/repro.py', 
    globals=globals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated), 
    locals=locals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated), 
    flags=flags@entry=0x7ffce610ce08, arena=arena@entry=0x7f95b855b950) at /usr/local/src/conda/python-3.9.22/Python/pythonrun.c:1242
#34 0x000055bf35e5c347 in pyrun_file (fp=fp@entry=0x55bf363602f0, filename=filename@entry='/scratch/pytorch/repro.py', start=start@entry=257, 
    globals=globals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated), 
    locals=locals@entry={'__name__': '__main__', '__doc__': None, '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/scratch/pytorch/repro.py') at remote 0x7f95b857dc10>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7f95b8568ae0>, '__file__': '/scratch/pytorch/repro.py', '__cached__': None, 'torch': <module at remote 0x7f95b835a220>, 'device': 'cuda', 'group_size': 4, 'M': 16, 'N': 32, 'K': 64, 'dtype_AB': <torch.dtype at remote 0x7f94a02574b0>, 'dtype_scale': <torch.dtype at remote 0x7f94a02cbdb0>, 'dtype_offset': <torch.dtype at remote 0x7f94a02cbc90>, 'dtype_C': <torch.dtype at remote 0x7f94a0257150>, 'A': <Tensor() at remote 0x7f949dd4c9f0>, 'B': <Tensor at remote 0x7f95b835a860>, 'A_scale': <Tensor() at remote 0x7f95b82fa040>, 'B_scale': <Tensor() at remote 0x7f95b835a810>, 'offs': <Tensor() at remote 0x7f95b835a8b0>, 'f_ref': <built-in method _scaled_grouped_mm of type object at remote 0x7f95b1187fe0>, 'f': <function at remote 0x7f9533bea820>, 'use_f...(truncated), 
    closeit=closeit@entry=1, flags=0x7ffce610ce08) at /usr/local/src/conda/python-3.9.22/Python/pythonrun.c:1140
#35 0x000055bf35fbb270 in pyrun_simple_file (flags=0x7ffce610ce08, closeit=1, filename='/scratch/pytorch/repro.py', fp=0x55bf363602f0)
    at /usr/local/src/conda/python-3.9.22/Python/pythonrun.c:450
#36 PyRun_SimpleFileExFlags (fp=0x55bf363602f0, filename=<optimized out>, closeit=1, flags=0x7ffce610ce08) at /usr/local/src/conda/python-3.9.22/Python/pythonrun.c:483
#37 0x000055bf35fb88a4 in pymain_run_file (cf=0x7ffce610ce08, config=0x55bf363266e0) at /usr/local/src/conda/python-3.9.22/Modules/main.c:377
#38 pymain_run_python (exitcode=0x7ffce610ce00) at /usr/local/src/conda/python-3.9.22/Modules/main.c:606
#39 Py_RunMain () at /usr/local/src/conda/python-3.9.22/Modules/main.c:685
#40 0x000055bf35f8bc57 in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at /usr/local/src/conda/python-3.9.22/Modules/main.c:1105
#41 0x00007f95b865cd90 in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#42 0x00007f95b865ce40 in __libc_start_main () from /usr/lib/x86_64-linux-gnu/libc.so.6
#43 0x000055bf35f8bb6e in _start ()

Versions

The collect_env.py output
Collecting environment information...
PyTorch version: 2.8.0a0+git92d0c40
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (conda-forge gcc 13.3.0-2) 13.3.0
Clang version: 20.1.3 (https://github.com/conda-forge/clangdev-feedstock 3e9dfa811865fe27bcd95c0004d27603f2ec4a73)
CMake version: version 4.0.1
Libc version: glibc-2.35

Python version: 3.9.22 | packaged by conda-forge | (main, Apr 14 2025, 23:35:59)  [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-119-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.6.85
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 560.35.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               128
On-line CPU(s) list:                  0-127
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Gold 6448Y
CPU family:                           6
Model:                                143
Thread(s) per core:                   2
Core(s) per socket:                   32
Socket(s):                            2
Stepping:                             8
BogoMIPS:                             4200.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            3 MiB (64 instances)
L1i cache:                            2 MiB (64 instances)
L2 cache:                             128 MiB (64 instances)
L3 cache:                             120 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80,82,84,86,88,90,92,94,96,98,100,102,104,106,108,110,112,114,116,118,120,122,124,126
NUMA node1 CPU(s):                    1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47,49,51,53,55,57,59,61,63,65,67,69,71,73,75,77,79,81,83,85,87,89,91,93,95,97,99,101,103,105,107,109,111,113,115,117,119,121,123,125,127
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.14.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] optree==0.13.0
[pip3] pytorch-triton==3.3.0+git96316ce5
[pip3] torch==2.8.0a0+git92d0c40
[conda] cuda-cudart               12.6.77              h5888daf_0    conda-forge
[conda] cuda-cudart-dev           12.6.77              h5888daf_0    conda-forge
[conda] cuda-cudart-dev_linux-64  12.6.77              h3f2d84a_0    conda-forge
[conda] cuda-cudart-static        12.6.77              h5888daf_0    conda-forge
[conda] cuda-cudart-static_linux-64 12.6.77              h3f2d84a_0    conda-forge
[conda] cuda-cudart_linux-64      12.6.77              h3f2d84a_0    conda-forge
[conda] cuda-cupti                12.6.80              hbd13f7d_0    conda-forge
[conda] cuda-cupti-dev            12.6.80              h5888daf_0    conda-forge
[conda] cuda-libraries-dev        12.6.3               ha770c72_0    conda-forge
[conda] cuda-nvrtc                12.6.85              hbd13f7d_0    conda-forge
[conda] cuda-nvrtc-dev            12.6.85              h5888daf_0    conda-forge
[conda] cuda-nvtx                 12.6.77              hbd13f7d_0    conda-forge
[conda] cuda-nvtx-dev             12.6.77              ha770c72_0    conda-forge
[conda] cuda-opencl               12.6.77              hbd13f7d_0    conda-forge
[conda] cuda-opencl-dev           12.6.77              h5888daf_0    conda-forge
[conda] cudnn                     9.8.0.87             h81d5506_1    conda-forge
[conda] libcublas                 12.6.4.1             h5888daf_1    conda-forge
[conda] libcublas-dev             12.6.4.1             h5888daf_1    conda-forge
[conda] libcufft                  11.3.0.4             hbd13f7d_0    conda-forge
[conda] libcufft-dev              11.3.0.4             h5888daf_0    conda-forge
[conda] libcurand                 10.3.7.77            hbd13f7d_0    conda-forge
[conda] libcurand-dev             10.3.7.77            h5888daf_0    conda-forge
[conda] libcusolver               11.7.1.2             h5888daf_1    conda-forge
[conda] libcusolver-dev           11.7.1.2             h5888daf_1    conda-forge
[conda] libcusparse               12.5.4.2             hbd13f7d_0    conda-forge
[conda] libcusparse-dev           12.5.4.2             h5888daf_0    conda-forge
[conda] libmagma                  2.9.0                h19665d7_1    conda-forge
[conda] libmagma_sparse           2.9.0                h19665d7_0    conda-forge
[conda] libnvjitlink              12.6.85              hbd13f7d_0    conda-forge
[conda] libnvjitlink-dev          12.6.85              h5888daf_0    conda-forge
[conda] magma                     2.9.0                h3d470c8_0    conda-forge
[conda] mkl                       2024.2.2            ha957f24_16    conda-forge
[conda] mkl-include               2025.1.0           hf2ce2f3_808    conda-forge
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] optree                    0.13.0                   pypi_0    pypi
[conda] pytorch-triton            3.3.0+git96316ce5          pypi_0    pypi
[conda] torch                     2.8.0a0+git92d0c40           dev_0    <develop>
[conda] torchfix                  0.4.0                    pypi_0    pypi

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

@mlazos mlazos added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels Apr 21, 2025
@mlazos
Copy link
Contributor

mlazos commented Apr 21, 2025

cc @anijain2305 if we're hanging in output_graph.py outside of call_user_compile this looks like an important Dynamo issue.

@alexsamardzic
Copy link
Collaborator Author

After a meta registration implemented for given operator, the problem disappeared - so I'm closing the issue for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants