Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7bcb6d3

Browse files
authored
Allow for matching debug handles with partial overlap between aten graph and runtime
Differential Revision: D82229367 Pull Request resolved: pytorch#14306
1 parent 03f436a commit 7bcb6d3

2 files changed

Lines changed: 31 additions & 14 deletions

File tree

devtools/inspector/_inspector_utils.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -657,13 +657,21 @@ def _combine_aot_overlapped_intermediate_outputs(
657657
# Combine all AOT debug_handles into a list
658658
aot_combined_debug_handle = [t[0] for t in aot_map.keys()]
659659

660-
if set(aot_combined_debug_handle) != set(runtime_debug_handle):
661-
# AOT combined debug_handle and runtime debug_handle do not match.
660+
# Reason we dont check for exact match:
661+
# in some experiments where we want to rewrite the aten graph that was
662+
# lowered, so as to use custom ops like int4_matmul, we lose some nodes
663+
# on the graph and thus lose some debug handles. And we dont find
664+
# exact match within connected components.
665+
if not set(aot_combined_debug_handle).issubset(set(runtime_debug_handle)):
666+
# AOT combined debug_handle is not a subset of runtime debug_handle.
662667
return (-1,), None
663668

664669
# Pick the last intermediate output
665670
last_int = runtime_debug_handle[negative_index]
666671
key = (last_int,)
672+
if key not in aot_map:
673+
# If the last intermediate output is not in the AOT map, return None
674+
return (-1,), None
667675
return runtime_debug_handle, aot_map[key]
668676

669677

@@ -1059,11 +1067,16 @@ def _find_n_match_node(node: Node) -> None:
10591067
if node.op in ("output", "placeholder"):
10601068
return
10611069
node_id = f"{node.name}.{exported_program_graph_id}"
1062-
parent_node_id = get_parent_node_identifier(node)
1070+
parent_node_ids = get_ancestor_node_identifiers(node)
10631071
if node_id in ancestors_node_id_to_debug_handle:
10641072
matched_debug_handles.add(ancestors_node_id_to_debug_handle[node_id])
1065-
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1066-
matched_debug_handles.add(ancestors_node_id_to_debug_handle[parent_node_id])
1073+
elif parent_node_ids:
1074+
for parent_node_id in parent_node_ids:
1075+
if parent_node_id in ancestors_node_id_to_debug_handle:
1076+
matched_debug_handles.add(
1077+
ancestors_node_id_to_debug_handle[parent_node_id]
1078+
)
1079+
break
10671080

10681081
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
10691082
return matched_debug_handles
@@ -1097,15 +1110,17 @@ def _equip_debug_handle(node: Node) -> None:
10971110
if node.op in ("output", "placeholder"):
10981111
return
10991112
node_id = f"{node.name}.{exported_program_graph_id}"
1100-
parent_node_id = get_parent_node_identifier(node)
1113+
parent_node_ids = get_ancestor_node_identifiers(node)
1114+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
11011115
if node_id in ancestors_node_id_to_debug_handle:
11021116
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
1103-
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1104-
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1105-
parent_node_id
1106-
]
1107-
else:
1108-
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1117+
elif parent_node_ids:
1118+
for parent_node_id in parent_node_ids:
1119+
if parent_node_id in ancestors_node_id_to_debug_handle:
1120+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1121+
parent_node_id
1122+
]
1123+
break
11091124

11101125
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
11111126

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,15 @@ def test_map_runtime_aot_intermediate_outputs_no_overlaps(self):
334334
self.assertEqual(actual, expected)
335335

336336
def test_map_runtime_aot_intermediate_outputs_partial_match(self):
337-
# Partial match between aot and runtime debug_handles will return empty
337+
# Partial match between aot and runtime debug_handles will return
338+
# matching debug handles from runtime
338339
aot_intermediate_outputs = {(2,): 100, (9,): 300}
339340
runtime_intermediate_outputs = {(2, 3): (200, 1), (8, 9): (300, 1)}
340341
actual = map_runtime_aot_intermediate_outputs(
341342
aot_intermediate_outputs, runtime_intermediate_outputs
342343
)
343-
expected = {}
344+
# Since the runtime output debug handle of 9 is there in aot debug handle
345+
expected = {((8, 9), 300): ((8, 9), 300)}
344346
self.assertEqual(actual, expected)
345347

346348
def test_map_runtime_aot_intermediate_outputs_multiple_aot_to_one_runtime(self):

0 commit comments

Comments
 (0)