Thanks to visit codestin.com
Credit goes to github.com

Skip to content

torch.export with dynamic shapes on Static Cache HF LLama model fails #152465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
peri044 opened this issue Apr 29, 2025 · 1 comment
Open

torch.export with dynamic shapes on Static Cache HF LLama model fails #152465

peri044 opened this issue Apr 29, 2025 · 1 comment

Comments

@peri044
Copy link
Contributor

peri044 commented Apr 29, 2025

πŸ› Describe the bug

I'm trying to export HF Llama model with Static Cache. StaticCache export is supported but only can be used with static input shapes (https://github.com/huggingface/transformers/blob/f39f4960f30e3eadd6d948e4dcb2da32eda253b5/tests/utils/test_cache_utils.py#L247-L271 and https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py#L493-L497). When I try to export it using dynamic shapes, I face the following error

torch._dynamo.exc.UserError: Cannot associate shape {} specified at `dynamic_shapes['past_key_values']` to non-tensor type <class 'transformers.cache_utils.StaticCache'> at `inputs['past_key_values']` (expected None)
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`

If I change the past_key_values (in dynamic_shapes) to "past_key_values" : None, the error is

ValueError: Unsupported input type <class 'transformers.cache_utils.StaticCache'>. Export only supports pytree containers of basic types (Tensor, int, float, ...) as input. To register a custom dataclass, use torch.export.register_dataclass. To register a custom container type, use torch.utils._pytree.register_pytree_node. To register a constant input, use torch.utils._pytree.register_constant

What is the correct way to export it with dynamic shapes ? Is this supported ? Thanks !!
cc: @angelayi

import torch 
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TorchExportableModuleWithStaticCache, StaticCache

with torch.inference_mode():
    max_seq_len = 2176
    DEVICE="cuda"
    model_id = "meta-llama/Llama-3.2-1B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map=DEVICE,
            torch_dtype=torch.float16,
            attn_implementation="sdpa",
            num_hidden_layers=1,
            generation_config=GenerationConfig(
                use_cache=True,
                cache_implementation="static",
                max_length=max_seq_len,
                cache_config={
                    "batch_size": 1,
                    "max_cache_len": max_seq_len,
                    "device": DEVICE,
                },
            ),
        ).eval().cuda()
    
    static_cache = StaticCache(
            config=model.config,
            max_batch_size=model.generation_config.cache_config.batch_size,
            max_cache_len=model.generation_config.cache_config.max_cache_len,
            device=model.generation_config.cache_config.device,
            dtype=model.dtype,
        )
    
    for i in range(len(static_cache.key_cache)):
            model.register_buffer(f"key_cache_{i}", static_cache.key_cache[i], persistent=False)
            model.register_buffer(f"value_cache_{i}", static_cache.value_cache[i], persistent=False)

    model.is_causal = any("CausalLM" in arch for arch in model.model.config.architectures)
    if model.is_causal:
        causal_mask = torch.tril(
            torch.ones(
                static_cache.max_cache_len,
                static_cache.max_cache_len,
                dtype=torch.bool,
            )
        )
    model.register_buffer("mask", causal_mask, persistent=False)

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    prompt = "What is parallel programming ?"
    model_inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = model_inputs["input_ids"].to(DEVICE)
    cache_position = torch.arange(input_ids.shape[1]).to(DEVICE)

    seq_len = torch.export.Dim("seq_len", min=1, max=max_seq_len)
    cache_len = torch.export.Dim("cache_len", min=1, max=max_seq_len)
    
    exported_program = torch.export.export(
        model,
        args=(),
        kwargs={"input_ids" : input_ids, "cache_position": cache_position, "past_key_values": static_cache},
        dynamic_shapes={"input_ids" : {1: seq_len}, "cache_position" : {0: cache_len}, "past_key_values" : {}}, 
        strict=False,
    )
    gm = exported_program.module()
    print(gm.graph)

Versions

import torch
torch.version
'2.8.0.dev20250423+cu128'
transformers.version
'4.49.0'

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@angelayi
Copy link
Contributor

cc @tugsbayasgalan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants