Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 02501d8

Browse files
larryliu0820claude
andauthored
Add multi-method support for ETRecord edge dialect programs (pytorch#17185)
## Summary This PR implements proper multi-method support for ETRecord while fixing fundamental ETRecord lifecycle issues that were exposed during testing. ## Multi-Method ETRecord Support - **VLM Model Support**: Added support for models with multiple methods (vision_encoder, text_decoder, etc.) required for Vision-Language Models - **Flexible Storage**: ETRecord now stores `edge_dialect_program` as `Dict[str, ExportedProgram]` for multi-method cases - **Backward Compatibility**: Single-method cases still store as `ExportedProgram` for seamless compatibility - **Consistent Format**: Always uses new multi-method serialization format for future-proofing ## ETRecord Lifecycle Separation - **Fixed Shared State Issue**: EdgeProgramManager and ExecutorchProgramManager were incorrectly sharing ETRecord instances - **Proper Separation**: `EdgeProgramManager.to_executorch()` now creates separate ETRecord for ExecutorchProgramManager - **Edge Stage Integrity**: EdgeProgramManager retains pure edge-stage ETRecord (`_debug_handle_map = None`) - **Complete Records**: ExecutorchProgramManager gets complete ETRecord with executorch data populated - **Immutable Design**: Follows immutable snapshot principle where each transformation stage owns appropriate data ## Test Corrections - **Pure Edge Testing**: Fixed `test_to_edge_with_etrecord_generation` to test edge-stage behavior without executorch interference - **Correct Comparisons**: Fixed graph comparison logic that incorrectly mixed edge/executorch transformation states - **Full Coverage**: All 48 ETRecord tests pass with proper lifecycle separation ## Architecture Impact The changes establish clear architectural boundaries between transformation stages: - **Edge Stage**: ETRecord captures exported program + edge dialect program - **Executorch Stage**: ETRecord captures everything from edge + executorch-specific data - **No Shared State**: Each manager maintains its own appropriate ETRecord snapshot This enables ETRecord to support complex multi-method models used in VLM workflows while maintaining clean separation of concerns and preventing state corruption between transformation stages. ## Testing - ✅ All 48 ETRecord tests pass - ✅ Multi-method serialization/deserialization works correctly - ✅ Backward compatibility maintained for single-method cases - ✅ Lifecycle separation prevents state corruption - ✅ VLM workflows can now use ETRecord for debugging and inspection Differential Revision: [D92215098](https://www.internalfb.com/diff/D92215098) Co-authored-by: Claude Sonnet 4 <[email protected]>
1 parent f1bfb18 commit 02501d8

5 files changed

Lines changed: 288 additions & 58 deletions

File tree

.ci/scripts/unittest-linux-cmake.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@
66
# LICENSE file in the root directory of this source tree.
77
set -eux
88

9+
# Some ARM/TOSA-adjacent tests import modules that require tosa_serializer.
10+
# Install from a local tosa-tools checkout when available. If absent in this
11+
# checkout layout, clone the pinned upstream tag and install from there.
12+
if ! python -c "import tosa_serializer" >/dev/null 2>&1; then
13+
TOSA_SERIALIZATION_DIR="./examples/arm/arm-scratch/tosa-tools/serialization"
14+
if [[ ! -d "${TOSA_SERIALIZATION_DIR}" ]]; then
15+
TOSA_TOOLS_DIR="$(mktemp -d /tmp/tosa-tools.XXXXXX)"
16+
git clone --depth 1 --branch v2025.11.0 \
17+
https://git.gitlab.arm.com/tosa/tosa-tools.git "${TOSA_TOOLS_DIR}"
18+
TOSA_SERIALIZATION_DIR="${TOSA_TOOLS_DIR}/serialization"
19+
fi
20+
21+
CMAKE_POLICY_VERSION_MINIMUM=3.5 BUILD_PYBIND=1 \
22+
python -m pip install --no-dependencies \
23+
"${TOSA_SERIALIZATION_DIR}"
24+
python -c "import tosa_serializer"
25+
fi
26+
927
# Run pytest with coverage
1028
pytest -n auto --cov=./ --cov-report=xml
1129
# Run gtest

devtools/etrecord/_etrecord.py

Lines changed: 138 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def __init__(
6262
self,
6363
exported_program: Optional[ExportedProgram] = None,
6464
export_graph_id: Optional[int] = None,
65-
edge_dialect_program: Optional[ExportedProgram] = None,
65+
edge_dialect_program: Optional[
66+
Union[ExportedProgram, Dict[str, ExportedProgram]]
67+
] = None,
6668
graph_map: Optional[Dict[str, ExportedProgram]] = None,
6769
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None,
6870
_delegate_map: Optional[
@@ -88,18 +90,51 @@ def __init__(
8890
```
8991
9092
If user need to create an ETRecord manually, please use the `create_etrecord` function.
93+
94+
**EXPERIMENTAL**: This API supports multiple methods. For example:
95+
```python
96+
lowered_and_edge = to_edge_transform_and_lower(
97+
{
98+
"vision_encoder": vision_encoder_ep,
99+
"token_embedding": token_embedding_ep,
100+
"text_decoder": causal_llm_ep,
101+
},
102+
partitioner={
103+
"vision_encoder": [XnnpackPartitioner()],
104+
"token_embedding": [XnnpackPartitioner()],
105+
"text_decoder": [
106+
XnnpackPartitioner(
107+
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
108+
per_op_mode=True,
109+
),
110+
XnnpackPartitioner(),
111+
],
112+
},
113+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
114+
constant_methods=manager.metadata,
115+
generate_etrecord=True, # Enable ETRecord generation for all 3 methods
116+
)
117+
```
91118
"""
92119

93120
self.exported_program = exported_program
94121
self.export_graph_id = export_graph_id
95122
self.edge_dialect_program = edge_dialect_program
96123
self.graph_map = graph_map
97-
self._debug_handle_map = _debug_handle_map
124+
self.__debug_handle_map = _debug_handle_map # Use private attribute
98125
self._delegate_map = _delegate_map
99126
self._instruction_id_to_num_outs_map = _instruction_id_to_num_outs_map
100127
self._reference_outputs = _reference_outputs
101128
self._representative_inputs = _representative_inputs
102129

130+
@property
131+
def _debug_handle_map(self):
132+
return self.__debug_handle_map
133+
134+
@_debug_handle_map.setter
135+
def _debug_handle_map(self, value):
136+
self.__debug_handle_map = value
137+
103138
def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None:
104139
"""
105140
Serialize and save the ETRecord to the specified path for use in Inspector. The ETRecord
@@ -121,6 +156,14 @@ def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None:
121156
"ETRecord must contain edge dialect program and executorch program to be saved"
122157
)
123158

159+
# Normalize edge_dialect_program to dict format for consistent handling
160+
if isinstance(self.edge_dialect_program, ExportedProgram):
161+
self._edge_dialect_programs_dict: Dict[str, ExportedProgram] = {
162+
"forward": self.edge_dialect_program
163+
}
164+
else:
165+
self._edge_dialect_programs_dict = self.edge_dialect_program
166+
124167
etrecord_zip = ZipFile(path, "w")
125168

126169
try:
@@ -136,7 +179,7 @@ def _write_identifier(self, etrecord_zip: ZipFile) -> None:
136179
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")
137180

138181
def _save_programs(self, etrecord_zip: ZipFile) -> None:
139-
"""Save exported program and edge dialect program."""
182+
"""Save exported program and edge dialect program(s)."""
140183
if self.exported_program is not None:
141184
self._save_exported_program(
142185
etrecord_zip,
@@ -145,8 +188,9 @@ def _save_programs(self, etrecord_zip: ZipFile) -> None:
145188
self.exported_program,
146189
)
147190

148-
if self.edge_dialect_program is not None:
149-
self._save_edge_dialect_program(etrecord_zip, self.edge_dialect_program)
191+
# Save all edge dialect programs (supports multiple methods)
192+
for method_name, edge_program in self._edge_dialect_programs_dict.items():
193+
self._save_edge_dialect_program(etrecord_zip, method_name, edge_program)
150194

151195
def _save_graph_map(self, etrecord_zip: ZipFile) -> None:
152196
"""Save graph map if present."""
@@ -201,6 +245,30 @@ def _save_metadata(self, etrecord_zip: ZipFile) -> None:
201245
json.dumps(self.export_graph_id),
202246
)
203247

248+
def copy(self) -> "ETRecord":
249+
"""
250+
Create a shallow copy of this ETRecord suitable for transformation stages.
251+
252+
This creates a new ETRecord instance with the same edge-stage data but without
253+
executorch-specific data. Useful when transitioning between transformation stages
254+
while preserving immutable snapshot semantics.
255+
256+
Returns:
257+
ETRecord: A new ETRecord with edge-stage data copied over.
258+
"""
259+
return ETRecord(
260+
exported_program=self.exported_program,
261+
export_graph_id=self.export_graph_id,
262+
edge_dialect_program=self.edge_dialect_program,
263+
graph_map=self.graph_map,
264+
# Explicitly exclude executorch-specific fields for clean separation
265+
_debug_handle_map=None,
266+
_delegate_map=None,
267+
_instruction_id_to_num_outs_map=None,
268+
_reference_outputs=None,
269+
_representative_inputs=None,
270+
)
271+
204272
def _save_exported_program(
205273
self,
206274
etrecord_zip: ZipFile,
@@ -223,13 +291,19 @@ def _save_exported_program(
223291
)
224292

225293
def _save_edge_dialect_program(
226-
self, etrecord_zip: ZipFile, edge_dialect_program: ExportedProgram
294+
self,
295+
etrecord_zip: ZipFile,
296+
method_name: str,
297+
edge_dialect_program: ExportedProgram,
227298
) -> None:
228299
"""Save the edge dialect program to the ETRecord zip file."""
229300
serialized_artifact = serialize(edge_dialect_program)
230301
assert isinstance(serialized_artifact.exported_program, bytes)
231302

232-
base_name = ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM
303+
# Use format: edge_dialect_exported_program/method_name for multi-method support
304+
base_name = (
305+
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}/{method_name}"
306+
)
233307
etrecord_zip.writestr(base_name, serialized_artifact.exported_program)
234308
etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict)
235309
etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants)
@@ -381,6 +455,7 @@ def add_edge_dialect_program(
381455
)
382456

383457
# Process edge dialect program and extract data
458+
384459
processed_edge_dialect_program = _process_edge_dialect_program(
385460
edge_dialect_program
386461
)
@@ -591,13 +666,28 @@ def _add_module_to_graph_map(
591666

592667
def _process_edge_dialect_program(
593668
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram]
594-
) -> ExportedProgram:
595-
"""Process edge dialect program and return the exported program."""
669+
) -> Union[ExportedProgram, Dict[str, ExportedProgram]]:
670+
"""Process edge dialect program and return the exported program(s).
671+
672+
For EdgeProgramManager with multiple methods, returns a Dict[str, ExportedProgram]
673+
mapping method names to their exported programs. For single-method cases or
674+
ExirExportedProgram, returns a single ExportedProgram.
675+
"""
596676
if isinstance(
597677
edge_dialect_program,
598678
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
599679
):
600-
return edge_dialect_program.exported_program()
680+
methods = edge_dialect_program.methods
681+
if len(methods) == 1:
682+
# Single method case - return the ExportedProgram directly
683+
method_name = next(iter(methods))
684+
return edge_dialect_program.exported_program(method_name)
685+
else:
686+
# Multiple methods - return a dict of all methods
687+
return {
688+
method: edge_dialect_program.exported_program(method)
689+
for method in methods
690+
}
601691
elif isinstance(edge_dialect_program, ExirExportedProgram):
602692
return edge_dialect_program.exported_program
603693
else:
@@ -676,19 +766,26 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
676766
)
677767

678768
graph_map: Dict[str, ExportedProgram] = {}
769+
edge_dialect_programs: Dict[str, ExportedProgram] = {}
679770
debug_handle_map = None
680771
delegate_map = None
681772
instruction_id_to_num_outs_map = None
682773
exported_program = None
683-
edge_dialect_program = None
774+
edge_dialect_program: Optional[
775+
Union[ExportedProgram, Dict[str, ExportedProgram]]
776+
] = None
684777
reference_outputs = None
685778
representative_inputs = None
686779
export_graph_id = 0
687780

688781
serialized_exported_program_files = set()
782+
serialized_edge_dialect_program_files = set()
689783
serialized_state_dict_files = set()
690784
serialized_constants_files = set()
691785
serialized_example_inputs_files = set()
786+
787+
edge_dialect_prefix = f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}/"
788+
692789
for entry in file_list:
693790
if entry == ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME:
694791
debug_handle_map = json.loads(
@@ -707,6 +804,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
707804
elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER:
708805
continue
709806
elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM:
807+
# Old format: single edge dialect program (backward compatibility)
710808
serialized_artifact = SerializedArtifact(
711809
etrecord_zip.read(
712810
ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM
@@ -716,6 +814,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
716814
etrecord_zip.read(f"{entry}_example_inputs"),
717815
)
718816
edge_dialect_program = deserialize(serialized_artifact)
817+
elif entry.startswith(edge_dialect_prefix) and not entry.endswith(
818+
("_state_dict", "_constants", "_example_inputs")
819+
):
820+
# New format: edge_dialect_exported_program/method_name
821+
serialized_edge_dialect_program_files.add(entry)
719822
elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM:
720823
serialized_artifact = SerializedArtifact(
721824
etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM),
@@ -748,6 +851,30 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
748851
else:
749852
serialized_exported_program_files.add(entry)
750853

854+
# Parse new format edge dialect programs (multi-method support)
855+
for serialized_file in serialized_edge_dialect_program_files:
856+
serialized_state_dict_file = f"{serialized_file}_state_dict"
857+
serialized_constants_file = f"{serialized_file}_constants"
858+
serialized_example_inputs_file = f"{serialized_file}_example_inputs"
859+
serialized_artifact = SerializedArtifact(
860+
etrecord_zip.read(serialized_file),
861+
etrecord_zip.read(serialized_state_dict_file),
862+
etrecord_zip.read(serialized_constants_file),
863+
etrecord_zip.read(serialized_example_inputs_file),
864+
)
865+
# Extract method name from path: edge_dialect_exported_program/method_name -> method_name
866+
method_name = serialized_file[len(edge_dialect_prefix) :]
867+
edge_dialect_programs[method_name] = deserialize(serialized_artifact)
868+
869+
# If we found multi-method edge dialect programs, use them
870+
if edge_dialect_programs:
871+
if len(edge_dialect_programs) == 1:
872+
# Single method - store as ExportedProgram for backward compatibility
873+
edge_dialect_program = next(iter(edge_dialect_programs.values()))
874+
else:
875+
# Multiple methods - store as dict
876+
edge_dialect_program = edge_dialect_programs
877+
751878
for serialized_file in serialized_exported_program_files:
752879
serialized_state_dict_file = f"{serialized_file}_state_dict"
753880
serialized_constants_file = f"{serialized_file}_constants"

0 commit comments

Comments
 (0)