Thanks to visit codestin.com
Credit goes to github.com

Skip to content

🐛 [Bug] Decomposing attention leads to shape errors (due to view op) in FLUX model #3333

@peri044

Description

@peri044

Bug Description

After merging this PR : #3296, I see the following error

ValueError: Cannot view a tensor with shape torch.Size([s6, s2 + 4096, 24, 128]) and strides (3072*s2 + 12582912, 128, 128*s2 + 524288, 1) as a tensor with shape (s1, (s6*(s2 + 4096)//s1), 3072)!

While executing %view_52 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%transpose_10, [%sym_size_int_63, -1, 3072]), kwargs = {})
Original traceback:
File "/work/TensorRT/examples/dynamo/run_2.py", line 48, in forward
    return self.module.forward(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 438, in forward
    hidden_states = block(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 119, in forward
    attn_output = self.attn(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(

To Reproduce

Here's the full script :

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import export_llm, generate
from torch.export import Dim
from typing import Optional, Dict, Any
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)
logger.addHandler(handler)

import time
from contextlib import contextmanager

@contextmanager
def timer(logger, name:str):
    logger.info(f"{name} section Start...")
    start = time.time()
    yield
    end = time.time()
    logger.info(f"{name} section End...")
    logger.info(f"{name} section elapsed time: {end - start} seconds")

class MyModule(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self,
                hidden_states: torch.Tensor,
                encoder_hidden_states: torch.Tensor = None,
                pooled_projections: torch.Tensor = None,
                timestep: torch.LongTensor = None,
                img_ids: torch.Tensor = None,
                txt_ids: torch.Tensor = None,
                guidance: torch.Tensor = None,
                joint_attention_kwargs: Optional[Dict[str, Any]] = None,
                return_dict: bool = False, **kwargs):


        return self.module.forward(
            hidden_states,
            encoder_hidden_states,
            pooled_projections,
            timestep,
            img_ids,
            txt_ids,
        )

def wrap_pipeline_transformer_call(instance, prompt, max_sequence_length):
    from unittest.mock import patch

# Assume `instance` is your class instance containing the `__call__` method

# Use patch.object to mock the __call__ method of self.transformer
    with patch.object(instance.transformer, 'forward', wraps=instance.transformer.forward) as mock_transformer_call:
        # one step is enough for intercept the inputs
        image =instance(
                prompt,
                guidance_scale=0.0,
                num_inference_steps=1,
                max_sequence_length=max_sequence_length,
                generator=torch.Generator("cpu").manual_seed(0)
            ).images[0]


        # Access the call arguments of the first (or specific) call
        if mock_transformer_call.call_args_list:
            args, kwargs = mock_transformer_call.call_args_list[0]
            # Store the inputs in a tuple
            intercepted_inputs = (args, kwargs)
            
            # print("Intercepted args:", args)
            # print("Intercepted kwargs:", kwargs)
            return (args, kwargs)
        else:
            print("No calls were made to self.transformer.__call__")
            return (None, None)


if __name__ == "__main__":

    # config
    dryrun = False

    # parameter setting
    batch_size = 2
    max_seq_len = 256
    prompt = ["A cat holding a sign that says hello world" for _ in range(batch_size)]
    cuda_device = "cuda:0"
    device="cuda:0"
    with torch.no_grad():
        pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", 
                                            torch_dtype=torch.float16)
        pipe.to(device)
        
        example_inputs = (torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 256, 4096), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 768), dtype=torch.float16).to(device),
                  torch.tensor([1., 1.], dtype=torch.float16).to(device),
                  torch.randn((batch_size, 4096, 3), dtype=torch.float16).to(device),
                  torch.randn((batch_size, 256, 3), dtype=torch.float16).to(device),
        )
        BATCH = Dim("batch", min=1, max=batch_size)
        SEQ_LEN = Dim("seq_len", min=1, max=max_seq_len)
        dynamic_shapes = ({0 : BATCH}, 
                        {0 : BATCH, 1 : SEQ_LEN},
                        {0 : BATCH},
                        {0 : BATCH},
                        {0 : BATCH},
                        {0 : BATCH, 1 : SEQ_LEN},
                        )
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"1 Free mem: {free}, Total mem: {total}")
        # breakpoint()
        with timer(logger=logger, name="ep_gen"):
                model = MyModule(pipe.transformer).eval().half()#.to(device)
                logger.info("Directly use _export because torch.export.export doesn't work")
                # This API is used to express the constraint violation guards as asserts in the graph.
                from torch.export._trace import _export
                ep = _export(
                    model,
                    args=example_inputs, 
                    dynamic_shapes=dynamic_shapes,
                    strict=False,
                    allow_complex_guards_as_runtime_asserts=True,
                )
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"2 Free mem: {free}, Total mem: {total}")
        # breakpoint()
        logger.info(f"Generating TRT engine now, dryrun={dryrun}...")
        # print("Generating TRT engine now...")
        #TODO: if some non-tensor input, do we still need to provide them.
        with timer(logger, "trt_gen"):
            with torch_tensorrt.logging.debug():
                trt_start = time.time()
                trt_model = torch_tensorrt.dynamo.compile(
                                ep,
                                inputs=list(example_inputs),
                                enabled_precisions={torch.float32},
                                truncate_double=True,
                                device=torch.device(cuda_device),
                                disable_tf32=True,
                                use_explicit_typing=True,
                                dryrun=dryrun,
                                debug=True,
                                use_fp32_acc=True,
                            )
                trt_end = time.time()
        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"3 Free mem: {free}, Total mem: {total}")
        breakpoint()
        del pipe
        del ep
        del model

        free, total = torch.cuda.mem_get_info(cuda_device)
        print(f"4 Free mem: {free}, Total mem: {total}")
        breakpoint()
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        example_inputs_cuda = [input.cuda() for input in example_inputs]
        with timer(logger, "trt_save"):
            try:
                breakpoint()
                trt_ep = torch.export.export(trt_model, args=example_inputs_cuda,
                                    dynamic_shapes=dynamic_shapes, strict=False)
                torch.export.save(trt_ep, "trt.ep")
            except Exception as e:
                import traceback
                # Capture the full traceback
                tb = traceback.format_exc()
                logger.warning("An error occurred. Here's the traceback:")
                # print(tb)
                logger.warning(tb)
                breakpoint()
                torch_tensorrt.save(trt_model, "trt.ep")

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions