-
Notifications
You must be signed in to change notification settings - Fork 24.1k
Attributeless FakeRootModule #135696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Labels
dynamo-triage-jan2025
module: dynamo
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
Would you be able to share the model repo and the running script? |
On main d0f08dc, the
Minimal repro:
|
Way simpler repro
|
pytorchmergebot
pushed a commit
that referenced
this issue
May 6, 2025
Might also fix - #135696 Pull Request resolved: #152853 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos, https://github.com/jansel
anijain2305
added a commit
that referenced
this issue
May 6, 2025
β¦ues" Might also fix - #135696 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
anijain2305
added a commit
that referenced
this issue
May 6, 2025
Might also fix - #135696 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
anijain2305
added a commit
that referenced
this issue
May 6, 2025
β¦ues" Might also fix - #135696 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
anijain2305
added a commit
that referenced
this issue
May 6, 2025
Might also fix - #135696 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this issue
May 7, 2025
Might also fix - #135696 Pull Request resolved: #152853 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos, https://github.com/jansel
anijain2305
added a commit
that referenced
this issue
May 7, 2025
Fixes #135696 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
anijain2305
added a commit
that referenced
this issue
May 7, 2025
Fixes #135696 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
anijain2305
added a commit
that referenced
this issue
May 7, 2025
Fixes #135696 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
anijain2305
added a commit
that referenced
this issue
May 7, 2025
Fixes #135696 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
dynamo-triage-jan2025
module: dynamo
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
π Describe the bug
When running a segmentation model (UNet) from segmentation_models_pytorch using PyTorch's torch._dynamo, I encountered an internal error.
The line that triggers the issue is not the compiling method itself, but when for example I call the summary or when I start the training. For example:
summary(
model,
(Params.channels, *Params.image_reshape),
batch_size=Params.batch_size,
device=Params.device,
)
Error logs
{
"name": "InternalTorchDynamoError",
"message": "'FakeRootModule' object has no attribute 'self___relu__forward_hooks_10___closure___1_cell_contents__Conv2d_1____nb_params'
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
",
"stack": "---------------------------------------------------------------------------
InternalTorchDynamoError Traceback (most recent call last)
Cell In[14], line 2
1 # Print model summary
----> 2 summary(
3 model,
4 (Params.channels, *Params.image_reshape),
5 batch_size=Params.batch_size,
6 device=Params.device,
7 )
File ~/.local/lib/python3.12/site-packages/torchsummary/torchsummary.py:72, in summary(model, input_size, batch_size, device)
68 model.apply(register_hook)
70 # make a forward pass
71 # print(x.shape)
---> 72 model(*x)
74 # remove these hooks
75 for h in hooks:
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:433, in _TorchDynamoContext.call.._fn(*args, **kwargs)
428 saved_dynamic_layer_stack_depth = (
429 torch._C._functorch.get_dynamic_layer_stack_depth()
430 )
432 try:
--> 433 return fn(*args, **kwargs)
434 finally:
435 # Restore the dynamic layer stack depth if necessary.
436 torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
437 saved_dynamic_layer_stack_depth
438 )
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File ~/.local/lib/python3.12/site-packages/segmentation_models_pytorch/base/model.py:33, in SegmentationModel.forward(self, x)
23 new_w = (
24 (w // output_stride + 1) * output_stride
25 if w % output_stride != 0
26 else w
27 )
28 raise RuntimeError(
29 f"Wrong input shape height={h}, width={w}. Expected image height and width "
30 f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
31 )
---> 33 def forward(self, x):
34 """Sequentially pass
x
trough model`s encoder, decoder and heads"""36 self.check_input_shape(x)
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File ~/.local/lib/python3.12/site-packages/segmentation_models_pytorch/encoders/resnet.py:58, in ResNetEncoder.forward(self, x)
48 def get_stages(self):
49 return [
50 nn.Identity(),
51 nn.Sequential(self.conv1, self.bn1, self.relu),
(...)
55 self.layer4,
56 ]
---> 58 def forward(self, x):
59 stages = self.get_stages()
61 features = []
File ~/.local/lib/python3.12/site-packages/segmentation_models_pytorch/encoders/resnet.py:63, in torch_dynamo_resume_in_forward_at_59(___stack0, self, x)
61 features = []
62 for i in range(self._depth + 1):
---> 63 x = stagesi
64 features.append(x)
66 return features
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/container.py:219, in Sequential.forward(self, input)
217 def forward(self, input):
218 for module in self:
--> 219 input = module(input)
220 return input
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/container.py:219, in Sequential.forward(self, input)
217 def forward(self, input):
218 for module in self:
--> 219 input = module(input)
220 return input
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1603, in Module._call_impl(self, *args, **kwargs)
1600 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1601 args = bw_hook.setup_input_hook(args)
-> 1603 result = forward_call(*args, **kwargs)
1604 if _global_forward_hooks or self._forward_hooks:
1605 for hook_id, hook in (
1606 *_global_forward_hooks.items(),
1607 *self._forward_hooks.items(),
1608 ):
1609 # mark that always called hook is run
File ~/.local/lib/python3.12/site-packages/torchvision/models/resnet.py:143, in Bottleneck.forward(self, x)
140 self.downsample = downsample
141 self.stride = stride
--> 143 def forward(self, x: Tensor) -> Tensor:
144 identity = x
146 out = self.conv1(x)
File ~/.local/lib/python3.12/site-packages/torchvision/models/resnet.py:146, in torch_dynamo_resume_in_forward_at_146(___stack0, self, x, identity)
143 def forward(self, x: Tensor) -> Tensor:
144 identity = x
--> 146 out = self.conv1(x)
147 out = self.bn1(out)
148 out = self.relu(out)
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1116, in CatchErrorsWrapper.call(self, frame, cache_entry, frame_state)
1110 return hijacked_callback(
1111 frame, cache_entry, self.hooks, frame_state
1112 )
1114 with compile_lock, _disable_current_modes():
1115 # skip=1: skip this frame
-> 1116 return self._torchdynamo_orig_callable(
1117 frame, cache_entry, self.hooks, frame_state, skip=1
1118 )
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:948, in ConvertFrame.call(self, frame, cache_entry, hooks, frame_state, skip)
946 counters["frames"]["total"] += 1
947 try:
--> 948 result = self._inner_convert(
949 frame, cache_entry, hooks, frame_state, skip=skip + 1
950 )
951 counters["frames"]["ok"] += 1
952 return result
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:472, in ConvertFrameAssert.call(self, frame, cache_entry, hooks, frame_state, skip)
458 compile_id = CompileId(frame_id, frame_compile_id)
460 signpost_event(
461 "dynamo",
462 "_convert_frame_assert._compile",
(...)
469 },
470 )
--> 472 return _compile(
473 frame.f_code,
474 frame.f_globals,
475 frame.f_locals,
476 frame.f_builtins,
477 self._torchdynamo_orig_callable,
478 self._one_graph,
479 self._export,
480 self._export_constraints,
481 hooks,
482 cache_entry,
483 cache_size,
484 frame,
485 frame_state=frame_state,
486 compile_id=compile_id,
487 skip=skip + 1,
488 )
File ~/.local/lib/python3.12/site-packages/torch/_utils_internal.py:84, in compile_time_strobelight_meta..compile_time_strobelight_meta_inner..wrapper_function(*args, **kwargs)
82 if "skip" in kwargs:
83 kwargs["skip"] = kwargs["skip"] + 1
---> 84 return StrobelightCompileTimeProfiler.profile_compile_time(
85 function, phase_name, *args, **kwargs
86 )
File ~/.local/lib/python3.12/site-packages/torch/_strobelight/compile_time_profiler.py:129, in StrobelightCompileTimeProfiler.profile_compile_time(cls, func, phase_name, *args, **kwargs)
124 @classmethod
125 def profile_compile_time(
126 cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
127 ) -> Any:
128 if not cls.enabled:
--> 129 return func(*args, **kwargs)
131 if cls.profiler is None:
132 logger.error("profiler is not set")
File /usr/local/lib/python3.12/contextlib.py:81, in ContextDecorator.call..inner(*args, **kwds)
78 @wraps(func)
79 def inner(*args, **kwds):
80 with self._recreate_cm():
---> 81 return func(*args, **kwds)
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:846, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
844 fail_user_frame_lineno = e.innermost_user_frame_summary.lineno # type: ignore[attr-defined]
845 e.compile_id = compile_id # type: ignore[attr-defined]
--> 846 raise InternalTorchDynamoError(str(e)).with_traceback(
847 e.traceback
848 ) from None
849 finally:
850 if tracer:
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:817, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
815 guarded_code = None
816 try:
--> 817 guarded_code = compile_inner(code, one_graph, hooks, transform)
818 return guarded_code
819 except (
820 Unsupported,
821 TorchRuntimeError,
(...)
828 BisectValidationException,
829 ) as e:
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/utils.py:231, in dynamo_timed..dynamo_timed_inner..time_wrapper(*args, **kwargs)
229 with torch.profiler.record_function(f"{key} (dynamo_timed)"):
230 t0 = time.time()
--> 231 r = func(*args, **kwargs)
232 time_spent = time.time() - t0
233 compilation_time_metrics[key].append(time_spent)
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:636, in _compile..compile_inner(code, one_graph, hooks, transform)
634 CompileContext.get().attempt = attempt
635 try:
--> 636 out_code = transform_code_object(code, transform)
637 break
638 except exc.RestartAnalysis as e:
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py:1185, in transform_code_object(code, transformations, safe)
1182 instructions = cleaned_instructions(code, safe)
1183 propagate_line_nums(instructions)
-> 1185 transformations(instructions, code_options)
1186 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:178, in preserve_global_state.._fn(*args, **kwargs)
176 cleanup = setup_compile_debug()
177 try:
--> 178 return fn(*args, **kwargs)
179 finally:
180 cleanup.close()
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:582, in _compile..transform(instructions, code_options)
580 try:
581 with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 582 tracer.run()
583 except exc.UnspecializeRestartAnalysis:
584 speculation_log.clear()
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2451, in InstructionTranslator.run(self)
2450 def run(self):
-> 2451 super().run()
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:893, in InstructionTranslatorBase.run(self)
891 try:
892 self.output.push_tx(self)
--> 893 while self.step():
894 pass
895 except BackendCompilerFailed:
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:805, in InstructionTranslatorBase.step(self)
802 self.update_block_stack(inst)
804 try:
--> 805 self.dispatch_table[inst.opcode](self, inst)
806 return not self.output.should_exit
807 except exc.ObservedException:
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:497, in break_graph_if_unsupported..decorator..wrapper(self, inst)
495 if speculation.failed:
496 assert speculation.reason is not None
--> 497 return handle_graph_break(self, inst, speculation.reason)
498 try:
499 return inner_fn(self, inst)
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:566, in break_graph_if_unsupported..decorator..handle_graph_break(self, inst, reason)
561 def handle_graph_break(
562 self: "InstructionTranslatorBase",
563 inst: Instruction,
564 reason: GraphCompileReason,
565 ):
--> 566 self.output.compile_subgraph(self, reason=reason)
567 cg = PyCodegen(self)
568 cleanup: List[Instruction] = []
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/output_graph.py:1123, in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
1120 output = []
1121 if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
1122 output.extend(
-> 1123 self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
1124 )
1126 if len(pass2.graph_outputs) != 0:
1127 output.append(pass2.create_store(graph_output_var))
File /usr/local/lib/python3.12/contextlib.py:81, in ContextDecorator.call..inner(*args, **kwds)
78 @wraps(func)
79 def inner(*args, **kwds):
80 with self._recreate_cm():
---> 81 return func(*args, **kwds)
File ~/.local/lib/python3.12/site-packages/torch/_dynamo/output_graph.py:1269, in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root)
1261 self.create_node(
1262 "output",
1263 "output",
1264 (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
1265 {},
1266 )
1267 if not config.do_not_emit_runtime_asserts:
1268 insert_deferred_runtime_asserts(
-> 1269 fx.GraphModule(root, self.graph),
1270 self.shape_env,
1271 name,
1272 )
1273 # NB: deferred runtime asserts can keep graphargs live, so make sure
1274 # those are inserted before pruning
1275 self.remove_unused_graphargs()
File ~/.local/lib/python3.12/site-packages/torch/fx/graph_module.py:399, in GraphModule.init(self, root, graph, class_name)
397 if node.op in ["get_attr", "call_module"]:
398 assert isinstance(node.target, str)
--> 399 _copy_attr(root, self, node.target)
400 elif isinstance(root, dict):
401 targets_to_copy = []
File ~/.local/lib/python3.12/site-packages/torch/fx/graph_module.py:229, in _copy_attr(from_module, to_module, target)
226 setattr(to_module, item, t)
227 from_module, to_module = f, t
--> 229 orig = getattr(from_module, field)
230 # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
231 # So, we register it as a named buffer in the target module.
232 if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
File ~/.local/lib/python3.12/site-packages/torch/nn/modules/module.py:1729, in Module.getattr(self, name)
1727 if name in modules:
1728 return modules[name]
-> 1729 raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")
InternalTorchDynamoError: 'FakeRootModule' object has no attribute 'self___relu__forward_hooks_10___closure___1_cell_contents__Conv2d_1____nb_params'
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
"
}
Minified repro
No response
Versions
OS: (sysname='Linux', release='6.10.6-2-liquorix-amd64', version='#1 ZEN SMP PREEMPT liquorix 6.10-6ubuntu1~jammy (2024-08-20)', machine='x86_64')
__Python VERSION: 3.12.3 (main, Sep 4 2024, 12:08:24) [GCC 13.2.0]
Python Version: 3.12.3 (main, Sep 4 2024, 12:08:24) [GCC 13.2.0]
PyTorch Version: 2.4.1+cu124
CUDA Version: 12.6
GPU: NVIDIA GeForce GTX 1650 Ti
NVIDIA Driver Version: 560.35.03
Active CUDA Device: GPU 0
Number of CUDA Devices: 1
Available Devices: 1
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Wed_Aug_14_10:10:22_PDT_2024
Cuda compilation tools, release 12.6, V12.6.68
Build cuda_12.6.r12.6/compiler.34714021_0
Collect output:
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec
The text was updated successfully, but these errors were encountered: