From 02d603820fde6bf81bf8c8d537c122bdc818adf2 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Sat, 15 Jul 2023 09:24:27 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- .github/scripts/run_cpp_linter.py | 4 +- .github/scripts/run_py_linter.py | 4 +- .../fx/hugging_face_torchdynamo_example.py | 20 +--- examples/fx/lower_example.py | 3 +- examples/fx/lower_example_aten.py | 3 +- examples/fx/quantized_resnet_test.py | 8 +- examples/fx/torchdynamo_example.py | 3 +- examples/int8/training/vgg16/finetune_qat.py | 16 +-- examples/int8/training/vgg16/main.py | 14 +-- noxfile.py | 32 +++--- py/setup.py | 63 ++++++----- py/torch_tensorrt/_Device.py | 60 +++++------ py/torch_tensorrt/_Input.py | 68 +++--------- py/torch_tensorrt/__init__.py | 2 +- py/torch_tensorrt/_compile.py | 43 ++++---- .../dynamo/_TorchTensorRTModule.py | 17 +-- py/torch_tensorrt/dynamo/backend/backends.py | 18 ++-- .../dynamo/backend/conversion.py | 23 ++-- .../dynamo/backend/lowering/_partition.py | 9 +- .../backend/lowering/_pre_aot_lowering.py | 3 +- .../backend/lowering/substitutions/einsum.py | 5 +- .../lowering/substitutions/maxpool1d.py | 5 +- .../backend/test/test_backend_compiler.py | 9 +- .../backend/test/test_decompositions.py | 17 ++- .../dynamo/backend/test/test_partitioning.py | 21 ++-- .../backend/test/test_pre_aot_lowering.py | 8 +- .../backend/test/test_specialized_models.py | 6 +- .../dynamo/backend/test/utils.py | 2 +- py/torch_tensorrt/dynamo/backend/utils.py | 22 +--- .../dynamo/fx_ts_compat/fx2trt.py | 40 ++++--- .../dynamo/fx_ts_compat/input_tensor_spec.py | 28 ++--- .../passes/lower_pass_manager_builder.py | 57 +++++----- .../dynamo/fx_ts_compat/passes/pass_utils.py | 93 ++++++++-------- .../fx_ts_compat/tools/common_fx2trt.py | 15 +-- .../fx_ts_compat/tools/trt_minimizer.py | 33 +++--- py/torch_tensorrt/fx/converter_registry.py | 5 +- .../fx/converters/acc_ops_converters.py | 100 +++++++---------- .../fx/converters/converter_utils.py | 33 +++--- .../fx/converters/impl/convolution.py | 6 +- .../fx/converters/nn_ops_converters.py | 6 +- .../fx/converters/transformation.py | 3 +- py/torch_tensorrt/fx/diagnostics.py | 11 +- py/torch_tensorrt/fx/fx2trt.py | 38 ++++--- py/torch_tensorrt/fx/input_tensor_spec.py | 63 +++++------ py/torch_tensorrt/fx/observer.py | 7 +- .../fx/passes/lower_basic_pass.py | 43 ++++---- .../fx/passes/lower_basic_pass_aten.py | 101 ++++++++---------- .../fx/passes/lower_pass_manager_builder.py | 57 +++++----- py/torch_tensorrt/fx/passes/pass_utils.py | 15 ++- .../fx/test/converters/acc_op/test_reshape.py | 7 +- .../fx/test/converters/acc_op/test_type_as.py | 11 +- ...test_fix_clamp_numerical_limits_to_fp16.py | 7 +- .../fx/test/passes/test_graph_opts.py | 7 +- .../test_remove_duplicate_output_args.py | 8 +- .../fx/test/passes/test_setitem_trt.py | 6 +- .../fx/test/quant/test_quant_trt.py | 20 ++-- .../fx/test/tracer/test_acc_tracer.py | 17 +-- .../fx/test/trt_lower/test_fx2trt_lower.py | 21 ++-- .../fx/test/trt_lower/trt_splitter_test.py | 31 +++--- py/torch_tensorrt/fx/tools/common_fx2trt.py | 15 +-- .../fx/tools/engine_layer_visualize.py | 44 ++++---- py/torch_tensorrt/fx/tools/model_packager.py | 6 +- .../fx/tools/timing_cache_utils.py | 2 +- py/torch_tensorrt/fx/tools/trt_minimizer.py | 5 +- py/torch_tensorrt/fx/tools/trt_splitter.py | 17 ++- .../fx/tracer/acc_tracer/acc_normalizer.py | 11 +- .../fx/tracer/acc_tracer/acc_ops.py | 39 +++---- .../fx/tracer/acc_tracer/acc_shape_prop.py | 2 +- .../fx/tracer/acc_tracer/acc_tracer.py | 20 ++-- .../fx/tracer/acc_tracer/acc_utils.py | 15 ++- .../fx/tracer/dispatch_tracer/tracer.py | 35 +++--- py/torch_tensorrt/fx/trt_module.py | 19 ++-- py/torch_tensorrt/fx/utils.py | 16 +-- py/torch_tensorrt/ptq.py | 36 +++---- py/torch_tensorrt/ts/_compile_spec.py | 46 +++----- py/torch_tensorrt/ts/_compiler.py | 3 +- py/torch_tensorrt/ts/ts_input.py | 6 +- tests/modules/custom_models.py | 26 ++--- tests/modules/hub.py | 18 ++-- tests/py/api/test_classes.py | 23 ++-- tests/py/api/test_collections.py | 28 ++--- tests/py/api/utils.py | 2 +- tests/py/model_test_case.py | 2 +- tests/py/models/custom_models.py | 3 +- .../py/ptq/test_ptq_dataloader_calibrator.py | 9 +- tests/py/ptq/test_ptq_to_backend.py | 9 +- tests/py/ptq/test_ptq_trt_calibrator.py | 9 +- tests/py/qat/test_qat_trt_accuracy.py | 11 +- tools/linter/cpplint.py | 7 +- tools/linter/cpplint_diff.py | 5 +- tools/linter/pylint.py | 7 +- tools/linter/pylint_diff.py | 16 +-- tools/linter/utils.py | 6 +- tools/perf/custom_models.py | 3 +- tools/perf/hub.py | 18 ++-- tools/perf/perf_run.py | 34 +++--- tools/perf/utils.py | 9 +- 97 files changed, 895 insertions(+), 1084 deletions(-) diff --git a/.github/scripts/run_cpp_linter.py b/.github/scripts/run_cpp_linter.py index 44748c49f3..3fcfb7c053 100644 --- a/.github/scripts/run_cpp_linter.py +++ b/.github/scripts/run_cpp_linter.py @@ -25,9 +25,7 @@ comment = """Code conforms to C++ style guidelines""" approval = "APPROVE" if output.returncode != 0: - comment = """There are some changes that do not conform to C++ style guidelines:\n ```diff\n{}```""".format( - output.stdout.decode("utf-8") - ) + comment = f"""There are some changes that do not conform to C++ style guidelines:\n ```diff\n{output.stdout.decode("utf-8")}```""" approval = "REQUEST_CHANGES" try: diff --git a/.github/scripts/run_py_linter.py b/.github/scripts/run_py_linter.py index f8f2b7f567..7b961cf584 100644 --- a/.github/scripts/run_py_linter.py +++ b/.github/scripts/run_py_linter.py @@ -29,9 +29,7 @@ stdout=subprocess.PIPE, ) out_text = diff_output.stdout.decode("utf-8") - comment = """There are some changes that do not conform to Python style guidelines:\n ```diff\n{}```""".format( - out_text - ) + comment = f"""There are some changes that do not conform to Python style guidelines:\n ```diff\n{out_text}```""" approval = "REQUEST_CHANGES" try: diff --git a/examples/fx/hugging_face_torchdynamo_example.py b/examples/fx/hugging_face_torchdynamo_example.py index 388ccf2e47..0be759a3e2 100644 --- a/examples/fx/hugging_face_torchdynamo_example.py +++ b/examples/fx/hugging_face_torchdynamo_example.py @@ -173,9 +173,7 @@ def get_cur_memory(): gc.collect() torch.cuda.empty_cache() stats = torch.cuda.memory_stats() - peak_bytes_requirement = stats["allocated_bytes.all.current"] - # print(f"Current memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB") - return peak_bytes_requirement + return stats["allocated_bytes.all.current"] @torchdynamo.skip @@ -205,11 +203,7 @@ def check_correctness(args, mod, inputs, optimize_ctx, optimize_name): print("ERROR") return False - if optimize_name == "dynamo_fx2trt_fp16": - cos_similarity = True - else: - cos_similarity = False - + cos_similarity = optimize_name == "dynamo_fx2trt_fp16" if not same(correct_result, new_result, cos_similarity=cos_similarity, tol=1e-2): print("INCORRECT") return False @@ -252,9 +246,7 @@ def bench_model_eval(args, name, mod, eval_inputs, optimize_ctx): # Profile time iters = 50 synchronize() - timings = [] - for _ in range(iters): - timings.append(timed(mod, forward_pass, eval_inputs)) + timings = [timed(mod, forward_pass, eval_inputs) for _ in range(iters)] t = np.median(timings, axis=0) else: # does not need recompile for torchdynamo, demo for fx2trt only @@ -275,9 +267,7 @@ def bench_model_eval(args, name, mod, eval_inputs, optimize_ctx): # Profile time iters = 50 synchronize() - timings = [] - for _ in range(iters): - timings.append(timed(mod, forward_pass, eval_inputs)) + timings = [timed(mod, forward_pass, eval_inputs) for _ in range(iters)] t = np.median(timings, axis=0) print(name, t, m) @@ -413,7 +403,7 @@ def main(): # fp16 if optimize_name == "dynamo_fx2trt_fp16": experiment = partial(experiment, dtype=torch.float16) - if optimize_name == "dynamo_fx2trt_fp32": + elif optimize_name == "dynamo_fx2trt_fp32": experiment = partial(experiment, dtype=torch.float32) experiment = partial( diff --git a/examples/fx/lower_example.py b/examples/fx/lower_example.py index 81c1cd28bc..c5fa046758 100644 --- a/examples/fx/lower_example.py +++ b/examples/fx/lower_example.py @@ -194,8 +194,7 @@ def run_configuration_benchmark( else: print("Lowering with JIT is not available!", "red") - result = Result(module=module, input=input, conf=conf, time_sec=time) - return result + return Result(module=module, input=input, conf=conf, time_sec=time) if __name__ == "__main__": diff --git a/examples/fx/lower_example_aten.py b/examples/fx/lower_example_aten.py index 09a8e7cb85..b896368abc 100644 --- a/examples/fx/lower_example_aten.py +++ b/examples/fx/lower_example_aten.py @@ -186,8 +186,7 @@ def run_configuration_benchmark( else: print("Lowering with JIT is not available!", "red") - result = Result(module=module, input=input, conf=conf, time_sec=time) - return result + return Result(module=module, input=input, conf=conf, time_sec=time) if __name__ == "__main__": diff --git a/examples/fx/quantized_resnet_test.py b/examples/fx/quantized_resnet_test.py index 64d7579414..086c91fa34 100644 --- a/examples/fx/quantized_resnet_test.py +++ b/examples/fx/quantized_resnet_test.py @@ -128,13 +128,7 @@ def __init__(self): # self.conv = torch.nn.Conv2d(3, 3, 3, padding=1) def forward(self, x): - # out = self.conv(x) - out = self.linear(x) - # out = torch.nn.functional.relu(out) - # out += x - # out += out - # out = torch.nn.functional.relu(out) - return out + return self.linear(x) # rn18 = M().eval() diff --git a/examples/fx/torchdynamo_example.py b/examples/fx/torchdynamo_example.py index 0d640de68c..d8d10ec203 100644 --- a/examples/fx/torchdynamo_example.py +++ b/examples/fx/torchdynamo_example.py @@ -215,8 +215,7 @@ def run_configuration_benchmark( else: print("Lowering mode is not available!", "red") - result = Result(module=module, input=input, conf=conf, time_sec=time) - return result + return Result(module=module, input=input, conf=conf, time_sec=time) if __name__ == "__main__": diff --git a/examples/int8/training/vgg16/finetune_qat.py b/examples/int8/training/vgg16/finetune_qat.py index 6ec20a9a46..f5e645c599 100644 --- a/examples/int8/training/vgg16/finetune_qat.py +++ b/examples/int8/training/vgg16/finetune_qat.py @@ -61,8 +61,8 @@ args = PARSER.parse_args() for arg in vars(args): - print(" {} {}".format(arg, getattr(args, arg))) -state = {k: v for k, v in args._get_kwargs()} + print(f" {arg} {getattr(args, arg)}") +state = dict(args._get_kwargs()) if args.seed is None: args.seed = random.randint(1, 10000) @@ -75,7 +75,7 @@ timestamp = datetime.timestamp(now) -writer = SummaryWriter(args.tensorboard + "/test_" + str(timestamp)) +writer = SummaryWriter(f"{args.tensorboard}/test_{str(timestamp)}") classes = ( "plane", "car", @@ -156,7 +156,7 @@ def calibrate_model( with torch.no_grad(): collect_stats(model, data_loader, num_calib_batch) - if not calibrator == "histogram": + if calibrator != "histogram": compute_amax(model, method="max") calib_output = os.path.join( out_dir, @@ -244,8 +244,8 @@ def main(): ) if args.start_from != 0: - ckpt_file = args.ckpt_dir + "/ckpt_epoch" + str(args.start_from) + ".pth" - print("Loading from checkpoint {}".format(ckpt_file)) + ckpt_file = f"{args.ckpt_dir}/ckpt_epoch{str(args.start_from)}.pth" + print(f"Loading from checkpoint {ckpt_file}") assert os.path.isfile(ckpt_file) ckpt = torch.load(ckpt_file) modified_state_dict = {} @@ -365,7 +365,7 @@ def test(model, dataloader, crit, epoch): def save_checkpoint(state, ckpt_dir="checkpoint"): - print("Checkpoint {} saved".format(state["epoch"])) + print(f'Checkpoint {state["epoch"]} saved') filename = "ckpt_epoch" + str(state["epoch"]) + ".pth" filepath = os.path.join(ckpt_dir, filename) torch.save(state, filepath) @@ -376,7 +376,7 @@ def adjust_lr(optimizer, epoch): new_lr = state["lr"] * (0.5 ** (epoch // 40)) if state["lr"] > 1e-7 else state["lr"] if new_lr != state["lr"]: state["lr"] = new_lr - print("Updating learning rate: {}".format(state["lr"])) + print(f'Updating learning rate: {state["lr"]}') for param_group in optimizer.param_groups: param_group["lr"] = state["lr"] diff --git a/examples/int8/training/vgg16/main.py b/examples/int8/training/vgg16/main.py index 3f248a9283..74884527af 100644 --- a/examples/int8/training/vgg16/main.py +++ b/examples/int8/training/vgg16/main.py @@ -50,8 +50,8 @@ args = PARSER.parse_args() for arg in vars(args): - print(" {} {}".format(arg, getattr(args, arg))) -state = {k: v for k, v in args._get_kwargs()} + print(f" {arg} {getattr(args, arg)}") +state = dict(args._get_kwargs()) if args.seed is None: args.seed = random.randint(1, 10000) @@ -64,7 +64,7 @@ timestamp = datetime.timestamp(now) -writer = SummaryWriter(args.tensorboard + "/test_" + str(timestamp)) +writer = SummaryWriter(f"{args.tensorboard}/test_{str(timestamp)}") classes = ( "plane", "car", @@ -146,8 +146,8 @@ def main(): model = nn.DataParallel(model) if args.start_from != 0: - ckpt_file = args.ckpt_dir + "/ckpt_epoch" + str(args.start_from) + ".pth" - print("Loading from checkpoint {}".format(ckpt_file)) + ckpt_file = f"{args.ckpt_dir}/ckpt_epoch{str(args.start_from)}.pth" + print(f"Loading from checkpoint {ckpt_file}") assert os.path.isfile(ckpt_file) ckpt = torch.load(ckpt_file) model.load_state_dict(ckpt["model_state_dict"]) @@ -238,7 +238,7 @@ def test(model, dataloader, crit, epoch): def save_checkpoint(state, ckpt_dir="checkpoint"): - print("Checkpoint {} saved".format(state["epoch"])) + print(f'Checkpoint {state["epoch"]} saved') filename = "ckpt_epoch" + str(state["epoch"]) + ".pth" filepath = os.path.join(ckpt_dir, filename) torch.save(state, filepath) @@ -249,7 +249,7 @@ def adjust_lr(optimizer, epoch): new_lr = state["lr"] * (0.5 ** (epoch // 40)) if state["lr"] > 1e-7 else state["lr"] if new_lr != state["lr"]: state["lr"] = new_lr - print("Updating learning rate: {}".format(state["lr"])) + print(f'Updating learning rate: {state["lr"]}') for param_group in optimizer.param_groups: param_group["lr"] = state["lr"] diff --git a/noxfile.py b/noxfile.py index 2629c0391b..2d6fba8e6e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -6,7 +6,7 @@ # Use system installed Python packages PYT_PATH = ( "/usr/local/lib/python3.10/dist-packages" - if not "PYT_PATH" in os.environ + if "PYT_PATH" not in os.environ else os.environ["PYT_PATH"] ) print(f"Using python path {PYT_PATH}") @@ -15,18 +15,20 @@ # TOP_DIR TOP_DIR = ( os.path.dirname(os.path.realpath(__file__)) - if not "TOP_DIR" in os.environ + if "TOP_DIR" not in os.environ else os.environ["TOP_DIR"] ) print(f"Test root directory {TOP_DIR}") # Set the USE_CXX11=1 to use cxx11_abi -USE_CXX11 = 0 if not "USE_CXX11" in os.environ else os.environ["USE_CXX11"] +USE_CXX11 = 0 if "USE_CXX11" not in os.environ else os.environ["USE_CXX11"] if USE_CXX11: print("Using cxx11 abi") # Set the USE_HOST_DEPS=1 to use host dependencies for tests -USE_HOST_DEPS = 0 if not "USE_HOST_DEPS" in os.environ else os.environ["USE_HOST_DEPS"] +USE_HOST_DEPS = ( + 0 if "USE_HOST_DEPS" not in os.environ else os.environ["USE_HOST_DEPS"] +) if USE_HOST_DEPS: print("Using dependencies from host python") @@ -36,7 +38,7 @@ SUPPORTED_PYTHON_VERSIONS = ["3.7", "3.8", "3.9", "3.10"] nox.options.sessions = [ - "l0_api_tests-" + "{}.{}".format(sys.version_info.major, sys.version_info.minor) + f"l0_api_tests-{sys.version_info.major}.{sys.version_info.minor}" ] @@ -92,7 +94,7 @@ def train_model(session): session.run_always( "python", "export_ckpt.py", - "vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth", + f"vgg16_ckpts/ckpt_epoch{str(EPOCHS)}.pth", env={"PYTHONPATH": PYT_PATH}, ) else: @@ -112,7 +114,9 @@ def train_model(session): ) session.run_always( - "python", "export_ckpt.py", "vgg16_ckpts/ckpt_epoch" + str(EPOCHS) + ".pth" + "python", + "export_ckpt.py", + f"vgg16_ckpts/ckpt_epoch{str(EPOCHS)}.pth", ) @@ -146,7 +150,7 @@ def finetune_model(session): session.run_always( "python", "export_qat.py", - "vgg16_ckpts/ckpt_epoch" + str(EPOCHS + 1) + ".pth", + f"vgg16_ckpts/ckpt_epoch{str(EPOCHS + 1)}.pth", env={"PYTHONPATH": PYT_PATH}, ) else: @@ -171,7 +175,7 @@ def finetune_model(session): session.run_always( "python", "export_qat.py", - "vgg16_ckpts/ckpt_epoch" + str(EPOCHS + 1) + ".pth", + f"vgg16_ckpts/ckpt_epoch{str(EPOCHS + 1)}.pth", ) @@ -185,8 +189,8 @@ def cleanup(session): "tests/py/*.jit.pt", ] - target = " ".join(x for x in [os.path.join(TOP_DIR, i) for i in target]) - session.run_always("bash", "-c", str("rm -rf ") + target, external=True) + target = " ".join([os.path.join(TOP_DIR, i) for i in target]) + session.run_always("bash", "-c", f"rm -rf {target}", external=True) def run_base_tests(session): @@ -324,15 +328,13 @@ def copy_model(session): model_files = ["trained_vgg16.jit.pt", "trained_vgg16_qat.jit.pt"] for file_name in model_files: - src_file = os.path.join( - TOP_DIR, str("examples/int8/training/vgg16/") + file_name - ) + src_file = os.path.join(TOP_DIR, f"examples/int8/training/vgg16/{file_name}") if os.path.exists(src_file): session.run_always( "cp", "-rpf", os.path.join(TOP_DIR, src_file), - os.path.join(TOP_DIR, str("tests/modules/") + file_name), + os.path.join(TOP_DIR, f"tests/modules/{file_name}"), external=True, ) diff --git a/py/setup.py b/py/setup.py index e07e904f87..65bb42d522 100644 --- a/py/setup.py +++ b/py/setup.py @@ -55,7 +55,7 @@ def get_git_revision_short_hash() -> str: sys.argv.remove("--legacy") if "--release" not in sys.argv: - __version__ = __version__ + "+" + get_git_revision_short_hash() + __version__ = f"{__version__}+{get_git_revision_short_hash()}" else: RELEASE = True sys.argv.remove("--release") @@ -119,13 +119,12 @@ def is_exe(fpath): if BAZEL_EXE is None: BAZEL_EXE = which("bazel") - if BAZEL_EXE is None: - sys.exit("Could not find bazel in PATH") + if BAZEL_EXE is None: + sys.exit("Could not find bazel in PATH") def build_libtorchtrt_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=False): - cmd = [BAZEL_EXE, "build"] - cmd.append("//:libtorchtrt") + cmd = [BAZEL_EXE, "build", "//:libtorchtrt"] if develop: cmd.append("--compilation_mode=dbg") else: @@ -159,26 +158,26 @@ def build_libtorchtrt_pre_cxx11_abi(develop=True, use_dist_dir=True, cxx11_abi=F def gen_version_file(): - if not os.path.exists(dir_path + "/torch_tensorrt/_version.py"): - os.mknod(dir_path + "/torch_tensorrt/_version.py") + if not os.path.exists(f"{dir_path}/torch_tensorrt/_version.py"): + os.mknod(f"{dir_path}/torch_tensorrt/_version.py") - with open(dir_path + "/torch_tensorrt/_version.py", "w") as f: + with open(f"{dir_path}/torch_tensorrt/_version.py", "w") as f: print("creating version file") - f.write('__version__ = "' + __version__ + '"\n') - f.write('__cuda_version__ = "' + __cuda_version__ + '"\n') - f.write('__cudnn_version__ = "' + __cudnn_version__ + '"\n') - f.write('__tensorrt_version__ = "' + __tensorrt_version__ + '"\n') + f.write(f'__version__ = "{__version__}' + '"\n') + f.write(f'__cuda_version__ = "{__cuda_version__}' + '"\n') + f.write(f'__cudnn_version__ = "{__cudnn_version__}' + '"\n') + f.write(f'__tensorrt_version__ = "{__tensorrt_version__}' + '"\n') def copy_libtorchtrt(multilinux=False): - if not os.path.exists(dir_path + "/torch_tensorrt/lib"): - os.makedirs(dir_path + "/torch_tensorrt/lib") + if not os.path.exists(f"{dir_path}/torch_tensorrt/lib"): + os.makedirs(f"{dir_path}/torch_tensorrt/lib") print("copying library into module") if multilinux: copyfile( - dir_path + "/build/libtrtorch_build/libtrtorch.so", - dir_path + "/trtorch/lib/libtrtorch.so", + f"{dir_path}/build/libtrtorch_build/libtrtorch.so", + f"{dir_path}/trtorch/lib/libtrtorch.so", ) else: os.system( @@ -200,13 +199,13 @@ def finalize_options(self): def run(self): if FX_ONLY: gen_version_file() - develop.run(self) else: global CXX11_ABI build_libtorchtrt_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI) gen_version_file() copy_libtorchtrt() - develop.run(self) + + develop.run(self) class InstallCommand(install): @@ -221,13 +220,13 @@ def finalize_options(self): def run(self): if FX_ONLY: gen_version_file() - install.run(self) else: global CXX11_ABI build_libtorchtrt_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI) gen_version_file() copy_libtorchtrt() - install.run(self) + + install.run(self) class BdistCommand(bdist_wheel): @@ -284,8 +283,8 @@ def run(self): for path in [str(p) for p in abs_paths]: if not path.startswith(dir_path): # Die if path in CLEAN_FILES is absolute + outside this directory - raise ValueError("%s is not a path inside %s" % (path, dir_path)) - print("Removing %s" % os.path.relpath(path)) + raise ValueError(f"{path} is not a path inside {dir_path}") + print(f"Removing {os.path.relpath(path)}") rmtree(path) for path_spec in self.PY_CLEAN_FILES: @@ -294,8 +293,8 @@ def run(self): for path in [str(p) for p in abs_paths]: if not path.startswith(dir_path): # Die if path in CLEAN_FILES is absolute + outside this directory - raise ValueError("%s is not a path inside %s" % (path, dir_path)) - print("Removing %s" % os.path.relpath(path)) + raise ValueError(f"{path} is not a path inside {dir_path}") + print(f"Removing {os.path.relpath(path)}") os.remove(path) @@ -309,18 +308,18 @@ def run(self): "torch_tensorrt/csrc/register_tensorrt_classes.cpp", ], library_dirs=[ - (dir_path + "/torch_tensorrt/lib/"), + f"{dir_path}/torch_tensorrt/lib/", "/opt/conda/lib/python3.6/config-3.6m-x86_64-linux-gnu", ], libraries=["torchtrt"], include_dirs=[ - dir_path + "torch_tensorrt/csrc", - dir_path + "torch_tensorrt/include", - dir_path + "/../bazel-TRTorch/external/tensorrt/include", - dir_path + "/../bazel-Torch-TensorRT/external/tensorrt/include", - dir_path + "/../bazel-TensorRT/external/tensorrt/include", - dir_path + "/../bazel-tensorrt/external/tensorrt/include", - dir_path + "/../", + f"{dir_path}torch_tensorrt/csrc", + f"{dir_path}torch_tensorrt/include", + f"{dir_path}/../bazel-TRTorch/external/tensorrt/include", + f"{dir_path}/../bazel-Torch-TensorRT/external/tensorrt/include", + f"{dir_path}/../bazel-TensorRT/external/tensorrt/include", + f"{dir_path}/../bazel-tensorrt/external/tensorrt/include", + f"{dir_path}/../", ], extra_compile_args=[ "-Wno-deprecated", diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 3eaa5aad4e..c26b32c9b2 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -55,44 +55,40 @@ def __init__(self, *args, **kwargs): raise TypeError( "When specifying Device through positional argument, argument must be str" ) + (self.device_type, id) = Device._parse_device_str(args[0]) + if self.device_type == trt.DeviceType.GPU: + self.gpu_id = id else: - (self.device_type, id) = Device._parse_device_str(args[0]) - if self.device_type == trt.DeviceType.GPU: - self.gpu_id = id + self.dla_core = id + self.gpu_id = 0 + logging.log( + logging.Level.Warning, + "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", + ) + + elif not args: + if "gpu_id" not in kwargs and "dla_core" not in kwargs: + raise ValueError( + "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" + ) + + if "dla_core" in kwargs: + self.device_type = trt.DeviceType.DLA + self.dla_core = kwargs["dla_core"] + if "gpu_id" in kwargs: + self.gpu_id = kwargs["gpu_id"] else: - self.dla_core = id self.gpu_id = 0 logging.log( logging.Level.Warning, "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", ) - - elif len(args) == 0: - if "gpu_id" in kwargs or "dla_core" in kwargs: - if "dla_core" in kwargs: - self.device_type = trt.DeviceType.DLA - self.dla_core = kwargs["dla_core"] - if "gpu_id" in kwargs: - self.gpu_id = kwargs["gpu_id"] - else: - self.gpu_id = 0 - logging.log( - logging.Level.Warning, - "Setting GPU id to 0 for device because device 0 manages DLA on Xavier", - ) - else: - self.gpu_id = kwargs["gpu_id"] - self.device_type = trt.DeviceType.GPU else: - raise ValueError( - "Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg" - ) - + self.gpu_id = kwargs["gpu_id"] + self.device_type = trt.DeviceType.GPU else: raise ValueError( - "Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments".format( - len(args) - ) + f"Unexpected number of positional arguments for class Device \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) if "allow_gpu_fallback" in kwargs: @@ -102,11 +98,9 @@ def __init__(self, *args, **kwargs): def __str__(self) -> str: return ( - "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) + ")" + f"Device(type={self.device_type}, gpu_id={self.gpu_id})" if self.device_type == trt.DeviceType.GPU - else ", dla_core={}, allow_gpu_fallback={}".format( - self.dla_core, self.allow_gpu_fallback - ) + else f", dla_core={self.dla_core}, allow_gpu_fallback={self.allow_gpu_fallback}" ) def _to_internal(self) -> _C.Device: @@ -149,7 +143,7 @@ def _current_device(cls): def _parse_device_str(s): s = s.lower() spec = s.split(":") - if spec[0] == "gpu" or spec[0] == "cuda": + if spec[0] in ["gpu", "cuda"]: return (trt.DeviceType.GPU, int(spec[1])) elif spec[0] == "dla": return (trt.DeviceType.DLA, int(spec[1])) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index e76817e041..2ebe21e945 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -81,9 +81,9 @@ def __init__(self, *args, **kwargs): self.shape = tuple(args[0]) self.shape_mode = Input._ShapeMode.STATIC - elif len(args) == 0: - if not ("shape" in kwargs) and not ( - all(k in kwargs for k in ["min_shape", "opt_shape", "max_shape"]) + elif not args: + if "shape" not in kwargs and any( + k not in kwargs for k in ["min_shape", "opt_shape", "max_shape"] ): raise ValueError( "Missing required arguments for class Input\nEither shape or all three of min_shape, opt_shape, max_shape must be defined" @@ -132,9 +132,7 @@ def __init__(self, *args, **kwargs): else: raise ValueError( - "Unexpected number of positional arguments for class Input \n Found {} arguments, expected either zero or a single positional arguments".format( - len(args) - ) + f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) if "dtype" in kwargs: @@ -148,45 +146,20 @@ def __init__(self, *args, **kwargs): if "format" in kwargs: self.format = Input._parse_format(kwargs["format"]) - if "tensor_domain" in kwargs: - domain = kwargs["tensor_domain"] - else: - domain = None - + domain = kwargs.get("tensor_domain", None) self.tensor_domain = Input._parse_tensor_domain(domain) def __str__(self) -> str: if self.shape_mode == Input._ShapeMode.STATIC: - return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format( - self.shape, - str(self.dtype), - str(self.format), - str(self.tensor_domain[0]), - str(self.tensor_domain[1]), - ) + return f"Input(shape={self.shape}, dtype={str(self.dtype)}, format={str(self.format)}, domain=[{str(self.tensor_domain[0])}, {str(self.tensor_domain[1])}))" elif self.shape_mode == Input._ShapeMode.DYNAMIC: - return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))".format( - self.shape["min_shape"], - self.shape["opt_shape"], - self.shape["max_shape"], - str(self.dtype), - str(self.format), - str(self.tensor_domain[0]), - str(self.tensor_domain[1]), - ) + return f'Input(min_shape={self.shape["min_shape"]}, opt_shape={self.shape["opt_shape"]}, max_shape={self.shape["max_shape"]}, dtype={str(self.dtype)}, format={str(self.format)}, domain=[{str(self.tensor_domain[0])}, {str(self.tensor_domain[1])}))' else: raise RuntimeError("Unknown input shape mode") @staticmethod def _supported_input_size_type(input_size: Any) -> bool: - if isinstance(input_size, torch.Size): - return True - elif isinstance(input_size, tuple): - return True - elif isinstance(input_size, list): - return True - else: - return False + return isinstance(input_size, (torch.Size, tuple, list)) @staticmethod def _parse_dtype(dtype: Any) -> _enums.dtype: @@ -275,10 +248,9 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: elif len(domain) == 2: domain_lo, domain_hi = domain - # Validate type and provided values for domain - valid_type_lo = isinstance(domain_lo, int) or isinstance(domain_lo, float) - valid_type_hi = isinstance(domain_hi, int) or isinstance(domain_hi, float) + valid_type_hi = isinstance(domain_hi, (int, float)) + valid_type_lo = isinstance(domain_lo, (int, float)) if not valid_type_lo: raise ValueError( f"Expected value for tensor domain low specifier, got {domain_lo}" @@ -290,8 +262,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: if domain_hi <= domain_lo: raise ValueError( - "Expected provided integer range to have low tensor domain value " - + f"< high tensor domain value, got invalid range [{domain_lo}, {domain_hi})" + f"Expected provided integer range to have low tensor domain value < high tensor domain value, got invalid range [{domain_lo}, {domain_hi})" ) result_domain = (float(domain_lo), float(domain_hi)) else: @@ -357,23 +328,18 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor if optimization_profile_field is not None: try: assert any( - [ - optimization_profile_field == field_name - for field_name in ["min_shape", "opt_shape", "max_shape"] - ] + optimization_profile_field == field_name + for field_name in ["min_shape", "opt_shape", "max_shape"] ) except: raise ValueError( "Invalid field name, expected one of min_shape, opt_shape, max_shape" ) - if ( - optimization_profile_field is not None - and self.shape_mode == Input._ShapeMode.STATIC - ): - raise ValueError( - "Specified a optimization profile field but the input is static" - ) + if self.shape_mode == Input._ShapeMode.STATIC: + raise ValueError( + "Specified a optimization profile field but the input is static" + ) if ( optimization_profile_field is None diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index 952563f5ca..92f3a94f63 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -101,7 +101,7 @@ def _find_lib(name, paths): def _register_with_torch(): trtorch_dir = os.path.dirname(__file__) - torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so") + torch.ops.load_library(f"{trtorch_dir}/lib/libtorchtrt.so") _register_with_torch() diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index de0aeb5308..e2bde40dbe 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -42,10 +42,14 @@ def _parse_module_type(module: Any) -> _ModuleType: def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: - module_is_tsable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.ts]]) - module_is_fxable = any([module_type == t for t in [_ModuleType.nn, _ModuleType.fx]]) - - ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]]) + module_is_tsable = any( + module_type == t for t in [_ModuleType.nn, _ModuleType.ts] + ) + module_is_fxable = any( + module_type == t for t in [_ModuleType.nn, _ModuleType.fx] + ) + + ir_targets_torchscript = any(ir == opt for opt in ["torchscript", "ts"]) ir_targets_fx = ir == "fx" ir_targets_dynamo_compile = ir == "dynamo_compile" ir_targets_fx_ts_compat = ir == "fx_ts_compat" @@ -59,23 +63,22 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType: elif module_is_fxable and ir_targets_dynamo_compile: return _IRType.dynamo_compile else: - if ir == "default": - # Options are listed in order of preference - if module_is_tsable: - logging.log( - logging.Level.Info, "ir was set to default, using TorchScript as ir" - ) - return _IRType.ts - elif module_is_fxable: - raise ValueError( - "Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" - ) - # logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx") - # return _IRType.fx - else: - raise ValueError("Module was provided with in an unsupported format") - else: + if ir != "default": raise ValueError("Unknown ir was requested") + # Options are listed in order of preference + if module_is_tsable: + logging.log( + logging.Level.Info, "ir was set to default, using TorchScript as ir" + ) + return _IRType.ts + elif module_is_fxable: + raise ValueError( + "Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" + ) + # logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx") + # return _IRType.fx + else: + raise ValueError("Module was provided with in an unsupported format") def compile( diff --git a/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py index 8359bc62fb..c370bc3990 100644 --- a/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py @@ -82,11 +82,17 @@ def __init__( self.engine = torch.classes.tensorrt.Engine( [ torch.ops.tensorrt.ABI_VERSION(), - self.name + "_engine" if self.name != "" else "tensorrt_engine", + f"{self.name}_engine" + if self.name != "" + else "tensorrt_engine", target_device._to_serialized_rt_device(), serialized_engine, - TorchTensorRTModule._pack_binding_names(self.input_binding_names), - TorchTensorRTModule._pack_binding_names(self.output_binding_names), + TorchTensorRTModule._pack_binding_names( + self.input_binding_names + ), + TorchTensorRTModule._pack_binding_names( + self.output_binding_names + ), ] ) else: @@ -155,10 +161,7 @@ def is_non_tensor(i: Tuple[Any, bool]) -> bool: outputs = torch.ops.tensorrt.execute_engine(list(inputs), self.engine) - if len(outputs) == 1: - return outputs[0] - - return tuple(outputs) + return outputs[0] if len(outputs) == 1 else tuple(outputs) def enable_profiling(self, profiling_results_dir: str = None): """Enable the profiler to collect latency information about the execution of the engine diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index b97079948e..2bd14f7ee2 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -73,21 +73,13 @@ def _pretraced_backend( try: logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - trt_compiled = _compile_module( + return _compile_module( gm, sample_inputs, settings=settings, ) - return trt_compiled except: - if not settings.pass_through_build_failures: - logger.warning( - "TRT conversion failed on the subgraph. See trace above. " - + "Returning GraphModule forward instead.", - exc_info=True, - ) - return gm.forward - else: + if settings.pass_through_build_failures: raise AssertionError( "Halting compilation on build failure since " + "pass_through_build_failures was specified as True. " @@ -95,6 +87,12 @@ def _pretraced_backend( + "halting compilation on engine build failures, " + "specify pass_through_build_failures=False." ) + logger.warning( + "TRT conversion failed on the subgraph. See trace above. " + + "Returning GraphModule forward instead.", + exc_info=True, + ) + return gm.forward def _compile_module( diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 425fb0941e..5925f65538 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -33,7 +33,7 @@ def convert_module( if not isinstance(module_outputs, (list, tuple)): module_outputs = [module_outputs] - output_dtypes = list(output.dtype for output in module_outputs) + output_dtypes = [output.dtype for output in module_outputs] interpreter = TRTInterpreter( module, @@ -62,15 +62,14 @@ def convert_module( output_names=interpreter_result.output_names, ) - else: - from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule - with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) - engine_str = engine_bytes.getvalue() - return TorchTensorRTModule( - serialized_engine=engine_str, - name=name, - input_binding_names=interpreter_result.input_names, - output_binding_names=interpreter_result.output_names, - ) + with io.BytesIO() as engine_bytes: + engine_bytes.write(interpreter_result.engine.serialize()) + engine_str = engine_bytes.getvalue() + return TorchTensorRTModule( + serialized_engine=engine_str, + name=name, + input_binding_names=interpreter_result.input_names, + output_binding_names=interpreter_result.output_names, + ) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 4d82bf4be5..d3da48a55f 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -15,10 +15,10 @@ logger = logging.getLogger(__name__) -DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set( +DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = { _get_qualified_name(to_replace.new_operator) for to_replace in SUBSTITUTION_REGISTRY.values() -) +} class TRTPartitioner(CapabilityBasedPartitioner): @@ -59,7 +59,7 @@ def __init__( def propose_partitions(self) -> List[Partition]: # Propose partitions using the default, then refine the results initial_proposed_partitions = super().propose_partitions() - partitions = {i: part for i, part in enumerate(initial_proposed_partitions)} + partitions = dict(enumerate(initial_proposed_partitions)) # For each partition, determine whether or not the number of computational operators # exceeds the threshold, and if not, remove that partition @@ -99,8 +99,7 @@ def propose_partitions(self) -> List[Partition]: def partition_and_fuse(self) -> GraphModule: partitions = self.propose_partitions() - fused_gm = self.fuse_partitions(partitions) - return fused_gm + return self.fuse_partitions(partitions) class TorchTensorRTOperatorSupport(OperatorSupport): diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py index 8a47fc04d2..52eb13cc66 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py @@ -23,10 +23,11 @@ class Substitution: ] +# Dictionary mapping module to Substitution instance # Dictionary mapping module to Substitution instance SUBSTITUTION_REGISTRY: Dict[ Union[Type[torch.nn.Module], Callable], Substitution -] = dict() +] = {} def register_substitution( diff --git a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py index 57c4a93e62..b0d93b8db5 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py @@ -70,11 +70,8 @@ def einsum_insertion_fn( 1 <= len(inputs) <= 2 ), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors" - # Ensure the input is formatted as an equation and - new_node = gm.graph.call_function( + return gm.graph.call_function( torch.ops.tensorrt.einsum, args=(equation, inputs), kwargs=node.kwargs, ) - - return new_node diff --git a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py index 020d3a0ca9..710419b60d 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/substitutions/maxpool1d.py @@ -77,8 +77,7 @@ def maxpool1d_insertion_fn( node: torch.fx.Node, submodule: torch.nn.Module, ) -> torch.fx.Node: - # Defines insertion function for new node - new_node = gm.graph.call_function( + return gm.graph.call_function( torch.ops.tensorrt.maxpool1d, args=node.args, kwargs={ @@ -90,8 +89,6 @@ def maxpool1d_insertion_fn( }, ) - return new_node - # 4. The Accelerated Implementation # diff --git a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py index 2af251adbc..b8c6705445 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py @@ -9,6 +9,7 @@ class TestTRTModuleNextCompilation(TestCase): def test_trt_module_next_full_support(self): + class FullySupportedMultiOp(torch.nn.Module): def forward(self, x, y): out = x - y @@ -53,10 +54,11 @@ def forward(self, x, y): max_diff, 0, DECIMALS_OF_AGREEMENT, - f"TRT outputs don't match with the original model.", + "TRT outputs don't match with the original model.", ) def test_trt_module_next_partial_support(self): + class PartiallySupportedMultiOp(torch.nn.Module): def forward(self, x, y): out = x - y @@ -121,12 +123,13 @@ def forward(self, x, y): max_diff, 0, DECIMALS_OF_AGREEMENT, - f"TRT outputs don't match with the original model.", + "TRT outputs don't match with the original model.", ) class TestCompilationOptions(TestCase): def test_trt_specific_options(self): + class SupportedMultiOp(torch.nn.Module): def forward(self, x, y): out = x - y @@ -165,7 +168,7 @@ def forward(self, x, y): max_diff, 0, DECIMALS_OF_AGREEMENT, - f"TRT outputs don't match with the original model.", + "TRT outputs don't match with the original model.", ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py b/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py index d947c955e0..6624dd8a54 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py @@ -41,13 +41,16 @@ def forward(self, x, y): ) def test_lowering_alias_replacement(self): + + + class Alias(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def forward(self, x): - y = torch.ops.aten.alias.default(x) - return y + return torch.ops.aten.alias.default(x) + # Operations expected to be removed in the traced graph after decompositions unexpected_ops = {torch.ops.aten.alias.default} @@ -70,13 +73,16 @@ def forward(self, x): ) def test_lowering_rsqrt(self): + + + class Rsqrt(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) def forward(self, x): - y = torch.ops.aten.rsqrt.default(x) - return y + return torch.ops.aten.rsqrt.default(x) + # Operations expected to be removed in the traced graph after decompositions expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.reciprocal.default} @@ -112,6 +118,7 @@ def forward(self, x): ) def test_lowering_addmm(self): + class AddMM(torch.nn.Module): def forward(self, x, y, z): return torch.addmm(x, y, z, beta=16, alpha=5) @@ -176,7 +183,7 @@ def forward(self, x, y, z): max_diff, 0, DECIMALS_OF_AGREEMENT, - f"AddMM TRT outputs don't match with the original model.", + "AddMM TRT outputs don't match with the original model.", ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py index fb5430b384..0fb14ce077 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_partitioning.py @@ -24,6 +24,9 @@ def forward(self, x, y): ) def test_partition_fully_supported_multi_op(self): + + + class FullySupportedMultiOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -32,8 +35,8 @@ def forward(self, x, y): sum_ = torch.ops.aten.sub.Tensor(x, y) concat_ = torch.ops.aten.cat.default(x, sum_) relu_ = torch.ops.aten.relu.default(concat_) - pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) - return pow_ + return torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2) @@ -44,6 +47,9 @@ def forward(self, x, y): ) def test_partition_partially_supported_multi_op(self): + + + class PartiallySupportedMultiOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -53,8 +59,8 @@ def forward(self, x, y): sum_2 = torch.ops.aten.add.Tensor(x, sum_1) sum_ = np.sum(sum_1) + np.sum(sum_2) relu_ = torch.ops.aten.relu.default(sum_) - pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) - return pow_ + return torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) partitioned_graph = partition(deepcopy(fx_graph), min_block_size=2) @@ -65,6 +71,9 @@ def forward(self, x, y): ) def test_partition_partially_supported_with_torch_executed_ops(self): + + + class PartiallySupportedMultiOp(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -74,8 +83,8 @@ def forward(self, x, y): sum_2 = torch.ops.aten.add.Tensor(x, sum_1) sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2) relu_ = torch.ops.aten.relu.default(sum_) - pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2) - return pow_ + return torch.ops.aten.pow.Tensor_Scalar(relu_, 2) + unexpected_ops = {torch.ops.aten.add.Tensor} diff --git a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py index da44d6e826..025f4c922a 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_pre_aot_lowering.py @@ -6,6 +6,7 @@ class TestMaxPool1D(TestCase): def test_pre_aot_lowering_maxpool1d(self): + class MaxPool1D(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -47,12 +48,15 @@ def forward(self, x): max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) self.assertAlmostEqual( - max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model." + max_diff, + 0, + "Maxpool1d TRT outputs don't match with the original model.", ) class TestEinsum(TestCase): def test_pre_aot_lowering_einsum(self): + class Einsum(torch.nn.Module): def forward(self, x, y): return torch.einsum("ij,ji->ij", x, y) @@ -93,7 +97,7 @@ def forward(self, x, y): max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results)) self.assertAlmostEqual( - max_diff, 0, f"Einsum TRT outputs don't match with the original model." + max_diff, 0, "Einsum TRT outputs don't match with the original model." ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py b/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py index 17df523ab8..5e6b9dc01f 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py @@ -6,6 +6,7 @@ class TestFakeTensors(TestCase): def test_lowering_mul_int(self): + class MulInt(torch.nn.Module): def forward(self, x): return x * 7 @@ -52,11 +53,12 @@ def forward(self, x): self.assertAlmostEqual( max_diff, 0, - msg=f"MulInt TRT outputs don't match with the original model.", + msg="MulInt TRT outputs don't match with the original model.", ) torch._dynamo.reset() def test_lowering_add_float(self): + class AddFloat(torch.nn.Module): def forward(self, x): return x + 84.0 @@ -104,7 +106,7 @@ def forward(self, x): self.assertAlmostEqual( max_diff, 0, - msg=f"AddFloat TRT outputs don't match with the original model.", + msg="AddFloat TRT outputs don't match with the original model.", ) torch._dynamo.reset() diff --git a/py/torch_tensorrt/dynamo/backend/test/utils.py b/py/torch_tensorrt/dynamo/backend/test/utils.py index 7c679b7d4d..10069f89bc 100644 --- a/py/torch_tensorrt/dynamo/backend/test/utils.py +++ b/py/torch_tensorrt/dynamo/backend/test/utils.py @@ -109,7 +109,7 @@ def same_output_format(trt_output, torch_output, enforce_tensor_type=True): for key in trt_output.keys() ) ) - elif isinstance(trt_output, set) or isinstance(trt_output, frozenset): + elif isinstance(trt_output, (set, frozenset)): raise AssertionError( "Unsupported output type 'set' encountered in output format check." ) diff --git a/py/torch_tensorrt/dynamo/backend/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py index 23a1cd4795..4d0a06926b 100644 --- a/py/torch_tensorrt/dynamo/backend/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -27,29 +27,13 @@ def prepare_inputs( return inputs elif isinstance(inputs, list): - prepared_input = list() - - for input_obj in inputs: - prepared_input.append(prepare_inputs(input_obj)) - - return prepared_input - + return [prepare_inputs(input_obj) for input_obj in inputs] elif isinstance(inputs, tuple): - prepared_input = list() - - for input_obj in inputs: - prepared_input.append(prepare_inputs(input_obj)) - + prepared_input = [prepare_inputs(input_obj) for input_obj in inputs] return tuple(prepared_input) elif isinstance(inputs, dict): - prepared_input = dict() - - for key, input_obj in inputs.items(): - prepared_input[key] = prepare_inputs(input_obj) - - return prepared_input - + return {key: prepare_inputs(input_obj) for key, input_obj in inputs.items()} else: raise ValueError( f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index a29cee509d..61e226717c 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -59,8 +59,7 @@ def __init__( self.network = self.builder.create_network(flag) - missing_ops = self.validate_conversion() - if missing_ops: + if missing_ops := self.validate_conversion(): warnings.warn( "Interpretation will fail due to missing operations \n" + "\n".join(f"{i}" for i in missing_ops) @@ -75,7 +74,7 @@ def __init__( self._output_names: List[str] = [] self._itensor_to_tensor_meta: Dict[ trt.tensorrt.ITensor, TensorMetadata - ] = dict() + ] = {} # Data types for TRT Module output Tensors self.output_dtypes = output_dtypes @@ -242,7 +241,7 @@ def run( _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") builder_config.max_aux_streams = max_aux_streams if version_compatible: - _LOGGER.info(f"Using version compatible") + _LOGGER.info("Using version compatible") builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE) if optimization_level is not None: _LOGGER.info(f"Using optimization level {optimization_level}") @@ -386,24 +385,23 @@ def output(self, target, args, kwargs): ) for i, output in enumerate(outputs): - if any( - op_name in output.name.split("_") - for op_name in ( - "eq", - "gt", - "lt", - "or", - "xor", - "and", - "not", - "ne", - "isinf", - "any", + output_bool = any( + ( + op_name in output.name.split("_") + for op_name in ( + "eq", + "gt", + "lt", + "or", + "xor", + "and", + "not", + "ne", + "isinf", + "any", + ) ) - ): - output_bool = True - else: - output_bool = False + ) name = f"output{i}" output.name = name self.network.mark_output(output) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py index 7f67e8abbf..8b23a1bb35 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py @@ -79,6 +79,7 @@ def from_input(cls, input_obj: Input) -> "InputTensorSpec": """ assert isinstance(input_obj, Input) input_spec = None + dtype = input_obj.torch_dtype if isinstance(input_obj.shape, dict): min_shape = input_obj.shape["min_shape"] opt_shape = input_obj.shape["opt_shape"] @@ -89,18 +90,14 @@ def from_input(cls, input_obj: Input) -> "InputTensorSpec": dyn_shape.append(min) else: dyn_shape.append(-1) - dtype = input_obj.torch_dtype - input_spec = cls( + return cls( shape=dyn_shape, dtype=dtype, shape_ranges=[(min_shape, opt_shape, max_shape)], ) else: shape = input_obj.shape - dtype = input_obj.torch_dtype - input_spec = cls(shape=shape, dtype=dtype) - - return input_spec + return cls(shape=shape, dtype=dtype) @classmethod def from_tensors_with_dynamic_batch_size( @@ -145,7 +142,12 @@ def from_tensors_with_dynamic_batch_size( ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." shape = list(tensor.shape) shape[batch_dim] = -1 - shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] + shape_ranges: List[ShapeRange] = [ + tuple( + tuple(shape[:batch_dim] + [bs] + shape[batch_dim + 1 :]) + for bs in batch_size_range + ) + ] * opt_profile_replica input_specs.append( cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) ) @@ -166,16 +168,8 @@ def to_random_tensor(self, id=1): @staticmethod def create_inputs_from_specs(input_specs: Iterable["InputTensorSpec"]): - inputs = [] - for spec in input_specs: - inputs.append(spec.to_random_tensor()) - - return inputs + return [spec.to_random_tensor() for spec in input_specs] @staticmethod def create_inputs_from_max_specs(input_specs: Iterable["InputTensorSpec"]): - inputs = [] - for spec in input_specs: - inputs.append(spec.to_random_tensor(2)) - - return inputs + return [spec.to_random_tensor(2) for spec in input_specs] diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py index 0fd3777254..2b2c8b3e5d 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py @@ -109,10 +109,14 @@ def graph_optimization_pass(self) -> PassManager: passes = [ wrapper(self._trace_func, self._input), ] - for p in self.lower_setting.customized_fuse_pass.passes: - passes.append(wrapper(p, self._input)) - for p in self.lower_setting.lower_basic_fuse_pass.passes: - passes.append(wrapper(p, self._input)) + passes.extend( + wrapper(p, self._input) + for p in self.lower_setting.customized_fuse_pass.passes + ) + passes.extend( + wrapper(p, self._input) + for p in self.lower_setting.lower_basic_fuse_pass.passes + ) if ( hasattr(self.lower_setting, "lower_precision") and self.lower_setting.lower_precision is LowerPrecision.FP16 @@ -123,20 +127,25 @@ def graph_optimization_pass(self) -> PassManager: passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) passes.append(inplace_wrapper(common_subexpression_elimination)) - passes.append( - inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) + passes.extend( + ( + inplace_wrapper( + lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input) + ), + fix_reshape_batch_dim, + ) ) - passes.append(fix_reshape_batch_dim) - return PassManager.build_from_passlist(passes) def graph_optimization_pass_aten(self) -> PassManager: - passes = [] - - for p in self.lower_setting.customized_fuse_pass.passes: - passes.append(wrapper(p, self._input)) - for p in self.lower_setting.lower_basic_fuse_pass.passes: - passes.append(wrapper(p, self._input)) + passes = [ + wrapper(p, self._input) + for p in self.lower_setting.customized_fuse_pass.passes + ] + passes.extend( + wrapper(p, self._input) + for p in self.lower_setting.lower_basic_fuse_pass.passes + ) # TODO fix this pass for aten graph # if ( # hasattr(self.lower_setting, "lower_precision") @@ -277,16 +286,14 @@ def build_trt_lower_pipeline( ) self._additional_input = additional_input - passes = [] + passes = [self._default_replace_mutable_op_pass()] - passes.append(self._default_replace_mutable_op_pass()) passes.append(self._const_fold_pass()) passes.append(self.graph_optimization_pass()) passes.append(self._split_pass()) passes.append(self._trt_lower_pass()) - pm = PassManager.build_from_passlist(passes) - return pm + return PassManager.build_from_passlist(passes) def build_aten2trt_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None @@ -305,29 +312,23 @@ def build_aten2trt_lower_pipeline( ) self._additional_input = additional_input - passes = [] - passes.append( - wrapper(self._trace_func, self._input), - ) + passes = [wrapper(self._trace_func, self._input)] passes.append(self.graph_optimization_pass_aten()) passes.append(self._split_pass()) passes.append(self._trt_lower_pass()) - pm = PassManager.build_from_passlist(passes) - return pm + return PassManager.build_from_passlist(passes) def build_default_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: self._input = input self._additional_input = additional_input - passes = [] + passes = [self._default_replace_mutable_op_pass()] - passes.append(self._default_replace_mutable_op_pass()) passes.append(self._const_fold_pass()) passes.append(self.graph_optimization_pass()) passes.append(self._split_pass()) passes.append(self._default_lower_pass()) - pm = PassManager.build_from_passlist(passes) - return pm + return PassManager.build_from_passlist(passes) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py index 7d3046d617..ba7473e671 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py @@ -139,56 +139,55 @@ def _validate_inference(pass_: PassFunc) -> PassFunc: @wraps(pass_) def pass_with_validation( - module: fx.GraphModule, - input: Input, - *args, - **kwargs, - ) -> fx.GraphModule: + module: fx.GraphModule, + input: Input, + *args, + **kwargs, + ) -> fx.GraphModule: if suppress_accuracy_check_failure: return pass_(module, input, *args, **kwargs) - else: - input_tensors = extract_example_tensors_from_input(input, device) - res0 = module(*input_tensors) - processed_module = pass_(module, input, *args, **kwargs) - res1 = processed_module(*input_tensors) - tensor_res_0 = _collect_tensors(res0) - tensor_res_1 = _collect_tensors(res1) - relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE - - for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs2 = {"equal_nan": True} - if rtol: - kwargs2["rtol"] = rtol - if atol: - kwargs2["atol"] = atol - kwargs2[ - "msg" - ] = ( - lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" - ) - # If tensors are on different devices, make sure to compare - # their copies that are on the same device. - if x.get_device() != y.get_device(): - x = x.cpu() - y = y.cpu() - try: + input_tensors = extract_example_tensors_from_input(input, device) + res0 = module(*input_tensors) + processed_module = pass_(module, input, *args, **kwargs) + res1 = processed_module(*input_tensors) + tensor_res_0 = _collect_tensors(res0) + tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE + + for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): + kwargs2 = {"equal_nan": True} + if rtol: + kwargs2["rtol"] = rtol + if atol: + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) + # If tensors are on different devices, make sure to compare + # their copies that are on the same device. + if x.get_device() != y.get_device(): + x = x.cpu() + y = y.cpu() + try: + torch.testing.assert_close(x, y, **kwargs2) + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" + ) torch.testing.assert_close(x, y, **kwargs2) - except Exception as e: - if relax_accuracy_check_failure: - _LOGGER.error(f"{e}") - kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER - kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER - new_atol = kwargs2["atol"] - new_rtol = kwargs2["rtol"] - _LOGGER.info( - f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" - ) - torch.testing.assert_close(x, y, **kwargs2) - return processed_module - else: - raise e - - return processed_module + return processed_module + else: + raise e + + return processed_module return pass_with_validation diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py index 334243fef4..f741c10d0e 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py @@ -71,10 +71,7 @@ def run_test( precision=LowerPrecision.FP32, ): with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - + cuda_inputs = [i.cuda() for i in inputs] mod.eval() if len(expected_ops): self.assert_has_op(mod, expected_ops) @@ -144,10 +141,7 @@ def run_test_custom_compare_results( """ with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - + cuda_inputs = [i.cuda() for i in inputs] mod.eval() if len(expected_ops): self.assert_has_op(mod, expected_ops) @@ -176,10 +170,7 @@ def run_test_custom_compare_results( def run_test_with_error(self, mod, inputs, interpreter, expect_error): with self.assertRaises(expect_error): with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - + cuda_inputs = [i.cuda() for i in inputs] mod.eval() interpreter.run(lower_precision=LowerPrecision.FP32) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py index bfb1964de9..e1dab4b403 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py @@ -22,32 +22,29 @@ def lower_mod_default( ) interpreter_result = interp.run() if use_python_runtime: - res_mod = TRTModule( + return TRTModule( interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names, ) - else: - import io + import io - from torch_tensorrt._Device import Device - from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule + from torch_tensorrt._Device import Device + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule - with io.BytesIO() as engine_bytes: - engine_bytes.write(interpreter_result.engine.serialize()) - engine_str = engine_bytes.getvalue() + with io.BytesIO() as engine_bytes: + engine_bytes.write(interpreter_result.engine.serialize()) + engine_str = engine_bytes.getvalue() - res_mod = TorchTensorRTModule( - engine_str, - name=str(type(mod)), - input_binding_names=interpreter_result.input_names, - output_binding_names=interpreter_result.output_names, - target_device=Device(f"cuda:{torch.cuda.current_device()}"), - # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do - ) - - return res_mod + return TorchTensorRTModule( + engine_str, + name=str(type(mod)), + input_binding_names=interpreter_result.input_names, + output_binding_names=interpreter_result.output_names, + target_device=Device(f"cuda:{torch.cuda.current_device()}"), + # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do + ) class TensorRTMinizerSetting(net_min_base._MinimizerSettingBase): diff --git a/py/torch_tensorrt/fx/converter_registry.py b/py/torch_tensorrt/fx/converter_registry.py index 0167f75f08..cb876337e1 100644 --- a/py/torch_tensorrt/fx/converter_registry.py +++ b/py/torch_tensorrt/fx/converter_registry.py @@ -24,7 +24,4 @@ def register_converter(converter): def disable_converter(converter): return converter - if enabled: - return register_converter - else: - return disable_converter + return register_converter if enabled else disable_converter diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 9532c7072c..42500683df 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -292,7 +292,7 @@ def acc_ops_pad_with_slice_layer( # cast value to TRTensor dt = unified_dtype_converter(input_val.dtype, Frameworks.TORCH) - value = 0 if value == None else value + value = 0 if value is None else value value_const = get_trt_tensor( network, torch.tensor([value], dtype=dt), f"{name}_value" ) @@ -346,12 +346,8 @@ def acc_ops_flatten( ) num_dims = len(input_val.shape) + (1 if network.has_implicit_batch_dimension else 0) - start_dim = get_positive_dim( - cast(int, kwargs["start_dim"] if "start_dim" in kwargs else 0), num_dims - ) - end_dim = get_positive_dim( - cast(int, kwargs["end_dim"] if "end_dim" in kwargs else -1), num_dims - ) + start_dim = get_positive_dim(cast(int, kwargs.get("start_dim", 0)), num_dims) + end_dim = get_positive_dim(cast(int, kwargs.get("end_dim", -1)), num_dims) if network.has_implicit_batch_dimension: assert start_dim != 0, "Can't flatten batch dimension when it's implicit." @@ -426,8 +422,7 @@ def acc_ops_flatten( if i >= start_dim and i <= end_dim: flatten_dim *= s elif i == end_dim + 1: - final_shape.append(flatten_dim) - final_shape.append(s) + final_shape.extend((flatten_dim, s)) else: final_shape.append(s) if end_dim == len(input_val.shape) - 1: @@ -453,7 +448,7 @@ def acc_ops_size( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_t = kwargs["input"] - if type(input_t) == torch.nn.Parameter or type(input_t) == torch.Tensor: + if type(input_t) in [torch.nn.Parameter, torch.Tensor]: if ( not has_dynamic_shape(input_t.shape) and network.has_implicit_batch_dimension @@ -496,7 +491,7 @@ def acc_ops_numel( ) if has_dynamic_shape(input_val.shape): - raise RuntimeError(f"numel does not support dynamic shapes.") + raise RuntimeError("numel does not support dynamic shapes.") numel = np.prod(input_val.shape) layer = network.add_constant((1,), trt.Weights(np.array(numel, dtype=np.float32))) @@ -744,10 +739,7 @@ def acc_ops_softmax( # Used to get dim when dim is None. Copied from PyTorch softmax implementation. def get_softmax_dim(ndim: int) -> int: - if ndim == 0 or ndim == 1 or ndim == 3: - ret = 0 - else: - ret = 1 + ret = 0 if ndim in {0, 1, 3} else 1 return ret if kwargs["dim"] is None: @@ -1278,7 +1270,7 @@ def add_acc_ops_dim_reduce(network, target, args, kwargs, name, reduce_op): new_kwargs["sorted"] = False topk_out0, topk_out1 = acc_ops_topk( - network, target, args, new_kwargs, name + "_topk" + network, target, args, new_kwargs, f"{name}_topk" ) topk_out0.name = f"{name}_topk0" @@ -1297,12 +1289,7 @@ def add_acc_ops_dim_reduce(network, target, args, kwargs, name, reduce_op): input_val = topk_out0 shape = input_val.shape - output_shape = [] - for i, s in enumerate(shape): - if i == dim and s == 1: - continue - output_shape.append(s) - + output_shape = [s for i, s in enumerate(shape) if i != dim or s != 1] shuffle_layer0 = network.add_shuffle(input_val) shuffle_layer0.reshape_dims = tuple(output_shape) set_layer_name(shuffle_layer0, target, f"{name}_shuffle0") @@ -1740,15 +1727,15 @@ def acc_ops_any( ) if input_t.dtype in (trt.float32, trt.float16, trt.int32): - comp_t = torch.zeros(tuple([*input_t.shape])).to( + comp_t = torch.zeros((*input_t.shape,)).to( unified_dtype_converter(input_t.dtype, Frameworks.TORCH) ) comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t") kwargs_new = {"input": input_t, "other": comp_t} - eq_output = acc_ops_eq(network, target, None, kwargs_new, name + "_eq") + eq_output = acc_ops_eq(network, target, None, kwargs_new, f"{name}_eq") kwargs_new = {"input": eq_output} not_output = acc_ops_logical_not( - network, target, None, kwargs_new, name + "_not" + network, target, None, kwargs_new, f"{name}_not" ) else: not_output = input_t @@ -1763,10 +1750,10 @@ def acc_ops_any( } else: kwargs_new = {"input": int_output} - sum_output = acc_ops_sum(network, target, None, kwargs_new, name + "_sum") + sum_output = acc_ops_sum(network, target, None, kwargs_new, f"{name}_sum") # cast int to bool output = type_cast(network, target, f"{name}_cast_bool", sum_output, trt.bool) - output.name = output.name + "_any" + output.name = f"{output.name}_any" return output @@ -1780,7 +1767,7 @@ def acc_ops_fmod( ) -> Union[TRTTensor, Sequence[TRTTensor]]: # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it trunc_div_value = trunc_div( - kwargs["input"], kwargs["other"], network, target, name + "_trunc_div" + kwargs["input"], kwargs["other"], network, target, f"{name}_trunc_div" ) prod_value = add_binary_elementwise_layer( network, @@ -1788,17 +1775,16 @@ def acc_ops_fmod( kwargs["other"], trt.ElementWiseOperation.PROD, target, - name + "_prod", + f"{name}_prod", ) - sub_value = add_binary_elementwise_layer( + return add_binary_elementwise_layer( network, kwargs["input"], prod_value, trt.ElementWiseOperation.SUB, target, - name + "_sub", + f"{name}_sub", ) - return sub_value # T113156424 embedding implemenatation is very limited and shows no usage in hf models due to the indices are int64. @@ -1953,7 +1939,7 @@ def acc_ops_max_poolnd( dilation = extend_attr_to_tuple(kwargs["dilation"], extend_len) ceil_mode = kwargs["ceil_mode"] - if len(stride) == 0 or stride[0] == None: + if len(stride) == 0 or stride[0] is None: stride = kernel_size ones = (1,) * extend_len @@ -1991,7 +1977,7 @@ def acc_ops_squeeze( "of the TensorRT region!" ) - dim = cast(Optional[int], kwargs["dim"] if "dim" in kwargs else None) + dim = cast(Optional[int], kwargs.get("dim", None)) # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic # dim, which is a very rare case. For now we just claim not supporting dim=None. assert dim is not None, "We don't support dim=None right now for squeeze." @@ -2008,11 +1994,7 @@ def acc_ops_squeeze( len(get_dynamic_dims(input_val.shape)) <= 1 ), "Currently more than one dynamic dim for input to squeeze is not supported." - output_shape = [] - for i, s in enumerate(input_val.shape): - if i == dim and s == 1: - continue - output_shape.append(s) + output_shape = [s for i, s in enumerate(input_val.shape) if i != dim or s != 1] layer = network.add_shuffle(input_val) layer.reshape_dims = tuple(output_shape) set_layer_name(layer, target, name) @@ -2287,12 +2269,12 @@ def acc_ops_avg_pool1d( ceil_mode = kwargs["ceil_mode"] count_include_pad = kwargs["count_include_pad"] - if len(stride) == 0 or stride[0] == None: + if len(stride) == 0 or stride[0] is None: stride = kernel_size shuffle_layer = network.add_shuffle(input_val) shuffle_layer.reshape_dims = tuple(input_val.shape) + (1,) - set_layer_name(shuffle_layer, target, name + "_shuffle1") + set_layer_name(shuffle_layer, target, f"{name}_shuffle1") shuffle_out = shuffle_layer.get_output(0) layer = network.add_pooling_nd( @@ -2301,7 +2283,7 @@ def acc_ops_avg_pool1d( layer.stride_nd = stride + (1,) layer.padding_nd = padding + (0,) - layer.average_count_excludes_padding = False if count_include_pad else True + layer.average_count_excludes_padding = not count_include_pad set_layer_name(layer, target, name) if ceil_mode: layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP @@ -2309,7 +2291,7 @@ def acc_ops_avg_pool1d( output = layer.get_output(0) layer = network.add_shuffle(output) layer.reshape_dims = tuple(output.shape)[:-1] - set_layer_name(layer, target, name + "_shuffle2") + set_layer_name(layer, target, f"{name}_shuffle2") return layer.get_output(0) @@ -2337,7 +2319,7 @@ def acc_ops_avg_pool2d( count_include_pad = kwargs["count_include_pad"] divisor_override = kwargs["divisor_override"] - if len(stride) == 0 or stride[0] == None: + if len(stride) == 0 or stride[0] is None: stride = kernel_size if divisor_override: @@ -2348,7 +2330,7 @@ def acc_ops_avg_pool2d( ) layer.stride = stride layer.padding = padding - layer.average_count_excludes_padding = False if count_include_pad else True + layer.average_count_excludes_padding = not count_include_pad set_layer_name(layer, target, name) if ceil_mode: @@ -2422,11 +2404,11 @@ def acc_ops_slice_tensor( raise RuntimeError( f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" ) - dim = dim - 1 - else: - if dynamic_shape: - # Check whether slice target dim is dynamic shape dim - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + else: + dim = dim - 1 + elif dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" start_int = cast(int, kwargs["start"]) stop_int = cast(int, kwargs["stop"]) @@ -2478,9 +2460,7 @@ def acc_ops_expand_tensor( inshape = tuple(input_val.shape) shape = tuple(shape) start = tuple([0] * ranks) - stride = tuple( - [int(i == o) for i, o in zip(inshape, shape)] - ) # stride == 1 if dimensions match, 0 otherwise + stride = tuple(int(i == o) for i, o in zip(inshape, shape)) layer = network.add_slice(input_val, start=start, shape=shape, stride=stride) set_layer_name(layer, target, name) return layer.get_output(0) @@ -2660,10 +2640,9 @@ def acc_ops_split( if network.has_implicit_batch_dimension: assert dim != 0, "Can't split on batch dim when it's implicit!" dim -= 1 - else: - if dynamic_shape > 0: - # Check whether slice target dim is dynamic shape dim - assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + elif dynamic_shape > 0: + # Check whether slice target dim is dynamic shape dim + assert input_val.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" split_size = cast(int, kwargs["split_size"]) start = [0] * len(input_val.shape) @@ -2787,8 +2766,7 @@ def add_clamp(network, input, val, op, name): acc_ops_clamp_trt = network.add_constant( acc_ops_clamp_shape, acc_ops_clamp_tensor ).get_output(0) - layer = network.add_elementwise(input, acc_ops_clamp_trt, op) - return layer + return network.add_elementwise(input, acc_ops_clamp_trt, op) @tensorrt_converter(acc_ops.clamp) @@ -2874,7 +2852,7 @@ def num_slice_types(slices): """ Gather the number of slice in getitem slices. """ - return sum(1 for s in slices if isinstance(s, slice) or isinstance(s, int)) + return sum(1 for s in slices if isinstance(s, (slice, int))) def slice_to_trt_params(py_slice, dim_size): """ @@ -3155,7 +3133,7 @@ def acc_ops_quantize_per_tensor( scale_layer = network.add_constant( (1,), trt.Weights(np.ascontiguousarray([float(q_scale)], dtype=np.float32)) ) - scale_layer.name = input_val.name + ".per_tensor_quant.scale" + scale_layer.name = f"{input_val.name}.per_tensor_quant.scale" scale = scale_layer.get_output(0) # assert trt.__version__ > "8.0", "Explicit quantize op is only supported in " # "TensorRT 8.0 or above, current TensorRT version:" + trt.__version__ diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 49bf401f58..95b1bde0fc 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -93,9 +93,7 @@ def get_positive_dim(dim: int, dim_size: int) -> int: Returns: A positive integer that represent the same dimension as the given dim. """ - if dim < 0: - return dim % dim_size - return dim + return dim % dim_size if dim < 0 else dim def set_layer_name( @@ -210,10 +208,7 @@ def has_dynamic_shape(shape: Shape) -> bool: Returns: A boolean value indicates whether there's dynamic dim in the shape. """ - count = 0 - for s in shape: - count += 1 if s == -1 else 0 - return count + return sum(1 if s == -1 else 0 for s in shape) def get_axes_for_reduce_op( @@ -304,13 +299,15 @@ def get_trt_tensor( if isinstance(input_val, bool): input_val = int(input_val) - if isinstance(input_val, torch.Tensor) and ( - input_val.dtype == torch.bool or input_val.dtype == torch.int64 - ): + if isinstance(input_val, torch.Tensor) and input_val.dtype in [ + torch.bool, + torch.int64, + ]: input_val = input_val.to(torch.int32) - elif isinstance(input_val, np.ndarray) and ( - input_val.dtype == np.bool_ or input_val.dtype == np.int64 - ): + elif isinstance(input_val, np.ndarray) and input_val.dtype in [ + np.bool_, + np.int64, + ]: input_val = input_val.to(np.int32) if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)): @@ -563,7 +560,7 @@ def add_binary_elementwise_layer( assert len(lhs_val.shape) >= len( rhs_val.shape ), f"{lhs_val.shape} >= {rhs_val.shape}" - elif not is_lhs_trt_tensor and is_rhs_trt_tensor: + elif not is_lhs_trt_tensor: assert len(rhs_val.shape) >= len( lhs_val.shape ), f"{rhs_val.shape} >= {lhs_val.shape}" @@ -574,7 +571,7 @@ def add_binary_elementwise_layer( layer = network.add_elementwise(lhs_val, rhs_val, op_type) set_layer_name(layer, target, name) output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ + output.name = f"{output.name}_{target.__name__}" return output @@ -622,7 +619,7 @@ def add_unary_layer( layer = network.add_unary(input_val, operation_type) set_layer_name(layer, target, name) output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ + output.name = f"{output.name}_{target.__name__}" return layer.get_output(0) @@ -837,7 +834,7 @@ def trunc_div( target, f"{name}_floor_div", ) - output = add_binary_elementwise_layer( + return add_binary_elementwise_layer( network, abs_floor_output, sign_output, @@ -846,8 +843,6 @@ def trunc_div( f"{name}_output", ) - return output - def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, diff --git a/py/torch_tensorrt/fx/converters/impl/convolution.py b/py/torch_tensorrt/fx/converters/impl/convolution.py index a0e7537fde..360564b162 100644 --- a/py/torch_tensorrt/fx/converters/impl/convolution.py +++ b/py/torch_tensorrt/fx/converters/impl/convolution.py @@ -51,7 +51,7 @@ def convNd( "dim": -1, } input_val = acc_ops_converters.acc_ops_unsqueeze( - network, target, tuple(), kwargs, name + "_unsqueeze" + network, target, tuple(), kwargs, f"{name}_unsqueeze" ) # Process bias terms @@ -77,7 +77,7 @@ def convNd( "dim": -1, } weight = acc_ops_converters.acc_ops_unsqueeze( - network, target, tuple(), kwargs, name + "_unsqueeze_weight" + network, target, tuple(), kwargs, f"{name}_unsqueeze_weight" ) elif isinstance(weight, torch.Tensor): @@ -139,7 +139,7 @@ def convNd( "dim": -1, } result = acc_ops_converters.acc_ops_squeeze( - network, target, tuple(), kwargs, name + "_squeeze" + network, target, tuple(), kwargs, f"{name}_squeeze" ) return result diff --git a/py/torch_tensorrt/fx/converters/nn_ops_converters.py b/py/torch_tensorrt/fx/converters/nn_ops_converters.py index 2aacaa9a68..2dfcca66bd 100644 --- a/py/torch_tensorrt/fx/converters/nn_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/nn_ops_converters.py @@ -221,5 +221,9 @@ def quantized_conv_relu2d(network, submod, args, kwargs, layer_name): ) return activation.relu( - network, submod._get_name(), SourceIR.NN, layer_name + "_relu", conv_out + network, + submod._get_name(), + SourceIR.NN, + f"{layer_name}_relu", + conv_out, ) diff --git a/py/torch_tensorrt/fx/converters/transformation.py b/py/torch_tensorrt/fx/converters/transformation.py index 62cfef8453..5e8a6aa9db 100644 --- a/py/torch_tensorrt/fx/converters/transformation.py +++ b/py/torch_tensorrt/fx/converters/transformation.py @@ -33,8 +33,7 @@ def torch_flatten(network, target, args, kwargs, name): if i < start_dim: new_shape.append(dim) elif i > end_dim: - new_shape.append(flatten_dim) - new_shape.append(dim) + new_shape.extend((flatten_dim, dim)) else: flatten_dim *= dim diff --git a/py/torch_tensorrt/fx/diagnostics.py b/py/torch_tensorrt/fx/diagnostics.py index 0d78513a81..0869233503 100644 --- a/py/torch_tensorrt/fx/diagnostics.py +++ b/py/torch_tensorrt/fx/diagnostics.py @@ -103,13 +103,12 @@ def write(self, file_name: str, data: WriteObj): res, err = _res_or_err(data) if err: to_write = err.encode("utf-8") + elif isinstance(res, str): + to_write = res.encode("utf-8") + elif isinstance(res, bytes): + to_write = res else: - if isinstance(res, str): - to_write = res.encode("utf-8") - elif isinstance(res, bytes): - to_write = res - else: - raise TypeError(f"Unknown data type: {type(res)}") + raise TypeError(f"Unknown data type: {type(res)}") self._write(file_name, to_write) except Exception as e: # Log the error and swallow the exception, as this should not diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index d7ef976fba..e8c0ec4560 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -61,8 +61,7 @@ def __init__( flag |= EXPLICIT_PRECISION self.network = self.builder.create_network(flag) - missing_ops = self.validate_conversion() - if missing_ops: + if missing_ops := self.validate_conversion(): warnings.warn( "Interpretation will fail due to missing operations \n" + "\n".join(f"{i}" for i in missing_ops) @@ -77,7 +76,7 @@ def __init__( self._output_names: List[str] = [] self._itensor_to_tensor_meta: Dict[ trt.tensorrt.ITensor, TensorMetadata - ] = dict() + ] = {} def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: @@ -360,24 +359,23 @@ def output(self, target, args, kwargs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") for i, output in enumerate(outputs): - if any( - op_name in output.name.split("_") - for op_name in ( - "eq", - "gt", - "lt", - "or", - "xor", - "and", - "not", - "ne", - "isinf", - "any", + output_bool = any( + ( + op_name in output.name.split("_") + for op_name in ( + "eq", + "gt", + "lt", + "or", + "xor", + "and", + "not", + "ne", + "isinf", + "any", + ) ) - ): - output_bool = True - else: - output_bool = False + ) name = f"output{i}" output.name = name self.network.mark_output(output) diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 8128fc1760..3f8bb27aed 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -22,18 +22,8 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None): if not isinstance(inputs, torch.Tensor) and len(inputs) > 1: bs = inputs[0].size(0) batch_dims = None - if not all(x.size(0) == bs for x in inputs): + if any(x.size(0) != bs for x in inputs): batch_dims = InputTensorSpec.find_batch_size_dim(inputs) - return InputTensorSpec.from_tensors_with_dynamic_batch_size( - inputs, - ( - 0, - lower_setting.max_batch_size, - lower_setting.max_batch_size, - ), - lower_setting.opt_profile_replica, - batch_dims, - ) else: batch_dims = [] @@ -53,16 +43,17 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None): f"Failed to find batch dimension because shapes are the same, {i.shape}" ) - return InputTensorSpec.from_tensors_with_dynamic_batch_size( - inputs, - ( - 0, - lower_setting.max_batch_size, - lower_setting.max_batch_size, - ), - lower_setting.opt_profile_replica, - batch_dims, - ) + + return InputTensorSpec.from_tensors_with_dynamic_batch_size( + inputs, + ( + 0, + lower_setting.max_batch_size, + lower_setting.max_batch_size, + ), + lower_setting.opt_profile_replica, + batch_dims, + ) class InputTensorSpec(NamedTuple): @@ -169,7 +160,12 @@ def from_tensors_with_dynamic_batch_size( batch_dim ), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}." shape[batch_dim] = -1 - shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item] + shape_ranges: List[ShapeRange] = [ + tuple( + tuple(shape[:batch_dim] + [bs] + shape[batch_dim + 1 :]) + for bs in batch_size_range + ) + ] * opt_profile_replica input_specs.append( cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges) ) @@ -177,7 +173,6 @@ def from_tensors_with_dynamic_batch_size( return input_specs @classmethod - # pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any` def find_batch_size_dim(cls, inputs: Any) -> []: if isinstance(inputs, torch.Tensor) or len(inputs) <= 1: return [0] @@ -207,12 +202,10 @@ def find_batch_size_dim(cls, inputs: Any) -> []: bs_dim = [] for i in inputs: - # Default batch size dim = -1, indicate no batch_size - dim = -1 - for index, val in enumerate(i.shape): - if val == batch_size: - dim = index - break + dim = next( + (index for index, val in enumerate(i.shape) if val == batch_size), + -1, + ) bs_dim.append(dim) return bs_dim @@ -231,16 +224,8 @@ def to_random_tensor(self, id=1): @staticmethod def create_inputs_from_specs(input_specs: Iterable["InputTensorSpec"]): - inputs = [] - for spec in input_specs: - inputs.append(spec.to_random_tensor()) - - return inputs + return [spec.to_random_tensor() for spec in input_specs] @staticmethod def create_inputs_from_max_specs(input_specs: Iterable["InputTensorSpec"]): - inputs = [] - for spec in input_specs: - inputs.append(spec.to_random_tensor(2)) - - return inputs + return [spec.to_random_tensor(2) for spec in input_specs] diff --git a/py/torch_tensorrt/fx/observer.py b/py/torch_tensorrt/fx/observer.py index 3742bd2840..df8e4e5a9b 100644 --- a/py/torch_tensorrt/fx/observer.py +++ b/py/torch_tensorrt/fx/observer.py @@ -53,13 +53,8 @@ def _add(): try: yield finally: - try: + with contextlib.suppress(ValueError): self._get_callbacks().remove(callback) - except ValueError: - # Callback should be in the callbacks list. I'm just being - # extra cautious here. I don't want it to throw and affect - # business logic. - pass return _add() diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index e98a9371c5..2f773649ac 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -170,7 +170,7 @@ def forward(self, x): continue weight_t = weight.transpose(0, 1) - weight_t_name = "weight_t_tensor_" + str(counter) + weight_t_name = f"weight_t_tensor_{str(counter)}" gm.register_buffer(weight_t_name, weight_t) counter += 1 @@ -206,8 +206,8 @@ def trt_transposed_linear( def check_permute(node: torch.fx.Node): ranks = len(node.meta["tensor_meta"].shape) - permutation = list(i % ranks for i in node.kwargs["permutation"]) # type: ignore[union-attr] - allowed_permutation = list(i for i in range(ranks)) + permutation = [i % ranks for i in node.kwargs["permutation"]] + allowed_permutation = list(range(ranks)) allowed_permutation[-1] = ranks - 2 allowed_permutation[-2] = ranks - 1 return permutation == allowed_permutation @@ -336,20 +336,20 @@ def list_gen( dim: int, ): if start_node: - if end_node: - concat_list = [start_node, input_node, end_node] - else: - concat_list = [start_node, input_node] - else: - if end_node: - concat_list = [input_node, end_node] - else: - concat_list = [input_node] - if len(concat_list) > 1: - concat_node = gm.graph.call_function(torch.cat, args=(concat_list, dim)) + concat_list = ( + [start_node, input_node, end_node] + if end_node + else [start_node, input_node] + ) + elif end_node: + concat_list = [input_node, end_node] else: - concat_node = concat_list[0] - return concat_node + concat_list = [input_node] + return ( + gm.graph.call_function(torch.cat, args=(concat_list, dim)) + if len(concat_list) > 1 + else concat_list[0] + ) def transform_setitem(gm: torch.fx.GraphModule, input: Input): @@ -501,9 +501,7 @@ def get_reshape_batch_size_as_node(maybe_reshape: fx.Node) -> Optional[fx.Node]: if not shape: return None batch_size = shape[0] - if isinstance(batch_size, fx.Node): - return batch_size - return None + return batch_size if isinstance(batch_size, fx.Node) else None def get_reshape_batch_size_inferred_source( batch_size_node: fx.Node, @@ -656,9 +654,10 @@ def remove_dtype_and_to_pattern( if "input" in next_node.kwargs else next_node.args[0] ) - if len(node.users) == 1 and ( - next_node.target == acc_ops.to_dtype or next_node.target == "to" - ): + if len(node.users) == 1 and next_node.target in [ + acc_ops.to_dtype, + "to", + ]: next_node.replace_all_uses_with(input) mod.graph.erase_node(next_node) mod.graph.erase_node(node) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py index 00063c3e21..ea29dbdefb 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py @@ -40,7 +40,7 @@ def replace_inplace_ops( torch.ops.aten.add_.Tensor: torch.ops.aten.add.Tensor, } for n in module.graph.nodes: - if n.op == "call_function" and n.target in map_func.keys(): + if n.op == "call_function" and n.target in map_func: modified = True node = n with module.graph.inserting_after(node): @@ -68,19 +68,18 @@ def replace_native_layernorm_with_layernorm( and n.target == torch.ops.aten.native_layer_norm.default ): for v in n.users: - if v.op == "call_function" and v.target == operator.getitem: - if v.args[1] != 0: - raise RuntimeError( - f"Got args[{v.args[1]}]!!\n" - "layernorm can only generate output (args[0]), " - "not mean (args[1]) or std (args[2])!" - ) - new_op = torch.ops.aten.layer_norm.default - new_args = (*n.args, True) # cudnn_enable=True - modified = True - else: + if v.op != "call_function" or v.target != operator.getitem: continue + if v.args[1] != 0: + raise RuntimeError( + f"Got args[{v.args[1]}]!!\n" + "layernorm can only generate output (args[0]), " + "not mean (args[1]) or std (args[2])!" + ) + new_op = torch.ops.aten.layer_norm.default + new_args = (*n.args, True) # cudnn_enable=True + modified = True with module.graph.inserting_after(v): new_node = module.graph.create_node( "call_function", @@ -178,10 +177,10 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph elif n.target == torch.ops.aten.max_pool3d_with_indices.default: new_op = torch.ops.aten.max_pool3d new_args = n.args - elif ( - n.target == torch.ops.aten.native_batch_norm.default - or n.target == torch.ops.aten._native_batch_norm_legit.default - ): + elif n.target in [ + torch.ops.aten.native_batch_norm.default, + torch.ops.aten._native_batch_norm_legit.default, + ]: new_op = torch.ops.aten.batch_norm new_args = list(n.args) new_args.append(False) @@ -310,7 +309,7 @@ def replace_builtin_ops( def aten_compose_getitem_slice(input, list_args): - for _, args in enumerate(list_args): + for args in list_args: input = torch.ops.aten.slice.Tensor(input, *args) return input @@ -333,10 +332,7 @@ def match_pattern(module, node): ): node = next(iter(node.users)) holder.append(node) - if len(holder) == 1: - return (False,) - else: - return (True, holder) + return (False, ) if len(holder) == 1 else (True, holder) return (False,) modified = False @@ -380,8 +376,7 @@ def aten_compose_bmm_2d(flat_args_1, flat_args_2): ) view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_3, sym_size_4]) bmm = torch.ops.aten.bmm.default(view, view_1) - view_2 = torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) - return view_2 + return torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) def aten_compose_bmm_3d(flat_args_1, flat_args_2): @@ -399,8 +394,7 @@ def aten_compose_bmm_3d(flat_args_1, flat_args_2): ) view_1 = torch.ops.aten.view.default(expand_1, [sym_size, sym_size_3, sym_size_4]) bmm = torch.ops.aten.bmm.default(view, view_1) - view_2 = torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) - return view_2 + return torch.ops.aten.view.default(bmm, [sym_size, sym_size_1, sym_size_4]) def compose_bmm( @@ -472,38 +466,33 @@ def compose_chunk( """ def match_pattern(module, node): - if node.op == "call_function" and node.target in (torch.ops.aten.split.Tensor,): - div = node.args[1] - input = node.args[0] - if isinstance(div, int): - return (False,) - if div.target != operator.floordiv: - return (False,) - else: - div_const = div.args[1] - sub = div.args[0] - if sub.target != operator.sub: - return (False,) - else: - add = sub.args[0] - if add.target != operator.add: - return (False,) - else: - add_const = add.args[1] - if add_const != div_const: - return (False,) - symsize = add.args[0] - if symsize.target != torch.ops.aten.sym_size: - return (False,) - else: - symsize_input = symsize.args[0] - dim = symsize.args[1] - if symsize_input != input: - return (False,) - - return (True, div_const, dim) - else: + if node.op != "call_function" or node.target not in ( + torch.ops.aten.split.Tensor, + ): + return (False,) + + div = node.args[1] + input = node.args[0] + if isinstance(div, int): + return (False,) + if div.target != operator.floordiv: + return (False,) + div_const = div.args[1] + sub = div.args[0] + if sub.target != operator.sub: + return (False,) + add = sub.args[0] + if add.target != operator.add: + return (False,) + add_const = add.args[1] + if add_const != div_const: + return (False,) + symsize = add.args[0] + if symsize.target != torch.ops.aten.sym_size: return (False,) + symsize_input = symsize.args[0] + dim = symsize.args[1] + return (False, ) if symsize_input != input else (True, div_const, dim) modified = False for node in module.graph.nodes: diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 6e6b40d42f..0978d39620 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -107,10 +107,14 @@ def graph_optimization_pass(self) -> PassManager: passes = [ wrapper(self._trace_func, self._input), ] - for p in self.lower_setting.customized_fuse_pass.passes: - passes.append(wrapper(p, self._input)) - for p in self.lower_setting.lower_basic_fuse_pass.passes: - passes.append(wrapper(p, self._input)) + passes.extend( + wrapper(p, self._input) + for p in self.lower_setting.customized_fuse_pass.passes + ) + passes.extend( + wrapper(p, self._input) + for p in self.lower_setting.lower_basic_fuse_pass.passes + ) if ( hasattr(self.lower_setting, "lower_precision") and self.lower_setting.lower_precision is LowerPrecision.FP16 @@ -121,20 +125,25 @@ def graph_optimization_pass(self) -> PassManager: passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) passes.append(inplace_wrapper(common_subexpression_elimination)) - passes.append( - inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) + passes.extend( + ( + inplace_wrapper( + lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input) + ), + fix_reshape_batch_dim, + ) ) - passes.append(fix_reshape_batch_dim) - return PassManager.build_from_passlist(passes) def graph_optimization_pass_aten(self) -> PassManager: - passes = [] - - for p in self.lower_setting.customized_fuse_pass.passes: - passes.append(wrapper(p, self._input)) - for p in self.lower_setting.lower_basic_fuse_pass.passes: - passes.append(wrapper(p, self._input)) + passes = [ + wrapper(p, self._input) + for p in self.lower_setting.customized_fuse_pass.passes + ] + passes.extend( + wrapper(p, self._input) + for p in self.lower_setting.lower_basic_fuse_pass.passes + ) # TODO fix this pass for aten graph # if ( # hasattr(self.lower_setting, "lower_precision") @@ -267,45 +276,37 @@ def build_trt_lower_pipeline( ) -> PassManager: self._input = input self._additional_input = additional_input - passes = [] + passes = [self._default_replace_mutable_op_pass()] - passes.append(self._default_replace_mutable_op_pass()) passes.append(self._const_fold_pass()) passes.append(self.graph_optimization_pass()) passes.append(self._split_pass()) passes.append(self._trt_lower_pass()) - pm = PassManager.build_from_passlist(passes) - return pm + return PassManager.build_from_passlist(passes) def build_aten2trt_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: self._input = input self._additional_input = additional_input - passes = [] - passes.append( - wrapper(self._trace_func, self._input), - ) + passes = [wrapper(self._trace_func, self._input)] passes.append(self.graph_optimization_pass_aten()) passes.append(self._split_pass()) passes.append(self._trt_lower_pass()) - pm = PassManager.build_from_passlist(passes) - return pm + return PassManager.build_from_passlist(passes) def build_default_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: self._input = input self._additional_input = additional_input - passes = [] + passes = [self._default_replace_mutable_op_pass()] - passes.append(self._default_replace_mutable_op_pass()) passes.append(self._const_fold_pass()) passes.append(self.graph_optimization_pass()) passes.append(self._split_pass()) passes.append(self._default_lower_pass()) - pm = PassManager.build_from_passlist(passes) - return pm + return PassManager.build_from_passlist(passes) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 0b8578ffba..855073c5a8 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -257,11 +257,11 @@ def _run_alternative_batch_size(pass_: PassFunc) -> PassFunc: @wraps(pass_) def pass_with_validation( - module: fx.GraphModule, - input: Input, - *args, - **kwargs, - ) -> fx.GraphModule: + module: fx.GraphModule, + input: Input, + *args, + **kwargs, + ) -> fx.GraphModule: _run_alternative_batch_size = ( ALTERNATIVE_BATCH_SIZE_OVERRIDE if ALTERNATIVE_BATCH_SIZE_OVERRIDE is not None @@ -283,7 +283,7 @@ def pass_with_validation( ) return pass_(module, input, *args, **kwargs) - if not all(len(x.shape) > 0 for x in input): + if any(len(x.shape) <= 0 for x in input): _LOGGER.info( "Skip run_alternative_batch_size: some input tensor(s) are scalar" ) @@ -471,8 +471,7 @@ def _need_cast(self, node: Node, run_result) -> None: f"Encountered node: {node.format_node()} need dtype cast to float32." ) self.need_cast_to_float32.append(node) - # Process node that will be used as final output - elif "output" in set(i.name for i in node.users.keys()): + elif "output" in {i.name for i in node.users.keys()}: if run_result.dtype not in (torch.int32, torch.int64): _LOGGER.info( f"Encountered node: {node.format_node()} need dtype cast to bfloat16." diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py index b7b4137e42..ccdf870304 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py @@ -102,14 +102,17 @@ def forward(self, x, y): ) def test_reshape_with_dynamic_shape_mul(self): + + + class TestModule(torch.nn.Module): def forward(self, x, y, z): t = 8000 a = torch.reshape(x, [-1, t, 64]) b = torch.reshape(y, [-1, t, 64]) c = torch.reshape(z, [-1, t, 64]) - d = a + b + c - return d + return a + b + c + input_specs = [ InputTensorSpec( diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py index 1f3a39d836..5ee2a9b647 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py @@ -9,17 +9,20 @@ class TestTypeAsConverter(AccTestCase): def test_device_fp32(self): + + + class Type_as(torch.nn.Module): def __init__(self): super().__init__() self.a = torch.randn(2, 2) def forward(self, x): - b = self.a.type_as(x) - return b + return self.a.type_as(x) + # self.a = self.a.type_as(x) # error is throw + # return self.a + - # self.a = self.a.type_as(x) # error is throw - # return self.a input = torch.randn(2, 2).cuda() inputs = [ diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_clamp_numerical_limits_to_fp16.py b/py/torch_tensorrt/fx/test/passes/test_fix_clamp_numerical_limits_to_fp16.py index 457a9e415a..56c9ace2f5 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fix_clamp_numerical_limits_to_fp16.py +++ b/py/torch_tensorrt/fx/test/passes/test_fix_clamp_numerical_limits_to_fp16.py @@ -23,13 +23,16 @@ def setUp(self): torch.manual_seed(0) def test_clamp_numerical_limits_to_fp16(self): + + + class TestModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): - y = torch.clamp(x + x, min=-1e8, max=1e8) - return y + return torch.clamp(x + x, min=-1e8, max=1e8) + module = TestModule() inputs = [torch.rand(3, 2, 1)] diff --git a/py/torch_tensorrt/fx/test/passes/test_graph_opts.py b/py/torch_tensorrt/fx/test/passes/test_graph_opts.py index c91c456eb3..587c9a6e29 100644 --- a/py/torch_tensorrt/fx/test/passes/test_graph_opts.py +++ b/py/torch_tensorrt/fx/test/passes/test_graph_opts.py @@ -172,10 +172,13 @@ def forward(self, a, b, c): ) def test_common_subexpression_elimination_string_arg(self): + + + class TestModule(torch.nn.Module): def forward(self, a): - x = _test_op(["foo", "bar"], a) - return x + return _test_op(["foo", "bar"], a) + self._test_opt_with_module( module=TestModule(), diff --git a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py index 5dc7d8572c..ab580d9b74 100644 --- a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py +++ b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py @@ -15,6 +15,7 @@ class TestFx2TrtPasses(TestCase): def test_remove_duplicate_output_args(self): + class Sub(nn.Module): def forward(self, x): return (x, x) @@ -28,11 +29,12 @@ def forward(self, x): a_res = self.a(x) return a_res[0] + a_res[1] + + class Tracer(fx.Tracer): def is_leaf_module(self, m, qn): - if isinstance(m, Sub): # don't trace into - return True - return False + return isinstance(m, Sub) + top = Top() ttop = fx.GraphModule(top, Tracer().trace(top), "top") diff --git a/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py b/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py index cb7ff8f906..d391921ff9 100644 --- a/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py @@ -8,11 +8,15 @@ class TestTransformSetitem(AccTestCase): def test_setitem1d(self): + + + class TestModule(torch.nn.Module): def forward(self, x, y): - y[0:2] = x + y[:2] = x return y + inputs = [torch.randn(2), torch.randn(3)] m = TestModule() diff --git a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py index a78aaedf2e..21f5d68a2d 100644 --- a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py +++ b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py @@ -52,8 +52,7 @@ def lower_to_trt(model, inputs, shape_ranges): model, input_specs, explicit_batch_dimension=True, explicit_precision=True ) result = interp.run(lower_precision=LowerPrecision.INT8) - trt_mod = TRTModule(result.engine, result.input_names, result.output_names) - return trt_mod + return TRTModule(result.engine, result.input_names, result.output_names) class TestConvertFxDoNotUse(QuantizationTestCase): @@ -454,21 +453,21 @@ def test_conv_relu_module(self): conv3d_input = torch.rand(1, 3, 10, 10, 10) conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input} + + class ConvNdModule(torch.nn.Module): def __init__(self, dim, has_relu=False, f_relu=False): super().__init__() self.conv = conv_module[dim](3, 3, 3).float() if has_relu: - if f_relu: - self.relu = F.relu - else: - self.relu = torch.nn.ReLU() + self.relu = F.relu if f_relu else torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): return self.relu(self.conv(x)) + # just testing conv2d since conv1d and conv3d are not supported in fx2trt for dim, has_relu, f_relu, is_qat in itertools.product( [1, 2], [True, False], [True, False], [True, False] @@ -494,21 +493,22 @@ def forward(self, x): ) def test_linear_relu_module(self): + + + class LinearModule(torch.nn.Module): def __init__(self, has_relu=False, f_relu=False): super().__init__() self.linear = torch.nn.Linear(5, 10).float() if has_relu: - if f_relu: - self.relu = F.relu - else: - self.relu = torch.nn.ReLU() + self.relu = F.relu if f_relu else torch.nn.ReLU() else: self.relu = torch.nn.Identity() def forward(self, x): return self.relu(self.linear(x)) + linear_input = torch.rand(8, 5) shape_ranges = [((1, 5), (5, 5), (10, 5))] diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index 74715d6030..6559abc37a 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -943,7 +943,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: is_relu = node.target == acc_ops.relu self.assertTrue(is_sigmoid or is_relu) else: - self.assertTrue(node.op == "placeholder" or node.op == "output") + self.assertTrue(node.op in ["placeholder", "output"]) self.assertTrue(torch.equal(m(input), traced(input))) @@ -1537,10 +1537,13 @@ def test_hardtanh(self): ) def test_hardswish(self): + + + class TestModule(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: - y = nn.functional.hardswish(x) - return y + return nn.functional.hardswish(x) + m = TestModule() x = torch.randn(3, 4, 5) @@ -1551,7 +1554,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ph_x = node elif node.op == "call_function" and node.target == acc_ops.hardsigmoid: hardsigmoid_y = node - self.assertEqual(node.kwargs["input"], ph_x) + self.assertEqual(hardsigmoid_y.kwargs["input"], ph_x) elif node.op == "call_function" and node.target == acc_ops.mul: res_y = node self.assertEqual(node.kwargs["input"], hardsigmoid_y) @@ -2258,7 +2261,7 @@ def forward(self, a: List[torch.Tensor]) -> torch.Tensor: self.assertEqual(str(node.target), "a") ph = node elif node.op == "call_function" and node.target == acc_ops.getitem: - self.assertTrue(node.kwargs["idx"] == 0 or node.kwargs["idx"] == 1) + self.assertTrue(node.kwargs["idx"] in [0, 1]) if node.kwargs["idx"] == 0: getitem_0 = node else: @@ -2303,9 +2306,7 @@ def forward(self, a: Dict[str, torch.Tensor]) -> torch.Tensor: self.assertEqual(str(node.target), "a") ph = node elif node.op == "call_function" and node.target == acc_ops.getitem: - self.assertTrue( - node.kwargs["idx"] == "foo" or node.kwargs["idx"] == "bar" - ) + self.assertTrue(node.kwargs["idx"] in ["foo", "bar"]) if node.kwargs["idx"] == "foo": getitem_0 = node else: diff --git a/py/torch_tensorrt/fx/test/trt_lower/test_fx2trt_lower.py b/py/torch_tensorrt/fx/test/trt_lower/test_fx2trt_lower.py index 7868ca40ad..ccb91d8447 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/test_fx2trt_lower.py +++ b/py/torch_tensorrt/fx/test/trt_lower/test_fx2trt_lower.py @@ -56,24 +56,30 @@ def forward(self, x): lower(TestModule(), [torch.randn([2, 2])]) def test_replace_mutable_op(self): + + + class TestModule(torch.nn.Module): def forward(self, x, y): xf = x.fill_(100) yf = y.fill_(200) - c = torch.cat([xf, yf], dim=1) - return c + return torch.cat([xf, yf], dim=1) + lower = Lowerer.create(LowerSetting()) mod_traced = fx.symbolic_trace(TestModule()) lower(mod_traced, [torch.randn(3, 4), torch.randn(3, 4)]) def test_replace_mutable_op_dont_apply(self): + + + class TestModule(torch.nn.Module): def forward(self, x): s = x + 1 t = s.fill_(5) - p = s + t - return p + return s + t + mod_traced = fx.symbolic_trace(TestModule()) old_code = mod_traced.code @@ -86,12 +92,15 @@ def forward(self, x): self.assertEqual(old_code, new_code) def test_replace_mutable_op_do_apply(self): + + + class TestModule(torch.nn.Module): def forward(self, x): s = x + 1 t = s.fill_(5) # s not used afterwards - p = x + t - return p + return x + t + mod_traced = fx.symbolic_trace(TestModule()) old_code = mod_traced.code diff --git a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py index 96584c59bd..59588b9718 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py +++ b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py @@ -78,12 +78,14 @@ def test_demo(self): ==> c ==> """ + + class SimpleModule(torch.nn.Module): def forward(self, a): b = torch.sin(a) c = torch.cos(a) - d = b + c - return d + return b + c + mod = acc_tracer.trace(SimpleModule(), [torch.randn(2, 3)]) @@ -632,10 +634,12 @@ def test_splitter(splitter): def test_decline_if_input_dtype(self): operator_support = create_trt_operator_support() + + class TestModule(torch.nn.Module): def forward(self, a): - b = torch.relu(a) - return b + return torch.relu(a) + test_mod = TestModule().cuda().eval() x = torch.randn(2, 3) @@ -692,8 +696,7 @@ def forward(self, a): c = torch.relu(a) d = torch.cos(a) e = b + c - f = e - d - return f + return e - d def test_split_complex_graph_1(self): mod = acc_tracer.trace(self.TestModule(), [torch.randn(2, 3)]) @@ -806,8 +809,7 @@ def forward(self, x): c = torch.cos(a) d = b2 + c - e = torch.sigmoid(d) - return e + return torch.sigmoid(d) def test_split_non_tensor_edges_1(self): test_data = torch.randn(2, 3) @@ -1009,6 +1011,8 @@ def test_acc_nodes_finder_1(self): """ # Make a return non-tensor data + + class TestModule(torch.nn.Module): def forward(self, x, y, z): a1 = x.size() @@ -1017,9 +1021,8 @@ def forward(self, x, y, z): b = y + a1 c = z - a1 - d = b + c + return b + c - return d module_nn = TestModule() module_fx = torch.fx.symbolic_trace(module_nn) @@ -1175,18 +1178,20 @@ def test_splitter(splitter): test_splitter(splitter) def test_exclude_support_node_by_name(self): + + + class TestModule(torch.nn.Module): def forward(self, a): b = torch.sin(a) c = torch.relu(b) d = torch.cos(c) e = torch.sigmoid(d) - f = torch.tanh(e) - return f + return torch.tanh(e) + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) - # Set sin, cos and tanh as acc node and split with settings class CustomOpSupport(op_support.OperatorSupport): _support_dict = { "acc_ops.sin": None, diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index 6d883a4f62..1ebc9fb478 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -72,10 +72,7 @@ def run_test( precision=LowerPrecision.FP32, ): with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - + cuda_inputs = [i.cuda() for i in inputs] mod.eval() if len(expected_ops): self.assert_has_op(mod, expected_ops) @@ -145,10 +142,7 @@ def run_test_custom_compare_results( """ with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - + cuda_inputs = [i.cuda() for i in inputs] mod.eval() if len(expected_ops): self.assert_has_op(mod, expected_ops) @@ -177,10 +171,7 @@ def run_test_custom_compare_results( def run_test_with_error(self, mod, inputs, interpreter, expect_error): with self.assertRaises(expect_error): with torch.no_grad(): - cuda_inputs = [] - for i in inputs: - cuda_inputs.append(i.cuda()) - + cuda_inputs = [i.cuda() for i in inputs] mod.eval() interpreter.run(lower_precision=LowerPrecision.FP32) diff --git a/py/torch_tensorrt/fx/tools/engine_layer_visualize.py b/py/torch_tensorrt/fx/tools/engine_layer_visualize.py index cecd1ecb20..c8f8529264 100644 --- a/py/torch_tensorrt/fx/tools/engine_layer_visualize.py +++ b/py/torch_tensorrt/fx/tools/engine_layer_visualize.py @@ -89,28 +89,27 @@ def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node) return for input_name, input_type in zip(layer.input_names, layer.input_types): - if input_name not in output_name2node: - if input_name in reformat_layers: - from_node = pydot.Node( - input_name, - label="{reformatter|kernel: Reformat\\l|tactic: 0\\l}", - **style, - ) - graph.add_node(from_node) - if reformat_layers[input_name][0] in output_name2node: - graph.add_edge( - pydot.Edge( - output_name2node[reformat_layers[input_name][0]], - from_node, - label=f"{reformat_layers[input_name][0]}\\l{reformat_layers[input_name][1]}\\l", - ) - ) - else: - _LOGGER.info(f"Missing node {input_name}") - from_node = input_name - else: + if input_name in output_name2node: from_node = output_name2node[input_name] + elif input_name in reformat_layers: + from_node = pydot.Node( + input_name, + label="{reformatter|kernel: Reformat\\l|tactic: 0\\l}", + **style, + ) + graph.add_node(from_node) + if reformat_layers[input_name][0] in output_name2node: + graph.add_edge( + pydot.Edge( + output_name2node[reformat_layers[input_name][0]], + from_node, + label=f"{reformat_layers[input_name][0]}\\l{reformat_layers[input_name][1]}\\l", + ) + ) + else: + _LOGGER.info(f"Missing node {input_name}") + from_node = input_name edge_name = input_name.replace(">", "\\>") graph.add_edge( pydot.Edge( @@ -184,8 +183,7 @@ def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node) } dot_graphs: List[Any] = [] - i = 0 - for layers, reformat_layers in graphs: + for i, (layers, reformat_layers) in enumerate(graphs): output_name2node = {} layer_name2node = {} dot_graph = pydot.Dot("Layer Graph") @@ -202,8 +200,6 @@ def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node) ) dot_graph.write_raw(f"EngineLayers_{i}.dot") - i += 1 - if args.profile_file != "": est_reformat_time = 0.0 est_total_time = 0.0 diff --git a/py/torch_tensorrt/fx/tools/model_packager.py b/py/torch_tensorrt/fx/tools/model_packager.py index b86c21e809..157e05cfed 100644 --- a/py/torch_tensorrt/fx/tools/model_packager.py +++ b/py/torch_tensorrt/fx/tools/model_packager.py @@ -64,10 +64,8 @@ def generate_standalone_repro( if k in line: sub_string = line.split("(")[0].split()[-1] if sub_string.startswith(k): - mod = sub_string.replace(k + "_", "") - import_modules.add( - "from " + v + " import " + mod + " as " + sub_string - ) + mod = sub_string.replace(f"{k}_", "") + import_modules.add(f"from {v} import {mod} as {sub_string}") for mod in sorted(import_modules): lines.append(mod) diff --git a/py/torch_tensorrt/fx/tools/timing_cache_utils.py b/py/torch_tensorrt/fx/tools/timing_cache_utils.py index 4580843e98..0c89612567 100644 --- a/py/torch_tensorrt/fx/tools/timing_cache_utils.py +++ b/py/torch_tensorrt/fx/tools/timing_cache_utils.py @@ -9,7 +9,7 @@ def __init__(self, timing_cache_prefix: str = "", save_timing_cache=False): # Setting timing cache for TRTInterpreter tc = os.environ.get("TRT_TIMING_CACHE_PREFIX", "") timing_cache_prefix_name = timing_cache_prefix - if not timing_cache_prefix and tc: + if not timing_cache_prefix_name and tc: timing_cache_prefix_name = tc self.timing_cache_prefix_name = timing_cache_prefix_name diff --git a/py/torch_tensorrt/fx/tools/trt_minimizer.py b/py/torch_tensorrt/fx/tools/trt_minimizer.py index 1c14b289cf..aa008aa644 100644 --- a/py/torch_tensorrt/fx/tools/trt_minimizer.py +++ b/py/torch_tensorrt/fx/tools/trt_minimizer.py @@ -30,7 +30,7 @@ def lower_mod_default( engine_bytes.write(interpreter_result.engine.serialize()) engine_str = engine_bytes.getvalue() - res_mod = TorchTensorRTModule( + return TorchTensorRTModule( engine_str, name=str(type(mod)), input_binding_names=interpreter_result.input_names, @@ -39,12 +39,11 @@ def lower_mod_default( # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do ) else: - res_mod = TRTModule( + return TRTModule( interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names, ) - return res_mod class TensorRTMinizerSetting(net_min_base._MinimizerSettingBase): diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index 6fcb40c0d8..70bc09a061 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -21,15 +21,14 @@ def create_trt_operator_support( exclude_support_node_name: set = (), ) -> ops.OperatorSupportBase: """Creates an `OperatorSupportBase` instance used for TRT splitting purpose.""" - # Create an `OperatorSupport` that declares a node supported if it - # finds a registered TRT converter. - support_dict: Dict[str, None] = {} - for k in CONVERTERS.keys(): - if use_implicit_batch_dim: - if k not in NO_IMPLICIT_BATCH_DIM_SUPPORT.keys(): - support_dict[get_acc_ops_name(k)] = None - elif k not in NO_EXPLICIT_BATCH_DIM_SUPPORT.keys(): - support_dict[get_acc_ops_name(k)] = None + support_dict: Dict[str, None] = { + get_acc_ops_name(k): None + for k in CONVERTERS.keys() + if use_implicit_batch_dim + and k not in NO_IMPLICIT_BATCH_DIM_SUPPORT.keys() + or not use_implicit_batch_dim + and k not in NO_EXPLICIT_BATCH_DIM_SUPPORT.keys() + } supported_if_converter_registered = ops.OperatorSupport(support_dict=support_dict) return ops.chain( diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py index 1271b6f30c..f21487f37b 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py @@ -152,9 +152,10 @@ def _get_dup_signature_tuples(fn: Callable) -> List[Tuple[str, str]]: Helper that inspects the arg signature of `fn` and returns a list of tuples, where each tuple is a pair of duplicated names which is used for arg_replacement_tuples. """ - sig_tuples: List[Tuple[str, str]] = [] - for param in inspect.signature(inspect.unwrap(fn)).parameters: - sig_tuples.append((param, param)) + sig_tuples: List[Tuple[str, str]] = [ + (param, param) + for param in inspect.signature(inspect.unwrap(fn)).parameters + ] return sig_tuples @@ -327,9 +328,7 @@ def get_normalized_kwargs( if final_arg_is_varg: var_arg_idx = len(arg_replacement_tuples) - 1 new_kwarg_name = arg_replacement_tuples[var_arg_idx][1] - rest_of_args = [] - for i in range(var_arg_idx, len(node.args)): - rest_of_args.append(node.args[i]) + rest_of_args = [node.args[i] for i in range(var_arg_idx, len(node.args))] new_kwargs[new_kwarg_name] = rest_of_args return new_kwargs diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 1ed25d66f1..6a7a2808cd 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -82,9 +82,7 @@ def flatten(*, input, start_dim=0, end_dim=-1): ) @register_acc_op def squeeze(*, input, dim=None): - if dim is None: - return input.squeeze() - return input.squeeze(dim=dim) + return input.squeeze() if dim is None else input.squeeze(dim=dim) @register_acc_op_mapping(op_and_target=("call_function", nn.functional.embedding)) @@ -383,15 +381,17 @@ def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: torch.Tensor, torch.nn.parameter.Parameter, ], f"Expected torch.Tensor type for {input_obj_type}" - assert ( - attr_name == "shape" or attr_name == "device" or attr_name == "dtype" - ), f"Only supporting shape, device and dtype getattr for now, not {attr_name}" - if attr_name == "shape": - func = size - elif attr_name == "device": + assert attr_name in [ + "shape", + "device", + "dtype", + ], f"Only supporting shape, device and dtype getattr for now, not {attr_name}" + if attr_name == "device": func = device elif attr_name == "dtype": func = dtype + elif attr_name == "shape": + func = size with node.graph.inserting_before(node): size_node = node.graph.call_function(func, kwargs={"input": input_obj}) size_node.meta = node.meta.copy() @@ -516,7 +516,7 @@ def repeat_interleave_mapper(node: torch.fx.Node, _: nn.Module): input_node = node.kwargs["input"] repeats = cast(int, node.kwargs["repeats"]) dim = node.kwargs["dim"] - if not (type(repeats) is int): + if type(repeats) is not int: logger.info( "Not mapping repeat_interleave to an acc op. We currently only support `repeat_interleave` with int repeats" ) @@ -545,10 +545,7 @@ def repeat_interleave_mapper(node: torch.fx.Node, _: nn.Module): ) new_shape = [] if dim is not None: - if dim < 0: - repeat_dim = dim + rank - else: - repeat_dim = dim + repeat_dim = dim + rank if dim < 0 else dim size_node = node.graph.create_node( "call_function", size, @@ -2509,9 +2506,7 @@ def slice_tensor(*, input, dim, start, stop, step): slices: List[slice] = [slice(None, None, None) for _ in range(dim)] slices.append(slc) else: - slices = [Ellipsis, slc] # type: ignore[list-item] - slices.extend([slice(None, None, None) for _ in range(-dim - 1)]) - + slices = [Ellipsis, slc, *[slice(None, None, None) for _ in range(-dim - 1)]] return input[tuple(slices)] @@ -2713,13 +2708,9 @@ def custom_tensor_to_mapper(node: torch.fx.Node, _: nn.Module): raise RuntimeError(f"We currently do not support to({meta_type})") elif isinstance(dest, torch.device): # only device is set, dtype=None - if dest_other is None: - dest_device = dest - # device and dtype are both set - else: + if dest_other is not None: dest_dtype = dest_other - dest_device = dest - # only dtype is set + dest_device = dest else: dest_dtype = dest @@ -2764,7 +2755,7 @@ def custom_torch_add_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Nod "input": node.kwargs["other"], "other": node.kwargs["alpha"], }, - name=node.name + "_mul_alpha", + name=f"{node.name}_mul_alpha", ) other_node.meta = node.meta else: diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py index 96411246f0..40517f46dc 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py @@ -91,7 +91,7 @@ def run_node_with_xl_weights(self, n: torch.fx.Node) -> Any: xl_weights instead of just the ones treated here. """ - op = n.target.__module__ + "." + n.target.__name__ + op = f"{n.target.__module__}.{n.target.__name__}" if op.endswith("acc_ops.int_nbit_split_embedding_codegen_lookup_function"): output_dtype_int = n.kwargs["output_dtype"] diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index bc8c613fee..1d3f920c4c 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -327,16 +327,12 @@ def create_node( """ ## Hacky way to decide inplace ops - if type(target) != str: - name_target = target.__name__ - else: - name_target = target - + name_target = target.__name__ if type(target) != str else target allow_list = ["and_", "or_"] # python operator.and_, operator.or_ if ( name_target[-1] == "_" and name_target[0] != "_" - and not (name_target in allow_list) + and name_target not in allow_list and kind != "placeholder" ): raise RuntimeError( @@ -490,9 +486,12 @@ def _remove_exceptions(gm: torch.fx.GraphModule) -> bool: changed = False for node in reversed(gm.graph.nodes): if node.op == "call_module" and ( - isinstance(gm.get_submodule(node.target), ConditionalExceptionWrapper) - or isinstance( - gm.get_submodule(node.target), ConditionalExceptionBoolCondWrapper + isinstance( + gm.get_submodule(node.target), + ( + ConditionalExceptionWrapper, + ConditionalExceptionBoolCondWrapper, + ), ) ): gm.graph.erase_node(node) @@ -574,8 +573,7 @@ def _replace_transpose_last_dims(gm: torch.fx.GraphModule): if node.op == "call_method" and node.target == "transpose": if len(node.args) != 3: continue - changed = _replace_transpose_last_dims_impl(node) - if changed: + if changed := _replace_transpose_last_dims_impl(node): gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py index 75418034cb..efc964ed62 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py @@ -20,7 +20,7 @@ def get_target_from_module(mod: torch.nn.Module, target: str): """ Gets `target` from `mod` and returns it. If `target` is empty then returns `mod.` """ - if target == "": + if not target: return mod target_atoms = target.split(".") @@ -171,7 +171,7 @@ def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) while hasattr(mod_traced, name): match = re.match(r"(.*)_(\d+)$", name) if match is None: - name = name + "_1" + name = f"{name}_1" else: base, num = match.group(1, 2) name = f"{base}_{int(num) + 1}" @@ -184,9 +184,7 @@ def map_tensor_metadata(a: Any, fn: Callable): Map some `fn` to `a`, where `a` is either a TensorMetadata, or else a tuple/list/dict recursively containing TensorMetadata. """ - if isinstance(a, int): - return 1 - elif a is None: + if isinstance(a, int) or a is None: return 1 elif isinstance(a, TensorMetadata): return fn(a) @@ -203,11 +201,10 @@ def map_tensor_metadata(a: Any, fn: Callable): def get_tensor_meta(node: torch.fx.Node) -> TensorMetadata: - tensor_meta = node.meta.get("tensor_meta") - - if not tensor_meta: + if tensor_meta := node.meta.get("tensor_meta"): + return tensor_meta + else: raise RuntimeError( f"Node has no tensor metadata associated with it! " f"Check that shape propagation has run. {node.format_node()}" ) - return tensor_meta diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tracer.py index f3ba9abe3f..6703be4bf8 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/tracer.py @@ -47,7 +47,8 @@ def wrap_with_proxy(e, proxy): if isinstance(real_out, tuple): return tuple( - [wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)] + wrap_with_proxy(e, proxy_out[idx]) + for idx, e in enumerate(real_out) ) elif isinstance(real_out, list): return [wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)] @@ -149,22 +150,22 @@ def _module_getattr(self, attr, attr_val, parameter_proxy_cache): return attr_val def create_arg(self, a: Any): - if isinstance(a, torch.nn.Parameter): - for n, p in self.root.named_parameters(): - if a is p: - return self.create_node("get_attr", n, (), {}) - qualname: Optional[str] = None - - i = 0 - while True: - qualname = f"_param_constant{i}" - if not hasattr(self.root, qualname): - break - i += 1 - setattr(self.root, qualname, a) - - return self.create_node("get_attr", qualname, (), {}) - return super().create_arg(a) + if not isinstance(a, torch.nn.Parameter): + return super().create_arg(a) + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + qualname: Optional[str] = None + + i = 0 + while True: + qualname = f"_param_constant{i}" + if not hasattr(self.root, qualname): + break + i += 1 + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) def dispatch_trace( diff --git a/py/torch_tensorrt/fx/trt_module.py b/py/torch_tensorrt/fx/trt_module.py index ab2d9ac348..f69e53b0e0 100644 --- a/py/torch_tensorrt/fx/trt_module.py +++ b/py/torch_tensorrt/fx/trt_module.py @@ -93,10 +93,10 @@ def _check_initialized(self): def _on_state_dict(self, state_dict, prefix, local_metadata): self._check_initialized() - state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) - state_dict[prefix + "input_names"] = self.input_names - state_dict[prefix + "output_names"] = self.output_names - state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size + state_dict[f"{prefix}engine"] = bytearray(self.engine.serialize()) + state_dict[f"{prefix}input_names"] = self.input_names + state_dict[f"{prefix}output_names"] = self.output_names + state_dict[f"{prefix}cuda_graph_batch_size"] = self.cuda_graph_batch_size def _load_from_state_dict( self, @@ -108,14 +108,14 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ): - engine_bytes = state_dict[prefix + "engine"] + engine_bytes = state_dict[f"{prefix}engine"] logger = trt.Logger() runtime = trt.Runtime(logger) self.engine = runtime.deserialize_cuda_engine(engine_bytes) - self.input_names = state_dict[prefix + "input_names"] - self.output_names = state_dict[prefix + "output_names"] + self.input_names = state_dict[f"{prefix}input_names"] + self.output_names = state_dict[f"{prefix}output_names"] self._initialize() def __getstate__(self): @@ -212,10 +212,7 @@ def forward(self, *inputs): bindings, torch.cuda.current_stream().cuda_stream ) - if len(outputs) == 1: - return outputs[0] - - return tuple(outputs) + return outputs[0] if len(outputs) == 1 else tuple(outputs) def enable_profiling(self, profiler: "trt.IProfiler" = None): """ diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index e70fc862d0..892dd75834 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -63,13 +63,13 @@ class LowerPrecision(Enum): @staticmethod def from_str(label: str) -> Optional["LowerPrecision"]: - if label in ("fp32", "float32", "float", "torch.float32"): + if label in {"fp32", "float32", "float", "torch.float32"}: return LowerPrecision.FP32 - elif label in ("fp16", "float16", "half", "torch.half", "torch.float16"): + elif label in {"fp16", "float16", "half", "torch.half", "torch.float16"}: return LowerPrecision.FP16 elif label in ("int8"): return LowerPrecision.INT8 - elif label in ("bf16", "bfloat16", "torch.bfloat16"): + elif label in {"bf16", "bfloat16", "torch.bfloat16"}: return LowerPrecision.BF16 else: return None @@ -101,7 +101,7 @@ def unified_dtype_converter( elif dtype in (np.float32, torch.float32, trt.float32): return DataTypeEquivalence[trt.float32][to] else: - raise TypeError("%s is not a supported dtype" % dtype) + raise TypeError(f"{dtype} is not a supported dtype") def get_dynamic_dims(shape: Shape) -> List[int]: @@ -117,13 +117,7 @@ def get_dynamic_dims(shape: Shape) -> List[int]: A list of integers contains all the dynamic dimensions in the given shape """ - dynamic_dims = [] - - for i, s in enumerate(shape): - if s == -1: - dynamic_dims.append(i) - - return dynamic_dims + return [i for i, s in enumerate(shape) if s == -1] def proxytensor_trace(mod, inputs): diff --git a/py/torch_tensorrt/ptq.py b/py/torch_tensorrt/ptq.py index f60dd74b52..051dc2ba03 100644 --- a/py/torch_tensorrt/ptq.py +++ b/py/torch_tensorrt/ptq.py @@ -32,8 +32,7 @@ def get_batch(self, names): self.current_batch_idx += self.batch_size inputs_gpu = [] if isinstance(batch, list): - for example in batch: - inputs_gpu.append(example.to(self.device).data_ptr()) + inputs_gpu.extend(example.to(self.device).data_ptr() for example in batch) else: inputs_gpu.append(batch.to(self.device).data_ptr()) return inputs_gpu @@ -88,26 +87,20 @@ def __new__(cls, *args, **kwargs): if not isinstance(dataloader, torch.utils.data.DataLoader): log( Level.Error, - "Dataloader : {} is not a valid instance of torch.utils.data.DataLoader".format( - dataloader - ), + f"Dataloader : {dataloader} is not a valid instance of torch.utils.data.DataLoader", ) - if not cache_file: - if use_cache: - log( - Level.Debug, - "Using existing cache_file {} for calibration".format(cache_file), - ) - else: - log(Level.Debug, "Overwriting existing calibration cache file.") - else: + if cache_file: if use_cache: log( Level.Error, "Input cache file is None but use_cache is set to True in INT8 mode.", ) + elif use_cache: + log(Level.Debug, f"Using existing cache_file {cache_file} for calibration") + else: + log(Level.Debug, "Overwriting existing calibration cache file.") # Define attributes and member functions for the calibrator class attribute_mapping = { "data_loader": dataloader, @@ -166,10 +159,7 @@ def __new__(cls, *args, **kwargs): algo_type = kwargs.get("algo_type", CalibrationAlgo.ENTROPY_CALIBRATION_2) if os.path.isfile(cache_file): - log( - Level.Debug, - "Using existing cache_file {} for calibration".format(cache_file), - ) + log(Level.Debug, f"Using existing cache_file {cache_file} for calibration") else: log(Level.Error, "Invalid calibration cache file.") @@ -187,7 +177,11 @@ def __new__(cls, *args, **kwargs): return type( "DataLoaderCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping )() - elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: + elif ( + algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2 + or algo_type != CalibrationAlgo.LEGACY_CALIBRATION + and algo_type == CalibrationAlgo.MINMAX_CALIBRATION + ): return type( "DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping )() @@ -195,10 +189,6 @@ def __new__(cls, *args, **kwargs): return type( "DataLoaderCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping )() - elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: - return type( - "DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping - )() else: log( Level.Error, diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 08f18a22dd..a81794e52a 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -27,11 +27,7 @@ def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt. def _supported_input_size_type(input_size: Any) -> bool: - if isinstance(input_size, torch.Size): - return True - elif isinstance(input_size, tuple): - return True - elif isinstance(input_size, list): + if isinstance(input_size, (torch.Size, tuple, list)): return True else: raise TypeError( @@ -66,7 +62,7 @@ def _parse_op_precision(precision: Any) -> _enums.dtype: def _parse_enabled_precisions(precisions: Any) -> Set: parsed_precisions = set() - if any([isinstance(precisions, type) for type in [list, tuple, set]]): + if any(isinstance(precisions, type) for type in [list, tuple, set]): for p in precisions: parsed_precisions.add(_parse_op_precision(p)) else: @@ -87,18 +83,14 @@ def _parse_device_type(device: Any) -> _enums.DeviceType: elif isinstance(device, _C.DeviceType): return device elif isinstance(device, trt.DeviceType): - if device == trt.DeviceType.DLA: - return _C.DeviceType.DLA - return _C.DeviceType.GPU + return _C.DeviceType.DLA if device == trt.DeviceType.DLA else _C.DeviceType.GPU elif isinstance(device, str): - if device == "gpu" or device == "GPU": + if device in ["gpu", "GPU"]: return _C.DeviceType.GPU - elif device == "dla" or device == "DLA": + elif device in ["dla", "DLA"]: return _C.DeviceType.DLA else: - ValueError( - "Got a device type other than GPU or DLA (type: " + str(device) + ")" - ) + ValueError(f"Got a device type other than GPU or DLA (type: {str(device)})") else: raise TypeError( "Device specification must be of type torch.device, string or torch_tensorrt.DeviceType, but got: " @@ -141,9 +133,8 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback: info = _ts_C.TorchFallback() if "enabled" not in fallback_info: raise KeyError("Enabled is required parameter") - else: - assert isinstance(fallback_info["enabled"], bool) - info.enabled = fallback_info["enabled"] + assert isinstance(fallback_info["enabled"], bool) + info.enabled = fallback_info["enabled"] if "min_block_size" in fallback_info: assert isinstance(fallback_info["min_block_size"], int) info.min_block_size = fallback_info["min_block_size"] @@ -177,9 +168,7 @@ def _parse_input_signature(input_signature: Any, depth: int = 0): input = _parse_input_signature(item, depth + 1) input_list.append(input) return input_list - elif isinstance(input_signature, Input) or isinstance( - input_signature, torch.Tensor - ): + elif isinstance(input_signature, (Input, torch.Tensor)): i = ( Input.from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) @@ -209,13 +198,10 @@ def _parse_input_signature(input_signature: Any, depth: int = 0): "Invalid shape mode detected for input while parsing the input_signature" ) - clone = _internal_input_to_torch_class_input(ts_i._to_internal()) - return clone + return _internal_input_to_torch_class_input(ts_i._to_internal()) else: raise KeyError( - "Input signature contains an unsupported type {}".format( - type(input_signature) - ) + f"Input signature contains an unsupported type {type(input_signature)}" ) @@ -226,15 +212,11 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec: if len(compile_spec["inputs"]) > 0: if not all( - [ - isinstance(i, torch.Tensor) or isinstance(i, Input) - for i in compile_spec["inputs"] - ] + isinstance(i, (torch.Tensor, Input)) + for i in compile_spec["inputs"] ): raise KeyError( - "Input specs should be either torch_tensorrt.Input or torch.Tensor, found types: {}".format( - [type(i) for i in compile_spec["inputs"]] - ) + f'Input specs should be either torch_tensorrt.Input or torch.Tensor, found types: {[type(i) for i in compile_spec["inputs"]]}' ) inputs = [ diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 9dc0731014..7adff57d48 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -137,8 +137,7 @@ def compile( } compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec)) - compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod) - return compiled_module + return torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod) def convert_method_to_trt_engine( diff --git a/py/torch_tensorrt/ts/ts_input.py b/py/torch_tensorrt/ts/ts_input.py index 00055d4f13..5858af9ca9 100644 --- a/py/torch_tensorrt/ts/ts_input.py +++ b/py/torch_tensorrt/ts/ts_input.py @@ -95,11 +95,7 @@ def _to_internal(self) -> _C.Input: internal_in.opt = self.shape internal_in.input_is_dynamic = False - if self.dtype != _enums.dtype.unknown: - self._explicit_set_dtype = True - else: - self._explicit_set_dtype = False - + self._explicit_set_dtype = self.dtype != _enums.dtype.unknown internal_in.dtype = Input._parse_dtype(self.dtype) internal_in._explicit_set_dtype = self._explicit_set_dtype internal_in.format = Input._parse_format(self.format) diff --git a/tests/modules/custom_models.py b/tests/modules/custom_models.py index 327fbc3fb9..ecbb9715e9 100644 --- a/tests/modules/custom_models.py +++ b/tests/modules/custom_models.py @@ -43,7 +43,7 @@ def __init__(self): def forward(self, x): add_list = torch.empty(0).to(x.device) - for i in range(x.shape[1]): + for _ in range(x.shape[1]): add_list = torch.cat((add_list, torch.tensor([x.shape[1]]).to(x.device)), 0) return x + add_list @@ -91,8 +91,7 @@ def forward(self, x, y): mod_list = [x] if x.sum() > y.sum(): mod_list.append(y) - z = torch.cat(mod_list) - return z + return torch.cat(mod_list) # Collection input/output models @@ -101,8 +100,7 @@ def __init__(self): super(StandardTensorInput, self).__init__() def forward(self, x, y): - r = x + y - return r + return x + y class TupleInput(nn.Module): @@ -110,8 +108,7 @@ def __init__(self): super(TupleInput, self).__init__() def forward(self, z: Tuple[torch.Tensor, torch.Tensor]): - r = z[0] + z[1] - return r + return z[0] + z[1] class ListInput(nn.Module): @@ -119,8 +116,7 @@ def __init__(self): super(ListInput, self).__init__() def forward(self, z: List[torch.Tensor]): - r = z[0] + z[1] - return r + return z[0] + z[1] class TupleInputOutput(nn.Module): @@ -131,8 +127,7 @@ def forward(self, z: Tuple[torch.Tensor, torch.Tensor]): r1 = z[0] + z[1] r2 = z[0] - z[1] r1 = r1 * 10 - r = (r1, r2) - return r + return r1, r2 class ListInputOutput(nn.Module): @@ -142,8 +137,7 @@ def __init__(self): def forward(self, z: List[torch.Tensor]): r1 = z[0] + z[1] r2 = z[0] - z[1] - r = [r1, r2] - return r + return [r1, r2] class ListInputTupleOutput(nn.Module): @@ -159,8 +153,7 @@ def forward(self, z: List[torch.Tensor]): r4 = [r2, r1] tuple_out = self.tuple_model(r3) list_out = self.list_model(r4) - r = (tuple_out[1], list_out[0]) - return r + return tuple_out[1], list_out[0] def BertModule(): @@ -185,5 +178,4 @@ def BertModule(): model = BertModel(config) model.eval() model = BertModel.from_pretrained(model_name, torchscript=True) - traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) - return traced_model + return torch.jit.trace(model, [tokens_tensor, segments_tensors]) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index f4f68ffa99..14c1298cb9 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -71,9 +71,9 @@ def get(n, m, manifest): - print("Downloading {}".format(n)) - traced_filename = n + "_traced.jit.pt" - script_filename = n + "_scripted.jit.pt" + print(f"Downloading {n}") + traced_filename = f"{n}_traced.jit.pt" + script_filename = f"{n}_scripted.jit.pt" x = torch.ones((1, 3, 300, 300)).cuda() if n == "bert-base-uncased": traced_model = m["model"] @@ -81,11 +81,11 @@ def get(n, m, manifest): manifest.update({n: [traced_filename]}) else: m["model"] = m["model"].eval().cuda() - if m["path"] == "both" or m["path"] == "trace": + if m["path"] in ["both", "trace"]: trace_model = torch.jit.trace(m["model"], [x]) torch.jit.save(trace_model, traced_filename) manifest.update({n: [traced_filename]}) - if m["path"] == "both" or m["path"] == "script": + if m["path"] in ["both", "script"]: script_model = torch.jit.script(m["model"]) torch.jit.save(script_model, script_filename) if n in manifest.keys(): @@ -104,8 +104,8 @@ def download_models(version_matches, manifest): manifest = get(n, m, manifest) else: for n, m in models.items(): - scripted_filename = n + "_scripted.jit.pt" - traced_filename = n + "_traced.jit.pt" + scripted_filename = f"{n}_scripted.jit.pt" + traced_filename = f"{n}_traced.jit.pt" # Check if model file exists on disk if ( ( @@ -116,7 +116,7 @@ def download_models(version_matches, manifest): or (m["path"] == "script" and os.path.exists(scripted_filename)) or (m["path"] == "trace" and os.path.exists(traced_filename)) ): - print("Skipping {} ".format(n)) + print(f"Skipping {n} ") continue manifest = get(n, m, manifest) @@ -131,7 +131,7 @@ def main(): manifest = {"version": torch_version} # Creating an empty manifest file for overwriting post setup - os.system("touch {}".format(MANIFEST_FILE)) + os.system(f"touch {MANIFEST_FILE}") else: manifest_exists = True diff --git a/tests/py/api/test_classes.py b/tests/py/api/test_classes.py index 3d0cb5c5f9..78d78d31e0 100644 --- a/tests/py/api/test_classes.py +++ b/tests/py/api/test_classes.py @@ -57,7 +57,7 @@ class TestInput(unittest.TestCase): def _verify_correctness(self, struct: torchtrt.Input, target: Dict) -> bool: internal = struct._to_internal() - list_eq = lambda al, bl: all([a == b for (a, b) in zip(al, bl)]) + list_eq = lambda al, bl: all(a == b for (a, b) in zip(al, bl)) eq = lambda a, b: a == b @@ -242,6 +242,9 @@ def test_dynamic_shape(self): class TestTorchTensorRTModule(unittest.TestCase): @staticmethod def _get_trt_mod(): + + + class Test(torch.nn.Module): def __init__(self): super(Test, self).__init__() @@ -249,8 +252,8 @@ def __init__(self): self.fc2 = torch.nn.Linear(5, 5) def forward(self, x): - out = self.fc2(self.fc1(x)) - return out + return self.fc2(self.fc1(x)) + mod = torch.jit.script(Test()) test_mod_engine_str = torchtrt.ts.convert_method_to_trt_engine( @@ -264,6 +267,9 @@ def forward(self, x): ) def test_detect_invalid_input_binding(self): + + + class Test(torch.nn.Module): def __init__(self): super(Test, self).__init__() @@ -271,8 +277,8 @@ def __init__(self): self.fc2 = torch.nn.Linear(5, 5) def forward(self, x): - out = self.fc2(self.fc1(x)) - return out + return self.fc2(self.fc1(x)) + mod = torch.jit.script(Test()) test_mod_engine_str = torchtrt.ts.convert_method_to_trt_engine( @@ -287,6 +293,9 @@ def forward(self, x): ) def test_detect_invalid_output_binding(self): + + + class Test(torch.nn.Module): def __init__(self): super(Test, self).__init__() @@ -294,8 +303,8 @@ def __init__(self): self.fc2 = torch.nn.Linear(5, 5) def forward(self, x): - out = self.fc2(self.fc1(x)) - return out + return self.fc2(self.fc1(x)) + mod = torch.jit.script(Test()) test_mod_engine_str = torchtrt.ts.convert_method_to_trt_engine( diff --git a/tests/py/api/test_collections.py b/tests/py/api/test_collections.py index 64f46fa3e9..8fa04dee16 100644 --- a/tests/py/api/test_collections.py +++ b/tests/py/api/test_collections.py @@ -8,7 +8,7 @@ def find_repo_root(max_depth=10): dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): + for _ in range(max_depth): files = os.listdir(dir_path) if "WORKSPACE" in files: return dir_path @@ -18,7 +18,7 @@ def find_repo_root(max_depth=10): raise RuntimeError("Could not find repo root") -MODULE_DIR = find_repo_root() + "/tests/modules" +MODULE_DIR = f"{find_repo_root()}/tests/modules" class TestStandardTensorInput(unittest.TestCase): @@ -26,7 +26,7 @@ def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/standard_tensor_input_scripted.jit.pt") .eval() .to("cuda") ) @@ -55,7 +55,7 @@ def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/standard_tensor_input_scripted.jit.pt") .eval() .to("cuda") ) @@ -85,7 +85,7 @@ def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/standard_tensor_input_scripted.jit.pt") .eval() .to("cuda") ) @@ -114,7 +114,7 @@ def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/tuple_input_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/tuple_input_scripted.jit.pt") .eval() .to("cuda") ) @@ -143,7 +143,9 @@ def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/list_input_scripted.jit.pt").eval().to("cuda") + torch.jit.load(f"{MODULE_DIR}/list_input_scripted.jit.pt") + .eval() + .to("cuda") ) compile_spec = { @@ -170,7 +172,7 @@ def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/tuple_input_output_scripted.jit.pt") .eval() .to("cuda") ) @@ -197,7 +199,7 @@ def test_compile(self): def test_compile_full_compilation(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/tuple_input_output_scripted.jit.pt") .eval() .to("cuda") ) @@ -228,7 +230,7 @@ def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/list_input_output_scripted.jit.pt") .eval() .to("cuda") ) @@ -257,7 +259,7 @@ def test_compile_full_compilation(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/list_input_output_scripted.jit.pt") .eval() .to("cuda") ) @@ -289,7 +291,7 @@ def test_compile(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/list_input_tuple_output_scripted.jit.pt") .eval() .to("cuda") ) @@ -317,7 +319,7 @@ def test_compile_full_compilation(self): self.input = torch.randn((1, 3, 224, 224)).to("cuda") self.model = ( - torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt") + torch.jit.load(f"{MODULE_DIR}/list_input_tuple_output_scripted.jit.pt") .eval() .to("cuda") ) diff --git a/tests/py/api/utils.py b/tests/py/api/utils.py index ff6bc39158..bacbd6d739 100644 --- a/tests/py/api/utils.py +++ b/tests/py/api/utils.py @@ -46,7 +46,7 @@ def same_output_format(trt_output, torch_output): for key in trt_output.keys() ) ) - elif isinstance(trt_output, set) or isinstance(trt_output, frozenset): + elif isinstance(trt_output, (set, frozenset)): raise AssertionError( "Unsupported output type 'set' encountered in output format check." ) diff --git a/tests/py/model_test_case.py b/tests/py/model_test_case.py index 42073ef747..574b7c52eb 100644 --- a/tests/py/model_test_case.py +++ b/tests/py/model_test_case.py @@ -3,7 +3,7 @@ import torchvision.models as models import os -REPO_ROOT = os.path.abspath(os.getcwd()) + "/../../" +REPO_ROOT = f"{os.path.abspath(os.getcwd())}/../../" class ModelTestCase(unittest.TestCase): diff --git a/tests/py/models/custom_models.py b/tests/py/models/custom_models.py index a19b9ca81c..c02d8cd2c2 100644 --- a/tests/py/models/custom_models.py +++ b/tests/py/models/custom_models.py @@ -24,5 +24,4 @@ def BertModule(): model = BertModel(config) model.eval() model = BertModel.from_pretrained(model_name, torchscript=True) - traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) - return traced_model + return torch.jit.trace(model, [tokens_tensor, segments_tensors]) diff --git a/tests/py/ptq/test_ptq_dataloader_calibrator.py b/tests/py/ptq/test_ptq_dataloader_calibrator.py index 79c19dadbf..b13536b2a1 100644 --- a/tests/py/ptq/test_ptq_dataloader_calibrator.py +++ b/tests/py/ptq/test_ptq_dataloader_calibrator.py @@ -12,7 +12,7 @@ def find_repo_root(max_depth=10): dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): + for _ in range(max_depth): files = os.listdir(dir_path) if "WORKSPACE" in files: return dir_path @@ -22,7 +22,7 @@ def find_repo_root(max_depth=10): raise RuntimeError("Could not find repo root") -MODULE_DIR = find_repo_root() + "/tests/modules" +MODULE_DIR = f"{find_repo_root()}/tests/modules" def compute_accuracy(testing_dataloader, model): @@ -33,7 +33,6 @@ def compute_accuracy(testing_dataloader, model): class_preds = [] device = torch.device("cuda:0") with torch.no_grad(): - idx = 0 for data, labels in testing_dataloader: data, labels = data.to(device), labels.to(device) out = model(data) @@ -42,8 +41,6 @@ def compute_accuracy(testing_dataloader, model): class_preds.append(preds) total += labels.size(0) correct += (preds == labels).sum().item() - idx += 1 - test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) test_preds = torch.cat(class_preds) return correct / total @@ -53,7 +50,7 @@ class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( - torch.jit.load(MODULE_DIR + "/trained_vgg16.jit.pt").eval().to("cuda") + torch.jit.load(f"{MODULE_DIR}/trained_vgg16.jit.pt").eval().to("cuda") ) self.input = torch.randn((1, 3, 32, 32)).to("cuda") self.testing_dataset = torchvision.datasets.CIFAR10( diff --git a/tests/py/ptq/test_ptq_to_backend.py b/tests/py/ptq/test_ptq_to_backend.py index 3a0a5bf336..7fe7944dc3 100644 --- a/tests/py/ptq/test_ptq_to_backend.py +++ b/tests/py/ptq/test_ptq_to_backend.py @@ -11,7 +11,7 @@ def find_repo_root(max_depth=10): dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): + for _ in range(max_depth): files = os.listdir(dir_path) if "WORKSPACE" in files: return dir_path @@ -21,7 +21,7 @@ def find_repo_root(max_depth=10): raise RuntimeError("Could not find repo root") -MODULE_DIR = find_repo_root() + "/tests/modules" +MODULE_DIR = f"{find_repo_root()}/tests/modules" def compute_accuracy(testing_dataloader, model): @@ -32,7 +32,6 @@ def compute_accuracy(testing_dataloader, model): class_preds = [] device = torch.device("cuda:0") with torch.no_grad(): - idx = 0 for data, labels in testing_dataloader: data, labels = data.to(device), labels.to(device) out = model(data) @@ -41,8 +40,6 @@ def compute_accuracy(testing_dataloader, model): class_preds.append(preds) total += labels.size(0) correct += (preds == labels).sum().item() - idx += 1 - test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) test_preds = torch.cat(class_preds) return correct / total @@ -51,7 +48,7 @@ def compute_accuracy(testing_dataloader, model): class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( - torch.jit.load(MODULE_DIR + "/trained_vgg16.jit.pt").eval().to("cuda") + torch.jit.load(f"{MODULE_DIR}/trained_vgg16.jit.pt").eval().to("cuda") ) self.input = torch.randn((1, 3, 32, 32)).to("cuda") self.testing_dataset = torchvision.datasets.CIFAR10( diff --git a/tests/py/ptq/test_ptq_trt_calibrator.py b/tests/py/ptq/test_ptq_trt_calibrator.py index 93596c895d..c5730c8b9a 100644 --- a/tests/py/ptq/test_ptq_trt_calibrator.py +++ b/tests/py/ptq/test_ptq_trt_calibrator.py @@ -12,7 +12,7 @@ def find_repo_root(max_depth=10): dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): + for _ in range(max_depth): files = os.listdir(dir_path) if "WORKSPACE" in files: return dir_path @@ -22,7 +22,7 @@ def find_repo_root(max_depth=10): raise RuntimeError("Could not find repo root") -MODULE_DIR = find_repo_root() + "/tests/modules" +MODULE_DIR = f"{find_repo_root()}/tests/modules" def compute_accuracy(testing_dataloader, model): @@ -33,7 +33,6 @@ def compute_accuracy(testing_dataloader, model): class_preds = [] device = torch.device("cuda:0") with torch.no_grad(): - idx = 0 for data, labels in testing_dataloader: data, labels = data.to(device), labels.to(device) out = model(data) @@ -42,8 +41,6 @@ def compute_accuracy(testing_dataloader, model): class_preds.append(preds) total += labels.size(0) correct += (preds == labels).sum().item() - idx += 1 - test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) test_preds = torch.cat(class_preds) return correct / total @@ -97,7 +94,7 @@ def write_calibration_cache(self, cache): class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( - torch.jit.load(MODULE_DIR + "/trained_vgg16.jit.pt").eval().to("cuda") + torch.jit.load(f"{MODULE_DIR}/trained_vgg16.jit.pt").eval().to("cuda") ) self.input = torch.randn((1, 3, 32, 32)).to("cuda") self.testing_dataset = torchvision.datasets.CIFAR10( diff --git a/tests/py/qat/test_qat_trt_accuracy.py b/tests/py/qat/test_qat_trt_accuracy.py index ce574c57fe..83ab04c79f 100644 --- a/tests/py/qat/test_qat_trt_accuracy.py +++ b/tests/py/qat/test_qat_trt_accuracy.py @@ -12,7 +12,7 @@ def find_repo_root(max_depth=10): dir_path = os.path.dirname(os.path.realpath(__file__)) - for i in range(max_depth): + for _ in range(max_depth): files = os.listdir(dir_path) if "WORKSPACE" in files: return dir_path @@ -22,7 +22,7 @@ def find_repo_root(max_depth=10): raise RuntimeError("Could not find repo root") -MODULE_DIR = find_repo_root() + "/tests/modules" +MODULE_DIR = f"{find_repo_root()}/tests/modules" set_reportable_log_level(Level.Graph) @@ -35,7 +35,6 @@ def compute_accuracy(testing_dataloader, model): class_preds = [] device = torch.device("cuda:0") with torch.no_grad(): - idx = 0 for data, labels in testing_dataloader: data, labels = data.to(device), labels.to(device) out = model(data) @@ -44,8 +43,6 @@ def compute_accuracy(testing_dataloader, model): class_preds.append(preds) total += labels.size(0) correct += (preds == labels).sum().item() - idx += 1 - test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) test_preds = torch.cat(class_preds) return correct / total @@ -54,7 +51,9 @@ def compute_accuracy(testing_dataloader, model): class TestAccuracy(unittest.TestCase): def test_compile_script(self): self.model = ( - torch.jit.load(MODULE_DIR + "/trained_vgg16_qat.jit.pt").eval().to("cuda") + torch.jit.load(f"{MODULE_DIR}/trained_vgg16_qat.jit.pt") + .eval() + .to("cuda") ) self.testing_dataset = torchvision.datasets.CIFAR10( root="./data", diff --git a/tools/linter/cpplint.py b/tools/linter/cpplint.py index 43a6474305..4314cabbed 100644 --- a/tools/linter/cpplint.py +++ b/tools/linter/cpplint.py @@ -17,7 +17,7 @@ def lint(user, target_files, change_file=True): for f in target_files: cmd.append(f) subprocess.run(cmd) - subprocess.run(["chown", user + ":" + user, f]) + subprocess.run(["chown", f"{user}:{user}", f]) subprocess.run(["chmod", "644", f]) @@ -27,14 +27,15 @@ def lint(user, target_files, change_file=True): projects = utils.CHECK_PROJECTS(sys.argv[1:]) if "//..." in projects: projects = [ - p.replace(BAZEL_ROOT, "/")[:-1] for p in glob.glob(BAZEL_ROOT + "/*/") + p.replace(BAZEL_ROOT, "/")[:-1] + for p in glob.glob(f"{BAZEL_ROOT}/*/") ] projects = [p for p in projects if p not in utils.BLACKLISTED_BAZEL_TARGETS] for p in projects: if p.endswith("/..."): p = p[:-4] - path = BAZEL_ROOT + "/" + p[2:] + path = f"{BAZEL_ROOT}/{p[2:]}" files = utils.glob_files(path, utils.VALID_CPP_FILE_TYPES) if files != []: lint(USER, files) diff --git a/tools/linter/cpplint_diff.py b/tools/linter/cpplint_diff.py index 307978e43f..096c1cecee 100644 --- a/tools/linter/cpplint_diff.py +++ b/tools/linter/cpplint_diff.py @@ -33,7 +33,8 @@ def lint(target_files, color=True): projects = utils.CHECK_PROJECTS(sys.argv[1:]) if "//..." in projects: projects = [ - p.replace(BAZEL_ROOT, "/")[:-1] for p in glob.glob(BAZEL_ROOT + "/*/") + p.replace(BAZEL_ROOT, "/")[:-1] + for p in glob.glob(f"{BAZEL_ROOT}/*/") ] projects = [p for p in projects if p not in utils.BLACKLISTED_BAZEL_TARGETS] @@ -41,7 +42,7 @@ def lint(target_files, color=True): for p in projects: if p.endswith("/..."): p = p[:-4] - path = BAZEL_ROOT + "/" + p[2:] + path = f"{BAZEL_ROOT}/{p[2:]}" files = utils.glob_files(path, utils.VALID_CPP_FILE_TYPES) if files != []: if lint(files, color): diff --git a/tools/linter/pylint.py b/tools/linter/pylint.py index d5ce8f2e15..1a896f6970 100644 --- a/tools/linter/pylint.py +++ b/tools/linter/pylint.py @@ -19,7 +19,7 @@ def lint(user, target_files, change_file=True): for f in target_files: subprocess.run(cmd) - subprocess.run(["chown", user + ":" + user, f]) + subprocess.run(["chown", f"{user}:{user}", f]) subprocess.run(["chmod", "644", f]) @@ -29,14 +29,15 @@ def lint(user, target_files, change_file=True): projects = utils.CHECK_PROJECTS(sys.argv[1:]) if "//..." in projects: projects = [ - p.replace(BAZEL_ROOT, "/")[:-1] for p in glob.glob(BAZEL_ROOT + "/*/") + p.replace(BAZEL_ROOT, "/")[:-1] + for p in glob.glob(f"{BAZEL_ROOT}/*/") ] projects = [p for p in projects if p not in utils.BLACKLISTED_BAZEL_TARGETS] for p in projects: if p.endswith("/..."): p = p[:-4] - path = BAZEL_ROOT + "/" + p[2:] + path = f"{BAZEL_ROOT}/{p[2:]}" files = utils.glob_files(path, utils.VALID_PY_FILE_TYPES) if files != []: lint(USER, files) diff --git a/tools/linter/pylint_diff.py b/tools/linter/pylint_diff.py index de11bfa0af..386cd7c5e0 100644 --- a/tools/linter/pylint_diff.py +++ b/tools/linter/pylint_diff.py @@ -6,19 +6,12 @@ def lint(target_files, color=True): - failure = False cmd = ["black", "--diff"] - if color: - cmd += ["--color"] - else: - cmd += ["--no-color"] + cmd += ["--color"] if color else ["--no-color"] cmd += target_files output = subprocess.run(cmd) - if output.returncode != 0: - failure = True - - return failure + return output.returncode != 0 if __name__ == "__main__": @@ -31,7 +24,8 @@ def lint(target_files, color=True): projects = utils.CHECK_PROJECTS(sys.argv[1:]) if "//..." in projects: projects = [ - p.replace(BAZEL_ROOT, "/")[:-1] for p in glob.glob(BAZEL_ROOT + "/*/") + p.replace(BAZEL_ROOT, "/")[:-1] + for p in glob.glob(f"{BAZEL_ROOT}/*/") ] projects = [p for p in projects if p not in utils.BLACKLISTED_BAZEL_TARGETS] @@ -39,7 +33,7 @@ def lint(target_files, color=True): for p in projects: if p.endswith("/..."): p = p[:-4] - path = BAZEL_ROOT + "/" + p[2:] + path = f"{BAZEL_ROOT}/{p[2:]}" files = utils.glob_files(path, utils.VALID_PY_FILE_TYPES) if files != []: if lint(files, color): diff --git a/tools/linter/utils.py b/tools/linter/utils.py index d63642ccce..7fad815d90 100644 --- a/tools/linter/utils.py +++ b/tools/linter/utils.py @@ -35,7 +35,7 @@ def CHECK_PROJECTS(projs): for p in projs: if p[:2] != "//": - sys.exit(p + " is not a valid bazel target") + sys.exit(f"{p} is not a valid bazel target") return projs @@ -45,7 +45,7 @@ def find_bazel_root(): """ curdir = os.path.dirname(os.path.realpath(__file__)) while 1: - if os.path.exists(curdir + "/WORKSPACE"): + if os.path.exists(f"{curdir}/WORKSPACE"): return curdir if curdir == "/": sys.exit("Error: was unable to find a bazel workspace") @@ -55,5 +55,5 @@ def find_bazel_root(): def glob_files(project, file_types): files = [] for t in file_types: - files += glob.glob(project + "/**/*" + t, recursive=True) + files += glob.glob(f"{project}/**/*{t}", recursive=True) return files diff --git a/tools/perf/custom_models.py b/tools/perf/custom_models.py index a8b8a5dae0..9ac9ed9d5e 100644 --- a/tools/perf/custom_models.py +++ b/tools/perf/custom_models.py @@ -26,5 +26,4 @@ def BertModule(): model = BertModel(config) model.eval() model = BertModel.from_pretrained(model_name, torchscript=True) - traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) - return traced_model + return torch.jit.trace(model, [tokens_tensor, segments_tensors]) diff --git a/tools/perf/hub.py b/tools/perf/hub.py index a1f032212b..6606822c47 100644 --- a/tools/perf/hub.py +++ b/tools/perf/hub.py @@ -47,10 +47,10 @@ def get(n, m, manifest): - print("Downloading {}".format(n)) - traced_filename = "models/" + n + "_traced.jit.pt" - script_filename = "models/" + n + "_scripted.jit.pt" - pytorch_filename = "models/" + n + "_pytorch.pt" + print(f"Downloading {n}") + traced_filename = f"models/{n}_traced.jit.pt" + script_filename = f"models/{n}_scripted.jit.pt" + pytorch_filename = f"models/{n}_pytorch.pt" x = torch.ones((1, 3, 300, 300)).cuda() if n == "bert_base_uncased": traced_model = m["model"] @@ -97,9 +97,9 @@ def download_models(version_matches, manifest): manifest = get(n, m, manifest) else: for n, m in BENCHMARK_MODELS.items(): - scripted_filename = "models/" + n + "_scripted.jit.pt" - traced_filename = "models/" + n + "_traced.jit.pt" - pytorch_filename = "models/" + n + "_pytorch.pt" + scripted_filename = f"models/{n}_scripted.jit.pt" + traced_filename = f"models/{n}_traced.jit.pt" + pytorch_filename = f"models/{n}_pytorch.pt" # Check if model file exists on disk # Extract model specifications as list and ensure all desired formats exist @@ -129,7 +129,7 @@ def download_models(version_matches, manifest): and os.path.exists(pytorch_filename) ) ): - print("Skipping {} ".format(n)) + print(f"Skipping {n} ") continue manifest = get(n, m, manifest) @@ -143,7 +143,7 @@ def main(): manifest = {"version": torch_version} # Creating an empty manifest file for overwriting post setup - os.system("touch {}".format(MANIFEST_FILE)) + os.system(f"touch {MANIFEST_FILE}") else: # Load manifest if already exists diff --git a/tools/perf/perf_run.py b/tools/perf/perf_run.py index 00ddbabd22..d71f61268e 100644 --- a/tools/perf/perf_run.py +++ b/tools/perf/perf_run.py @@ -51,7 +51,7 @@ def read_config(self): # Retrieves the value from the configuration else uses default values def get(self, key, default_value=None): - if not key in self.params: + if key not in self.params: if not default_value: raise ValueError( "Key {} is not present and default_value is not configured. Please run it with default value", @@ -75,7 +75,7 @@ def run_torch(model, input_tensors, params, precision, batch_size): timings = [] with torch.no_grad(): - for i in range(iters): + for _ in range(iters): start_time = timeit.default_timer() features = model(*input_tensors) torch.cuda.synchronize() @@ -104,7 +104,7 @@ def run_torch_tensorrt( } if precision == "int8": - compile_settings.update({"calib": params.get("calibration_cache")}) + compile_settings["calib"] = params.get("calibration_cache") start_compile = time.time_ns() model = torchtrt.compile(model, **compile_settings) @@ -121,7 +121,7 @@ def run_torch_tensorrt( timings = [] with torch.no_grad(): - for i in range(iters): + for _ in range(iters): start_time = timeit.default_timer() features = model(*input_tensors) torch.cuda.synchronize() @@ -160,7 +160,7 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size): timings = [] with torch.no_grad(): - for i in range(iters): + for _ in range(iters): start_time = timeit.default_timer() features = model(*input_tensors) torch.cuda.synchronize() @@ -185,7 +185,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): if precision == "fp16": input_tensors = [tensor.half() for tensor in input_tensors] - fp16_mode = True if precision == "fp16" else False + fp16_mode = precision == "fp16" # dynamo_backend_params = {"fp16_mode" : fp16_mode} # model = torch.compile( # model, @@ -221,7 +221,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size): torch.cuda.synchronize() print("============= DONE 2 ==================") timings = [] - for i in range(iters): + for _ in range(iters): start_time = timeit.default_timer() features = exported_model(*input_tensors) torch.cuda.synchronize() @@ -246,7 +246,7 @@ def torch_dtype_from_trt(dtype): elif dtype == trt.float32: return torch.float32 else: - raise TypeError("%s is not supported by torch" % dtype) + raise TypeError(f"{dtype} is not supported by torch") def torch_device_from_trt(device): @@ -255,7 +255,7 @@ def torch_device_from_trt(device): elif device == trt.TensorLocation.HOST: return torch.device("cpu") else: - return TypeError("%s is not supported by torch" % device) + return TypeError(f"{device} is not supported by torch") def run_tensorrt( @@ -309,11 +309,11 @@ def run_tensorrt( timings = [] with engine.create_execution_context() as context: - for i in range(WARMUP_ITER): + for _ in range(WARMUP_ITER): context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream) torch.cuda.synchronize() - for i in range(iters): + for _ in range(iters): start_time = timeit.default_timer() context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream) torch.cuda.synchronize() @@ -338,7 +338,7 @@ def run( ): for backend in backends: if precision == "int8": - if backend == "all" or backend == "torch": + if backend in ["all", "torch"]: print( "int8 precision is not supported for torch runtime in this script yet" ) @@ -347,7 +347,7 @@ def run( if ( backend == "all" or backend == "torch_tensorrt" - or params.get("calibration_cache", None) == None + or params.get("calibration_cache", None) is None ): print("int8 precision expects calibration cache file for inference") return False @@ -606,12 +606,12 @@ def load_torch_model(params): input_tensors = [] num_input = params.get("input").get("num_inputs", 1) for i in range(num_input): - inp_tensor = params.get("input").get("input" + str(i)) + inp_tensor = params.get("input").get(f"input{str(i)}") input_tensors.append( torch.randint( 0, 2, - tuple(d for d in inp_tensor), + tuple(inp_tensor), dtype=precision_to_dtype(precision), ).cuda() ) @@ -621,7 +621,7 @@ def load_torch_model(params): "Warning, TensorRT engine file is configured. Please make sure the precision matches with the TRT engine for reliable results" ) - if not is_trt_engine and (precision == "fp16" or precision == "half"): + if not is_trt_engine and precision in ["fp16", "half"]: # If model is TensorRT serialized engine then model.half will report failure if model is not None: model = model.half() @@ -686,7 +686,7 @@ def load_torch_model(params): params["inputs"], precision_to_dtype(precision) ) - if not is_trt_engine and (precision == "fp16" or precision == "half"): + if not is_trt_engine and precision in ["fp16", "half"]: # If model is TensorRT serialized engine then model.half will report failure model = model.half() diff --git a/tools/perf/utils.py b/tools/perf/utils.py index 96a13ffbc2..a6be23124b 100644 --- a/tools/perf/utils.py +++ b/tools/perf/utils.py @@ -28,7 +28,7 @@ def precision_to_dtype(pr): if pr == "fp32": return torch.float - elif pr == "fp16" or pr == "half": + elif pr in ["fp16", "half"]: return torch.half elif pr == "int32": return torch.int32 @@ -42,15 +42,16 @@ def parse_inputs(user_inputs, dtype): parsed_inputs = user_inputs.split(";") torchtrt_inputs = [] for input in parsed_inputs: - input_shape = [] input_shape_and_dtype = input.split("@") dtype = ( precision_to_dtype(input_shape_and_dtype[1]) if len(input_shape_and_dtype) == 2 else dtype ) - for input_dim in input_shape_and_dtype[0][1:-1].split(","): - input_shape.append(int(input_dim)) + input_shape = [ + int(input_dim) + for input_dim in input_shape_and_dtype[0][1:-1].split(",") + ] torchtrt_inputs.append(torch.randint(0, 5, input_shape, dtype=dtype).cuda()) return torchtrt_inputs