Skip to content

Conversation

@toilaluan
Copy link

What does this PR do?

Adding TaylorSeer Caching method to accelerate inference speed mentioned in #12569

Author's codebase: https://github.com/Shenyi-Z/TaylorSeer

This PR structure will heavily mimic FasterCache (https://github.com/huggingface/diffusers/pull/10163/files) behaviour
I prioritze to make it work on image model pipelines (Flux, Qwen Image) for ease of evaluation

Expected Output

4->5x speeding up by these settings while keep output images are qualified

image

State Design

Core of this algorithm is about predict features of step t by using real computed features from previous step using Taylor Expansion Approximation.
We design a State class, include predict & update method and taylor_factors: Tensor to maintain iteration information. Each feature tensor will be bounded to a state instance (in double stream attention class in Flux & QwenImage, output of this module is image_features & txt_features, we will create 2 state instances for them)

  • update method will be called from real compute timestep and update taylor_factors using math formular referenced to original implementation
  • predict method will be called to predict feature from current taylor_factors using math formular referenced to original implementation

@seed93
Copy link

seed93 commented Nov 14, 2025

Will you adapt this great PR for flux kontext controlnet or flux controlnet? It would be nice if it is implemented and I am very eager to try it out.

@toilaluan
Copy link
Author

@seed93 yes, i am prioritizing for flux series and qwen image

@toilaluan
Copy link
Author

Here is analysis about TaylorSeer for Flux
Comparing with baseline, the output image is different, although PAB method give pretty close result
This result is match with author's implementation

model_id cache_method compute_dtype compile time model_memory model_max_memory_reserved inference_memory inference_max_memory_reserved
flux none fp16 False 22.318 33.313 33.322 33.322 34.305
flux pyramid_attention_broadcast fp16 False 18.394 33.313 33.322 33.322 35.789
flux taylorseer_cache fp16 False 6.457 33.313 33.322 33.322 38.18

Flux visual results

Baseline

image

Pyramid Attention Broadcast

image

TaylorSeer Cache (this implementation)

image

TaylorSeer Original (https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-Diffusers/taylorseer_flux/diffusers_taylorseer_flux.py)

image

Benchmark code is based on #10163

import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    MochiPipeline,
)
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/root/diffusers")
branch = repo.active_branch

from diffusers import (
    apply_taylorseer_cache, 
    TaylorSeerCacheConfig, 
    apply_faster_cache, 
    FasterCacheConfig, 
    apply_pyramid_attention_broadcast, 
    PyramidAttentionBroadcastConfig,
)

def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype) -> None:
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def prepare_flux_config(cache_method: str, pipe: FluxPipeline):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 950),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            current_timestep_callback=lambda: pipe.current_timestep,
        )
    elif cache_method == "taylorseer_cache":
        return TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float16, architecture="flux")
    elif cache_method == "fastercache":
        return FasterCacheConfig(
        spatial_attention_block_skip_range=2,
        spatial_attention_timestep_skip_range=(-1, 681),
        low_frequency_weight_update_timestep_range=(99, 641),
        high_frequency_weight_update_timestep_range=(-1, 301),
        spatial_attention_block_identifiers=["transformer_blocks"],
        attention_weight_callback=lambda _: 0.3,
        tensor_format="BFCHW",
    )
    elif cache_method == "none":
        return None


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


