Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix functionalization return mapping for mixed mutable out ops#186199

Draft
jansel wants to merge 1 commit into
gh/jansel/1329/basefrom
gh/jansel/1329/head
Draft

Fix functionalization return mapping for mixed mutable out ops#186199
jansel wants to merge 1 commit into
gh/jansel/1329/basefrom
gh/jansel/1329/head

Conversation

@jansel
Copy link
Copy Markdown
Contributor

@jansel jansel commented Jun 4, 2026

Stack from ghstack (oldest at bottom):

Functionalization codegen assumed the functional replacement returned fresh
outputs first and then one updated value per mutable input. That ordering is
wrong for schemas that both return out= aliases and mutate additional inputs.
For _native_batch_norm_legit.out, the functional op returns the schema
outputs first, followed by updated running_mean and running_var. The old
mapping therefore copied the out tensor into running_mean, causing
AOTAutograd tracing to fail with a rank-mismatched expand error.

Fix this by mapping functional returns according to the outer schema return
order. Aliased returns update their corresponding mutable argument, fresh
returns are wrapped and returned, and only mutable arguments not already
returned are mapped from the trailing functional returns. This keeps the fix in
the shared functionalization codegen path rather than special-casing batch norm.

Fixes #138264
Generated by my agent

Test Plan:

  • ninja -C build torch_python
  • ninja -C build install
  • python - <<'PY' ... torch.compile(torch._native_batch_norm_legit, backend="aot_eager")(*args, out=out) ... PY
  • python test/dynamo/test_aot_autograd.py -k test_native_batch_norm_legit_out
  • lintrunner -a
  • git diff --check

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @jataylo @azahed98

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jun 4, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/186199

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 3 Pending

As of commit dae4d5b with merge base 16de8e3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jansel
Copy link
Copy Markdown
Contributor Author

jansel commented Jun 4, 2026

@claude review these changes

@claude
Copy link
Copy Markdown

claude Bot commented Jun 4, 2026

Claude finished @jansel's task in 2m 8s —— View job


Code Review: Fix functionalization return mapping for mixed mutable out ops

  • Read the PR diff and changed files
  • Analyze the fix in torchgen/gen_functionalization_type.py
  • Review the test in test/dynamo/test_aot_autograd.py
  • Post review feedback

Summary: This is a correct and well-reasoned fix. The root cause analysis is accurate, and the fix correctly addresses the misaligned mapping between functional op returns and mutable arguments.

The Bug

For _native_batch_norm_legit.out, the schema is:

(Tensor input, ..., Tensor(a!) running_mean, Tensor(b!) running_var, ..., *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!))

The functional variant returns 5 tensors: [out_val, save_mean_val, save_invstd_val, running_mean_val, running_var_val]. The old code assumed fresh outputs came first, then mutated inputs — but here all 3 schema returns are aliased to out/save_mean/save_invstd, so non_aliased_outer_rets was empty. The old code then zipped all 5 functional returns against all 5 mutable args in flat_all order (running_mean, running_var, out, save_mean, save_invstd), causing out_val to be assigned to running_mean — a rank mismatch.

The Fix

The new logic correctly:

  1. Iterates over the schema return positions, matching aliased returns to their corresponding mutable arguments via annotation matching (outer_aliased_return_names).
  2. Maps the remaining (trailing) functional returns to mutable_arg_names_not_returned — which uses the exact same iteration order as FunctionSchema.signature() at torchgen/model.py:1852-1861 (self_arg → out → post_self_positional), guaranteeing consistency.
  3. Uses a dict-based lookup for mutation propagation instead of positional zipping, making it robust to different schema shapes.

Minor Observations

  1. Defensive assertion (line 564-567): The assertion that the inner functional op has no aliased returns is good — it documents an invariant and would catch future regressions early.

  2. itertools.chain ordering (lines 575-582): Correctly mirrors the ordering in model.py where returns_from_mutable_inputs is built. This is the critical correctness property and it's good that both places share the same pattern.

  3. Test coverage (lines 94-122): The test is well-structured — it verifies:

    • Return values match eager vs compiled
    • Out tensors are populated correctly
    • Mutated running stats (not in returns) are propagated correctly

    This directly exercises the bug path and confirms the rank mismatch no longer occurs.

Overall

LGTM. The fix is minimal, correct, and addresses the issue at the right abstraction level (shared codegen path rather than special-casing). The assertion messages are clear enough for debugging if a similar issue arises with another operator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant