1
1
import copy
2
2
import logging
3
3
from dataclasses import dataclass
4
- from typing import Dict , List , Optional , Sequence , Tuple
4
+ from typing import Callable , Dict , List , Optional , Sequence , Tuple
5
5
6
6
import torch
7
7
from torch .ao .ns .fx .utils import compute_sqnr
@@ -19,7 +19,16 @@ def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
19
19
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
20
20
The graph nodes of input model is modified inplace.
21
21
"""
22
- unique_id = 0
22
+ unique_id = - 1
23
+ # Find the max ID that exists in the graph first, in case part of the graph
24
+ # has already been annotated. This way we guarantee there are no duplicate
25
+ # handle IDs.
26
+ for node in graph_module .graph .nodes :
27
+ unique_id = max (
28
+ unique_id , node .meta .get (CUSTOM_KEY , {}).get (NUMERIC_DEBUG_HANDLE_KEY , - 1 )
29
+ )
30
+ unique_id += 1
31
+
23
32
for node in graph_module .graph .nodes :
24
33
if node .op in ["output" , "placeholder" ]:
25
34
continue
@@ -134,6 +143,17 @@ def sqnr(self) -> torch.Tensor:
134
143
self .actual .to (dtype = torch .float32 ), self .ref .to (dtype = torch .float32 )
135
144
)
136
145
146
+ def loss (
147
+ self , loss_function : Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ]
148
+ ) -> torch .Tensor :
149
+ if self .actual .shape != self .ref .shape :
150
+ raise ValueError (
151
+ f"Cannot compare tensors with different shapes: { self .actual .shape } vs { self .ref .shape } "
152
+ )
153
+ return loss_function (
154
+ self .actual .to (dtype = torch .float32 ), self .ref .to (dtype = torch .float32 )
155
+ )
156
+
137
157
def __repr__ (self ) -> str :
138
158
# Don't include the tensors themselves as they are quite large to print
139
159
# out.
@@ -149,6 +169,10 @@ def __post_init__(self) -> None:
149
169
150
170
if not isinstance (self .ref , torch .Tensor ):
151
171
raise ValueError (f"`self.ref` value must be a Tensor, got: { self .ref } " )
172
+ if self .actual .shape != self .ref .shape :
173
+ raise ValueError (
174
+ f"Cannot compare tensors with different shapes: ref={ self .ref .shape } vs actual={ self .actual .shape } "
175
+ )
152
176
153
177
154
178
@dataclass (frozen = True )
@@ -197,8 +221,8 @@ def extract_results_from_loggers(
197
221
198
222
199
223
def compare_results (
200
- ref_results : Dict [int , Tuple [str , object , List [torch .Tensor ]]],
201
- actual_results : Dict [int , Tuple [str , object , List [torch .Tensor ]]],
224
+ ref_results : Dict [int , Tuple [Optional [ str ] , object , List [torch .Tensor ]]],
225
+ actual_results : Dict [int , Tuple [Optional [ str ] , object , List [torch .Tensor ]]],
202
226
) -> Dict [int , NodeAccuracySummary ]:
203
227
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
204
228
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
@@ -220,16 +244,25 @@ def compare_results(
220
244
)
221
245
continue
222
246
actual_name , actual_stack , actual_stats = actual_results [debug_handle ]
247
+ try :
248
+ results = [
249
+ QuantizationComparisonResult (actual = a , ref = b )
250
+ for a , b in zip (actual_stats , ref_stats )
251
+ ]
252
+ except Exception as e :
253
+ # Add extra information for an exception from QuantizationComparisonResult
254
+ # if the shapes didn't match, to include the handle and the node names.
255
+ raise ValueError (
256
+ f"For numeric_debug_handle={ debug_handle } from ref node { ref_name } and actual node { actual_name } "
257
+ ) from e
258
+
223
259
comparisons [debug_handle ] = NodeAccuracySummary (
224
260
handle = debug_handle ,
225
- actual_node_name = actual_name ,
261
+ actual_node_name = actual_name or "" ,
226
262
actual_module_stack = _module_stack_to_str (actual_stack ),
227
- ref_node_name = ref_name ,
263
+ ref_node_name = ref_name or "" ,
228
264
ref_module_stack = _module_stack_to_str (ref_stack ),
229
- results = [
230
- QuantizationComparisonResult (actual = a , ref = b )
231
- for a , b in zip (actual_stats , ref_stats )
232
- ],
265
+ results = results ,
233
266
)
234
267
235
268
return comparisons
0 commit comments