diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 1ff754532f..557c01667f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -12,6 +12,7 @@ import tensorrt as trt import torch from torch._subclasses.fake_tensor import FakeTensor +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES @@ -256,48 +257,54 @@ def prepare_inputs( inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], disable_memory_format_check: bool = False, ) -> Any: - if inputs is None: - return None - - elif isinstance(inputs, Input): - return inputs + """ + We take a nested group of torch.Tensors or scalars and convert them into torchtrt.Input's + """ + # Any tensors created inside this call will be FakeTensors if it's inside a torch.compile session + # So, we disable fake mode temporarily. + with unset_fake_temporarily(): + if inputs is None: + return None - elif isinstance(inputs, (torch.Tensor, int, float, bool)): - return Input.from_tensor( - torch.tensor(inputs), - disable_memory_format_check=disable_memory_format_check, - ) + elif isinstance(inputs, Input): + return inputs - elif isinstance(inputs, (list, tuple)): - torchtrt_input_list = [] - for input_obj in inputs: - torchtrt_input = prepare_inputs( - input_obj, disable_memory_format_check=disable_memory_format_check + elif isinstance(inputs, (torch.Tensor, int, float, bool)): + return Input.from_tensor( + torch.tensor(inputs), + disable_memory_format_check=disable_memory_format_check, ) - torchtrt_input_list.append(torchtrt_input) - - return ( - torchtrt_input_list - if isinstance(inputs, list) - else tuple(torchtrt_input_list) - ) - elif isinstance(inputs, dict): - torchtrt_inputs_dict: Dict[Any, Any] = dict() + elif isinstance(inputs, (list, tuple)): + torchtrt_input_list = [] + for input_obj in inputs: + torchtrt_input = prepare_inputs( + input_obj, disable_memory_format_check=disable_memory_format_check + ) + torchtrt_input_list.append(torchtrt_input) - for key, input_obj in inputs.items(): - torchtrt_input = prepare_inputs( - input_obj, disable_memory_format_check=disable_memory_format_check + return ( + torchtrt_input_list + if isinstance(inputs, list) + else tuple(torchtrt_input_list) ) - torchtrt_inputs_dict[key] = torchtrt_input - return torchtrt_inputs_dict + elif isinstance(inputs, dict): + torchtrt_inputs_dict: Dict[Any, Any] = dict() - else: - raise ValueError( - f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " - + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" - ) + for key, input_obj in inputs.items(): + torchtrt_input = prepare_inputs( + input_obj, disable_memory_format_check=disable_memory_format_check + ) + torchtrt_inputs_dict[key] = torchtrt_input + + return torchtrt_inputs_dict + + else: + raise ValueError( + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" + ) def parse_complex_tensor_structs(