MODEL_MAPPING = {
    "flux": {
        "prepare": prepare_flux,
        "config": prepare_flux_config,
        "decode": decode_flux,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    print(f"Generator: {generator}")
    print(f"Generation kwargs: {generation_kwargs}")
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, cache_method: str, output_dir: str, dtype: str):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]

    try:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 2. Apply attention approximation technique
        config = model["config"](cache_method, pipe)
        if cache_method == "pyramid_attention_broadcast":
            apply_pyramid_attention_broadcast(pipe.transformer, config)
        elif cache_method == "fastercache":
            apply_faster_cache(pipe.transformer, config)
        elif cache_method == "taylorseer_cache":
            apply_taylorseer_cache(pipe.transformer, config)
        elif cache_method == "none":
            pass
        else:
            raise ValueError(f"Invalid {cache_method=} provided.")

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---dtype-{dtype}---cache_method-{cache_method}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": time,
            "model_memory": model_memory,
            "model_max_memory_reserved": model_max_memory_reserved,
            "inference_memory": inference_memory,
            "inference_max_memory_reserved": inference_max_memory_reserved,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": None,
            "model_memory": None,
            "model_max_memory_reserved": None,
            "inference_memory": None,
            "inference_max_memory_reserved": None,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--cache_method",
        type=str,
        default="pyramid_attention_broadcast",
        choices=["pyramid_attention_broadcast", "fastercache", "taylorseer_cache", "none"],
        help="Cache method to use.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.cache_method, args.output_dir, args.dtype)
    

@toilaluan
Copy link
Author

More comparison between this impl, baseline, author's impl

image

@toilaluan
Copy link
Author

I think current implementation is unified for every models that have attention modules, but to achieve full optimization, we have to config regex for which layer to cache or skip compute
Example in a sequence of Linear1, Act1, Linear2, Act2: we need to add hook for Linear1,act1,linear2 to do nothing (return an empty tensor) but cache output of act2
I already fix template for flux, but for other models, user have to write their own and pass it to the config init
@sayakpaul how do you think about this mechanism? I need some advises here

@sayakpaul sayakpaul requested a review from DN6 November 14, 2025 17:41
@toilaluan
Copy link
Author

toilaluan commented Nov 15, 2025

Tuning cache config really helps!

TaylorSeer cache configuration comparison

In the original code, they use 3 warmup steps and no cooldown. The output image differs significantly from the baseline, as shown in the report above.

As suggested in Shenyi-Z/TaylorSeer#12, increasing the warmup steps to 10 helps narrow the gap, but the cached output still has noticeable artifacts. This naturally suggested adding a cooldown phase (running the last steps without caching).

All runs below use the same prompt and 50 inference steps.

Visual comparison

Baseline vs. 3 warmup / 0 cooldown

Baseline (no cache) 3 warmup steps, 0 cooldown (cache)
Baseline output 3 warmup, 0 cooldown output

With only 3 warmup steps and 0 cooldown steps, the image content is not very close to the baseline.

10 warmup / 0 cooldown vs. 10 warmup / 5 cooldown

10 warmup steps, 0 cooldown (cache) 10 warmup steps, 5 cooldown (cache)
10 warmup, 0 cooldown output 10 warmup, 5 cooldown output

With 10 warmup steps, the content is closer to the baseline, but there are still many artifacts and noise.
By running the last 5 steps without caching (cooldown), most of these issues are resolved.


Hardware usage comparison

The table below shows the hardware usage comparison:

cache_method predict_steps max_order warmup_steps stop_predicts time (s) model_memory_gb inference_memory_gb max_memory_reserved_gb compute_dtype
none - - - - 22.781 33.313 33.321 37.943 fp16
taylorseer_cache 5.0 1.0 3.0 - 7.099 55.492 55.492 70.283 fp16
taylorseer_cache 5.0 1.0 3.0 45.0 9.024 55.490 55.490 70.283 fp16
taylorseer_cache 5.0 1.0 10.0 - 9.451 55.492 55.492 70.283 fp16
taylorseer_cache 5.0 1.0 10.0 45.0 11.000 55.490 55.490 70.283 fp16
taylorseer_cache 6.0 1.0 3.0 - 6.701 55.492 55.492 70.285 fp16
taylorseer_cache 6.0 1.0 3.0 45.0 8.651 55.490 55.490 70.285 fp16
taylorseer_cache 6.0 1.0 10.0 - 9.053 55.492 55.492 70.283 fp16
taylorseer_cache 6.0 1.0 10.0 45.0 11.001 55.490 55.490 70.283 fp16
image

Code

import gc
import pathlib
import pandas as pd
import torch
from itertools import product

from diffusers import FluxPipeline
from diffusers.utils.logging import set_verbosity_info

