Fix functionalization return mapping for mixed mutable out ops#186199
Fix functionalization return mapping for mixed mutable out ops#186199jansel wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/186199
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 3 PendingAs of commit dae4d5b with merge base 16de8e3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@claude review these changes |
|
Claude finished @jansel's task in 2m 8s —— View job Code Review: Fix functionalization return mapping for mixed mutable out ops
Summary: This is a correct and well-reasoned fix. The root cause analysis is accurate, and the fix correctly addresses the misaligned mapping between functional op returns and mutable arguments. The BugFor The functional variant returns 5 tensors: The FixThe new logic correctly:
Minor Observations
OverallLGTM. The fix is minimal, correct, and addresses the issue at the right abstraction level (shared codegen path rather than special-casing). The assertion messages are clear enough for debugging if a similar issue arises with another operator. |
Stack from ghstack (oldest at bottom):
Functionalization codegen assumed the functional replacement returned fresh
outputs first and then one updated value per mutable input. That ordering is
wrong for schemas that both return out= aliases and mutate additional inputs.
For
_native_batch_norm_legit.out, the functional op returns the schemaoutputs first, followed by updated
running_meanandrunning_var. The oldmapping therefore copied the
outtensor intorunning_mean, causingAOTAutograd tracing to fail with a rank-mismatched
expanderror.Fix this by mapping functional returns according to the outer schema return
order. Aliased returns update their corresponding mutable argument, fresh
returns are wrapped and returned, and only mutable arguments not already
returned are mapped from the trailing functional returns. This keeps the fix in
the shared functionalization codegen path rather than special-casing batch norm.
Fixes #138264
Generated by my agent
Test Plan:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @jataylo @azahed98