Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3be1506

Browse files
Riley Dulinpytorchmergebot
authored andcommitted
[torch][ao] Add customizable loss function to NodeAccuracySummary (pytorch#136282)
Summary: Add a customizable loss function callback to NodeAccuracySummary to allow users to pass in their own loss function. Also, fix some type errors and propagate better exception messages when unexpected tensor comparisons occur. Finally, enhance the robustness of `generate_numeric_debug_handle` in the case where it is called multiple times on the same model, by avoiding reuse of the same IDs. Test Plan: Added a test for this case in `test_numeric_debugger`. Reviewed By: jerryzh168 Differential Revision: D62898297 Pull Request resolved: pytorch#136282 Approved by: https://github.com/jerryzh168
1 parent e09c5b6 commit 3be1506

File tree

2 files changed

+95
-12
lines changed

2 files changed

+95
-12
lines changed

test/quantization/pt2e/test_numeric_debugger.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
2626

2727

28-
def _extract_debug_handles(model) -> Dict[torch.fx.Node, int]:
29-
debug_handle_map: Dict[torch.fx.Node, int] = {}
28+
def _extract_debug_handles(model) -> Dict[str, int]:
29+
debug_handle_map: Dict[str, int] = {}
3030

3131
for node in model.graph.nodes:
3232
if (
@@ -187,3 +187,53 @@ def test_extract_results_from_loggers(self):
187187
for node_summary in comparison_results.values():
188188
if len(node_summary.results) > 0:
189189
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
190+
191+
def test_added_node_gets_unique_id(self) -> None:
192+
m = TestHelperModules.Conv2dThenConv1d()
193+
example_inputs = m.example_inputs()
194+
m = capture_pre_autograd_graph(m, example_inputs)
195+
assert isinstance(m, torch.fx.GraphModule)
196+
generate_numeric_debug_handle(m)
197+
ref_handles = _extract_debug_handles(m)
198+
ref_counter = Counter(ref_handles.values())
199+
for k, v in ref_counter.items():
200+
self.assertEqual(
201+
v,
202+
1,
203+
msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
204+
)
205+
206+
# Now that we have unique ids, add a new node into the graph and re-generate
207+
# to make sure that the new node gets a unique id.
208+
last_node = next(iter(reversed(m.graph.nodes)))
209+
with m.graph.inserting_before(last_node):
210+
arg = last_node.args[0]
211+
self.assertIsInstance(arg, (list, tuple))
212+
arg = arg[0]
213+
# Add a function that only requires a single tensor input.
214+
n = m.graph.call_function(torch.ops.aten.relu.default, args=(arg,))
215+
arg.replace_all_uses_with(n, lambda x: x != n)
216+
m.recompile()
217+
218+
# Regenerate handles, make sure only the new relu node has a new id, and
219+
# it doesn't clash with any of the existing ids.
220+
generate_numeric_debug_handle(m)
221+
handles_after_modification = _extract_debug_handles(m)
222+
handles_counter = Counter(handles_after_modification.values())
223+
for name, handle in ref_handles.items():
224+
self.assertIn(name, handles_after_modification)
225+
# Check that handle was unchanged.
226+
self.assertEqual(handles_after_modification[name], handle)
227+
# Check that total count was unchanged.
228+
ref_count = ref_counter[handle]
229+
after_count = handles_counter[handle]
230+
self.assertEqual(
231+
after_count,
232+
ref_count,
233+
msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
234+
)
235+
236+
# Check for relu specifically. Avoid hardcoding the handle id since it
237+
# may change with future node ordering changes.
238+
self.assertNotEqual(handles_after_modification["relu_default"], 0)
239+
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)

torch/ao/quantization/pt2e/_numeric_debugger.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import logging
33
from dataclasses import dataclass
4-
from typing import Dict, List, Optional, Sequence, Tuple
4+
from typing import Callable, Dict, List, Optional, Sequence, Tuple
55

66
import torch
77
from torch.ao.ns.fx.utils import compute_sqnr
@@ -19,7 +19,16 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
1919
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
2020
The graph nodes of input model is modified inplace.
2121
"""
22-
unique_id = 0
22+
unique_id = -1
23+
# Find the max ID that exists in the graph first, in case part of the graph
24+
# has already been annotated. This way we guarantee there are no duplicate
25+
# handle IDs.
26+
for node in graph_module.graph.nodes:
27+
unique_id = max(
28+
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, -1)
29+
)
30+
unique_id += 1
31+
2332
for node in graph_module.graph.nodes:
2433
if node.op in ["output", "placeholder"]:
2534
continue
@@ -134,6 +143,17 @@ def sqnr(self) -> torch.Tensor:
134143
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
135144
)
136145

146+
def loss(
147+
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
148+
) -> torch.Tensor:
149+
if self.actual.shape != self.ref.shape:
150+
raise ValueError(
151+
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
152+
)
153+
return loss_function(
154+
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
155+
)
156+
137157
def __repr__(self) -> str:
138158
# Don't include the tensors themselves as they are quite large to print
139159
# out.
@@ -149,6 +169,10 @@ def __post_init__(self) -> None:
149169

150170
if not isinstance(self.ref, torch.Tensor):
151171
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
172+
if self.actual.shape != self.ref.shape:
173+
raise ValueError(
174+
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
175+
)
152176

153177

154178
@dataclass(frozen=True)
@@ -197,8 +221,8 @@ def extract_results_from_loggers(
197221

198222

199223
def compare_results(
200-
ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
201-
actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
224+
ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
225+
actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
202226
) -> Dict[int, NodeAccuracySummary]:
203227
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
204228
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
@@ -220,16 +244,25 @@ def compare_results(
220244
)
221245
continue
222246
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
247+
try:
248+
results = [
249+
QuantizationComparisonResult(actual=a, ref=b)
250+
for a, b in zip(actual_stats, ref_stats)
251+
]
252+
except Exception as e:
253+
# Add extra information for an exception from QuantizationComparisonResult
254+
# if the shapes didn't match, to include the handle and the node names.
255+
raise ValueError(
256+
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
257+
) from e
258+
223259
comparisons[debug_handle] = NodeAccuracySummary(
224260
handle=debug_handle,
225-
actual_node_name=actual_name,
261+
actual_node_name=actual_name or "",
226262
actual_module_stack=_module_stack_to_str(actual_stack),
227-
ref_node_name=ref_name,
263+
ref_node_name=ref_name or "",
228264
ref_module_stack=_module_stack_to_str(ref_stack),
229-
results=[
230-
QuantizationComparisonResult(actual=a, ref=b)
231-
for a, b in zip(actual_stats, ref_stats)
232-
],
265+
results=results,
233266
)
234267

235268
return comparisons

0 commit comments

Comments
 (0)