Thanks to visit codestin.com
Credit goes to github.com

Skip to content

feat: add refresh cache context api#542

Merged
DefTruth merged 18 commits intomainfrom
dev
Dec 9, 2025
Merged

feat: add refresh cache context api#542
DefTruth merged 18 commits intomainfrom
dev

Conversation

@DefTruth
Copy link
Member

@DefTruth DefTruth commented Dec 8, 2025

fixed #540, add refresh cache context api to reduce dependency on num_inference_steps.

Qwen-Image

import cache_dit
from cache_dit import DBCacheConfig
from diffusers import DiffusionPipeline

# Init cache context with num_inference_steps=None (default)
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
pipe = cache_dit.enable_cache(pipe.transformer, cache_config=DBCacheConfig(num_inference_steps=None))

# Assume num_inference_steps is 28, and we want to refresh the context
cache_dit.refresh_context(transformer, num_inference_steps=28, verbose=True)
output = pipe(...) # Just call the pipe as normal.
stats = cache_dit.summary(pipe.transformer) # Then, get the summary

# Update the cache context with new num_inference_steps=50.
cache_dit.refresh_context(pipe.transformer, num_inference_steps=50, verbose=True)
output = pipe(...) # Just call the pipe as normal.
stats = cache_dit.summary(pipe.transformer) # Then, get the summary

# Update the cache context with new cache_config.
cache_dit.refresh_context(
    pipe.transformer,
    cache_config=DBCacheConfig(
        residual_diff_threshold=0.1,
        max_warmup_steps=10,
        max_cached_steps=20,
        max_continuous_cached_steps=4,
        num_inference_steps=50,
    ),
    verbose=True,
)
output = pipe(...) # Just call the pipe as normal.
stats = cache_dit.summary(pipe.transformer) # Then, get the summary

Wan 2.2 T2V

import diffusers
from diffusers import WanPipeline, AutoencoderKLWan, WanTransformer3DModel
import cache_dit

pipe = WanPipeline.from_pretrained(
    (
        args.model_path
        if args.model_path is not None
        else os.environ.get(
            "WAN_2_2_DIR",
            "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
        )
    ),
    torch_dtype=torch.bfloat16,
    # https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
    device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
)

if args.cache:
    from cache_dit import (
        BlockAdapter,
        ForwardPattern,
        ParamsModifier,
        DBCacheConfig,
    )

    assert isinstance(pipe.transformer, WanTransformer3DModel)
    assert isinstance(pipe.transformer_2, WanTransformer3DModel)

    # Dual transformer caching with transformer-only api in cache-dit.
    cache_dit.enable_cache(
        BlockAdapter(
            transformer=[
                pipe.transformer,
                pipe.transformer_2,
            ],
            blocks=[
                pipe.transformer.blocks,
                pipe.transformer_2.blocks,
            ],
            forward_pattern=[
                ForwardPattern.Pattern_2,
                ForwardPattern.Pattern_2,
            ],
            params_modifiers=[
                # high-noise transformer only have 30% steps
                ParamsModifier(
                    cache_config=DBCacheConfig().reset(
                        max_warmup_steps=4,
                        max_cached_steps=8,
                    ),
                ),
                ParamsModifier(
                    cache_config=DBCacheConfig().reset(
                        max_warmup_steps=2,
                        max_cached_steps=20,
                    ),
                ),
            ],
            has_separate_cfg=True,
        ),
        cache_config=DBCacheConfig(
            Fn_compute_blocks=args.Fn,
            Bn_compute_blocks=args.Bn,
            max_warmup_steps=args.max_warmup_steps,
            max_cached_steps=args.max_cached_steps,
            max_continuous_cached_steps=args.max_continuous_cached_steps,
            residual_diff_threshold=args.rdt,
            # NOTE: num_inference_steps can be None here, we will
            # set it properly during cache refreshing.
            num_inference_steps=None,
        ),
    )

def split_inference_steps(num_inference_steps: int = 30) -> tuple[int, int]:
    if pipe.config.boundary_ratio is not None:
        boundary_timestep = pipe.config.boundary_ratio * pipe.scheduler.config.num_train_timesteps
    else:
        boundary_timestep = None
    pipe.scheduler.set_timesteps(num_inference_steps, device="cuda")
    timesteps = pipe.scheduler.timesteps
    num_high_noise_steps = 0  # high-noise steps for transformer
    for t in timesteps:
        if boundary_timestep is not None and t >= boundary_timestep:
            num_high_noise_steps += 1
    # low-noise steps for transformer_2
    num_low_noise_steps = num_inference_steps - num_high_noise_steps
    return num_high_noise_steps, num_low_noise_steps


def run_pipe(steps: int = 30):

    if args.cache:
        # Refresh cache context with proper num_inference_steps
        num_high_noise_steps, num_low_noise_steps = split_inference_steps(
            num_inference_steps=steps,
        )

        cache_dit.refresh_context(
            pipe.transformer,
            num_inference_steps=num_high_noise_steps,
            verbose=True,
        )
        cache_dit.refresh_context(
            pipe.transformer_2,
            num_inference_steps=num_low_noise_steps,
            verbose=True,
        )

    video = pipe(
        prompt=prompt,
        height=height,
        width=width,
        num_frames=81,
        num_inference_steps=steps,
        generator=torch.Generator("cpu").manual_seed(0),
    ).frames[0]
    return video

@DefTruth DefTruth self-assigned this Dec 8, 2025
@DefTruth DefTruth merged commit b583763 into main Dec 9, 2025
@DefTruth DefTruth deleted the dev branch December 9, 2025 03:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[feature request] Improve transformer-only cache-dit to reduce dependency on num_inference_steps

1 participant