diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 25b2fc6d2e..3891fcbb9a 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -9,11 +9,11 @@ **FLUX.1 [dev]** is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. It is an open-weight, guidance-distilled model for non-commercial applications. -Install the following dependencies before compilation +To run this demo, you need to have access to Flux model (request for access if you do not have it already on the `FLUX.1-dev `_ page) and install the following dependencies .. code-block:: python - pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" + pip install sentencepiece=="0.2.0" transformers=="4.48.2" accelerate=="1.3.0" diffusers=="0.32.2" protobuf=="5.29.3" There are different components of the ``FLUX.1-dev`` pipeline such as ``transformer``, ``vae``, ``text_encoder``, ``tokenizer`` and ``scheduler``. In this example, we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency) @@ -38,11 +38,10 @@ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16, ) -pipe.to(DEVICE).to(torch.float16) + # Store the config and transformer backbone config = pipe.transformer.config -backbone = pipe.transformer - +backbone = pipe.transformer.to(DEVICE) # %% # Export the backbone using torch.export @@ -63,6 +62,8 @@ "txt_ids": {0: SEQ_LEN}, "img_ids": {0: IMG_ID}, "guidance": {0: BATCH}, + "joint_attention_kwargs": {}, + "return_dict": None, } # The guidance factor is of type torch.float32 dummy_inputs = { @@ -79,6 +80,8 @@ "txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE), "img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE), "guidance": torch.tensor([1.0, 1.0], dtype=torch.float32).to(DEVICE), + "joint_attention_kwargs": {}, + "return_dict": False, } # This will create an exported program which is going to be compiled with Torch-TensorRT ep = _export( @@ -116,8 +119,11 @@ # --------------------------- # Release the GPU memory occupied by the exported program and the pipe.transformer # Set the transformer in the Flux pipeline to the Torch-TRT compiled model -backbone.to("cpu") + del ep +backbone.to("cpu") +pipe.to(DEVICE) +torch.cuda.empty_cache() pipe.transformer = trt_gm pipe.transformer.config = config diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index f24fd8ec21..f611c90f51 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -10,7 +10,7 @@ from .fuse_prims_broadcast import fuse_prims_broadcast from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager -from .remove_assert_scalar import remove_assert_scalar +from .remove_assert_nodes import remove_assert_nodes from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output @@ -27,7 +27,7 @@ replace_max_pool_with_indices, lower_scaled_dot_product_attention, view_to_reshape, - remove_assert_scalar, + remove_assert_nodes, accumulate_fp32_matmul, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py index 5ffdf08b7d..e569c45cfa 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py @@ -9,17 +9,54 @@ logger = logging.getLogger(__name__) +def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + target = torch.ops.aten.addmm.default + addmm_nodes = [node for node in gm.graph.nodes if node.target == target] + for addmm_node in addmm_nodes: + bias, mat1, mat2 = addmm_node.all_input_nodes + beta = addmm_node.kwargs.get("beta") + alpha = addmm_node.kwargs.get("alpha") + + with gm.graph.inserting_before(addmm_node): + mm_node = gm.graph.call_function( + torch.ops.aten.mm.default, + args=(mat1, mat2), + ) + if alpha: + mm_node = gm.graph.call_function( + torch.ops.aten.mul.Tensor, + args=(mm_node, alpha), + ) + + if beta: + bias = gm.graph.call_function( + torch.ops.aten.mul.Tensor, + args=(bias, beta), + ) + add_node = gm.graph.call_function( + torch.ops.aten.add.Tensor, + args=(bias, mm_node), + ) + + addmm_node.replace_all_uses_with(add_node, propagate_meta=True) + gm.graph.erase_node(addmm_node) + + return gm + + def accumulate_fp32_matmul( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: - """Replace a matmul layer with fp32 accumulation nodes""" + """Add cast to FP32/16 nodes around a matmul layer. This pattern is detected by TensorRT and will enable FP32 accumulation during execution.""" if settings.use_fp32_acc: matmul_targets = [ torch.ops.aten.mm.default, torch.ops.aten.bmm.default, - torch.ops.aten.addmm.default, ] + # Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes + split_addmm_nodes(gm) + matmul_nodes = [ node for node in gm.graph.nodes if node.target in matmul_targets ] diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index b8f4d7de48..6ebefc5509 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -51,6 +51,8 @@ def constant_fold( gm.graph.erase_node(node) gm = clean_up_graph_after_modifications(gm) + # Delete the constant folder instance which holds GPU memory + del cf logger.debug(f"Graph after constant folding:\n{gm.graph}") diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py similarity index 87% rename from py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py rename to py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py index 67d2ba6690..890391e280 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_scalar.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_assert_nodes.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) -def remove_assert_scalar( +def remove_assert_nodes( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: """Remove assert_scalar ops in the graph""" @@ -17,7 +17,7 @@ def remove_assert_scalar( for node in gm.graph.nodes: if ( node.target == torch.ops.aten._assert_scalar.default - or node == torch.ops.aten._assert_tensor_metadata.default + or node.target == torch.ops.aten._assert_tensor_metadata.default ): gm.graph.erase_node(node) count += 1 diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index f3d297a01c..1ff754532f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import gc import logging import warnings from dataclasses import fields, replace @@ -30,6 +31,7 @@ DYNAMIC_DIM = -1 RTOL = 5e-3 ATOL = 5e-3 +CPU_DEVICE = "cpu" class Frameworks(Enum): @@ -81,6 +83,17 @@ class Frameworks(Enum): } +def delete_module(module: torch.fx.GraphModule) -> None: + """ + This is a helper function to delete the instance of module. We first move it to CPU and then + delete the object. This function ensures the GPU memory occupied by the module is released effectively after this call + """ + module.to(CPU_DEVICE) + del module + torch.cuda.empty_cache() + gc.collect() + + def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool: """Parses a user-provided input argument regarding Python runtime diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 76d47d24bd..868a092eb0 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -269,6 +269,48 @@ def forward(self, input, weight): ) torch._dynamo.reset() + def test_fp32_acc_for_addmm(self): + class FP32Acc(torch.nn.Module): + def forward(self, input, mat1, mat2): + out = torch.ops.aten.addmm.default(input, mat1, mat2, beta=20, alpha=2) + return out + + inputs = [ + torch.rand((3, 5)).cuda(), + torch.rand((3, 4)).cuda(), + torch.rand((4, 5)).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(FP32Acc()) + expected_ops = { + torch.ops.aten._to_copy.default, + torch.ops.aten.mm.default, + torch.ops.aten.add.Tensor, + } + unexpected_ops = {} + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + use_fp32_acc=True, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + class TestLowerEfficientAttention(TestCase): def test_lower_efficient_attention(self):