from diffusers import apply_taylorseer_cache, TaylorSeerCacheConfig

def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype):
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
    generation_kwargs = {
        "prompt": prompt,
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="pil", **generation_kwargs).images[0]
    torch.cuda.synchronize()
    return output

def main(output_dir: str):
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    compute_dtype = torch.float16
    taylor_factors_dtype = torch.float16

    param_grid = {
        'predict_steps': [5, 6],
        'max_order': [1],
        'warmup_steps': [3, 10],
        'stop_predicts': [None, 45]
    }
    combinations = list(product(*param_grid.values()))
    param_keys = list(param_grid.keys())

    results = []

    # Reset before each run
    def reset_cuda():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

    # Baseline (no cache)
    print("Running baseline...")
    reset_cuda()
    pipe, generation_kwargs = prepare_flux(compute_dtype)
    model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
    inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
    image_filename = output_dir / "baseline.png"
    image.save(image_filename)
    print(f"Baseline image saved to {image_filename}")

    info = {
        'cache_method': 'none',
        'predict_steps': None,
        'max_order': None,
        'warmup_steps': None,
        'stop_predicts': None,
        'time': time,
        'model_memory_gb': model_memory,
        'inference_memory_gb': inference_memory,
        'max_memory_reserved_gb': max_memory_reserved,
        'compute_dtype': 'fp16'
    }
    results.append(info)

    # TaylorSeer cache configurations
    for combo in combinations:
        ps, mo, ws, sp = combo
        sp_str = 'None' if sp is None else str(sp)
        print(f"Running TaylorSeer with predict_steps={ps}, max_order={mo}, warmup_steps={ws}, stop_predicts={sp}...")
        reset_cuda()
        pipe, generation_kwargs = prepare_flux(compute_dtype)
        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        config = TaylorSeerCacheConfig(
            predict_steps=ps,
            max_order=mo,
            warmup_steps=ws,
            stop_predicts=sp,
            taylor_factors_dtype=taylor_factors_dtype,
            architecture="flux"
        )
        apply_taylorseer_cache(pipe.transformer, config)
        time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        image_filename = output_dir / f"taylorseer_p{ps}_o{mo}_w{ws}_s{sp_str}.jpg"
        image.save(image_filename)
        print(f"TaylorSeer image saved to {image_filename}")

        info = {
            'cache_method': 'taylorseer_cache',
            'predict_steps': ps,
            'max_order': mo,
            'warmup_steps': ws,
            'stop_predicts': sp,
            'time': time,
            'model_memory_gb': model_memory,
            'inference_memory_gb': inference_memory,
            'max_memory_reserved_gb': max_memory_reserved,
            'compute_dtype': 'fp16'
        }
        results.append(info)

    # Save CSV
    df = pd.DataFrame(results)
    csv_path = output_dir / 'benchmark_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    # Plot latency
    import matplotlib.pyplot as plt
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=(20, 8))

    baseline_row = df[df['cache_method'] == 'none'].iloc[0]
    baseline_time = baseline_row['time']

    labels = ['baseline']
    times = [baseline_time]

    taylor_df = df[df['cache_method'] == 'taylorseer_cache']
    for _, row in taylor_df.iterrows():
        sp_str = 'None' if pd.isna(row['stop_predicts']) else str(int(row['stop_predicts']))
        label = f"p{row['predict_steps']}-o{row['max_order']}-w{row['warmup_steps']}-s{sp_str}"
        labels.append(label)
        times.append(row['time'])

    bars = ax.bar(labels, times)
    ax.set_xlabel('Configuration')
    ax.set_ylabel('Latency (s)')
    ax.set_title('Inference Latency: Baseline vs TaylorSeer Cache Configurations')
    ax.tick_params(axis='x', rotation=90)
    plt.tight_layout()

    plot_path = output_dir / 'latency_comparison.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Plot saved to {plot_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save CSV, plot, and images.")
    args = parser.parse_args()

    set_verbosity_info()
    main(args.output_dir)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants