diff --git a/diffsynth_engine/models/flux/flux_dit.py b/diffsynth_engine/models/flux/flux_dit.py index 767991b8..0a3f9cd2 100644 --- a/diffsynth_engine/models/flux/flux_dit.py +++ b/diffsynth_engine/models/flux/flux_dit.py @@ -16,6 +16,7 @@ from diffsynth_engine.models.basic import attention as attention_ops from diffsynth_engine.utils.gguf import gguf_inference from diffsynth_engine.utils.fp8_linear import fp8_inference +from diffsynth_engine.utils.aiter_linear import use_swizzle_hipblaslt from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE from diffsynth_engine.utils.parallel import ( cfg_parallel, @@ -409,6 +410,7 @@ def forward( use_cfg = hidden_states.shape[0] > 1 with ( fp8_inference(fp8_linear_enabled), + use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=fp8_linear_enabled), gguf_inference(), cfg_parallel( ( diff --git a/diffsynth_engine/models/wan/wan_dit.py b/diffsynth_engine/models/wan/wan_dit.py index 86dc9d68..10ebed53 100644 --- a/diffsynth_engine/models/wan/wan_dit.py +++ b/diffsynth_engine/models/wan/wan_dit.py @@ -21,6 +21,7 @@ ) from diffsynth_engine.utils.gguf import gguf_inference from diffsynth_engine.utils.fp8_linear import fp8_inference +from diffsynth_engine.utils.aiter_linear import use_swizzle_hipblaslt from diffsynth_engine.utils.parallel import ( cfg_parallel, cfg_parallel_unshard, @@ -390,6 +391,7 @@ def forward( use_cfg = x.shape[0] > 1 with ( fp8_inference(fp8_linear_enabled), + use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=fp8_linear_enabled), gguf_inference(), cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg), ): diff --git a/diffsynth_engine/utils/aiter_linear.py b/diffsynth_engine/utils/aiter_linear.py new file mode 100644 index 00000000..b599840a --- /dev/null +++ b/diffsynth_engine/utils/aiter_linear.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import lru_cache +from aiter import hipb_mm, hipb_create_extension, per_tensor_quant_hip +from aiter.tuned_gemm import tgemm +from aiter.ops.shuffle import shuffle_weight +from diffsynth_engine.utils.platform import DTYPE_FP8 +from contextlib import contextmanager + + +@lru_cache(maxsize=1) +def init_hipblas(): + hipb_create_extension() + + +@contextmanager +def use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=True, use_scale_for_fp8=False): + if not swizzle: + yield + return + + # Preserve original F.linear + _original_linear = F.linear + + def optimized_linear(input, weight, bias=None, otype=torch.bfloat16, + scaleA=None, scaleB=None, device="cuda"): + + input_flat = input.reshape(-1, input.shape[-1]) + + init_hipblas() + + weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(device) + output_flat = hipb_mm( + input_flat, + weight_preshuffle.t(), + bias=bias, + solution_index=-1, + out_dtype=otype, + scaleA=scaleA, + scaleB=scaleB, + scaleOut=None, + bpreshuffle=True + ) + + # Reshape output to match input dimensions + new_shape = input.shape[:-1] + (weight.shape[0],) + output = output_flat.view(new_shape) + return output + + + def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16, + scaleA=None, scaleB=None, device="cuda"): + + input_flat = input.reshape(-1, input.shape[-1]) + + if use_scale_for_fp8: + + input_flat, a_scale = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8) + weight = weight.to(DTYPE_FP8) + + init_hipblas() + + weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device) + output_flat = hipb_mm( + input_flat, + weight_preshuffle.t(), + bias=bias, + solution_index=-1, + out_dtype=otype, + scaleA=a_scale, + scaleB=scaleB, + scaleOut=None, + bpreshuffle=True + ) + + else: + input_flat = input_flat.to(DTYPE_FP8) + weight = weight.to(DTYPE_FP8) + + init_hipblas() + + weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device) + output_flat = hipb_mm( + input_flat, + weight_preshuffle.t(), + bias=bias, + solution_index=-1, + out_dtype=otype, + scaleA=scaleA, + scaleB=scaleB, + scaleOut=None, + bpreshuffle=True + ) + + + # Reshape output to match input dimensions + new_shape = input.shape[:-1] + (weight.shape[0],) + output = output_flat.view(new_shape) + return output + + if use_fp8_linear: + F.linear = optimized_linear_fp8 + else: + F.linear = optimized_linear + + yield + F.linear = _original_linear + +