From 2fdf6e8de9ca4ccc6fa5e25b8a7d4543174e7686 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Sat, 22 Nov 2025 08:44:49 +1300 Subject: [PATCH] Wan2.1 Pipeline --- .../Config/AutoEncoderConfig.cs | 7 + .../Enums/PipelineType.cs | 3 +- .../Models/AutoEncoderModel.cs | 34 ++ .../Models/TransformerWanModel.cs | 55 +++ .../Pipelines/Wan/WanBase.cs | 421 ++++++++++++++++++ .../Pipelines/Wan/WanConfig.cs | 201 +++++++++ .../Pipelines/Wan/WanPipeline.cs | 113 +++++ 7 files changed, 833 insertions(+), 1 deletion(-) create mode 100644 TensorStack.StableDiffusion/Models/TransformerWanModel.cs create mode 100644 TensorStack.StableDiffusion/Pipelines/Wan/WanBase.cs create mode 100644 TensorStack.StableDiffusion/Pipelines/Wan/WanConfig.cs create mode 100644 TensorStack.StableDiffusion/Pipelines/Wan/WanPipeline.cs diff --git a/TensorStack.StableDiffusion/Config/AutoEncoderConfig.cs b/TensorStack.StableDiffusion/Config/AutoEncoderConfig.cs index e85251a..2a1be90 100644 --- a/TensorStack.StableDiffusion/Config/AutoEncoderConfig.cs +++ b/TensorStack.StableDiffusion/Config/AutoEncoderConfig.cs @@ -1,5 +1,6 @@ // Copyright (c) TensorStack. All rights reserved. // Licensed under the Apache 2.0 License. +using System.Text.Json.Serialization; using TensorStack.Common; namespace TensorStack.StableDiffusion.Config @@ -14,5 +15,11 @@ public record AutoEncoderModelConfig : ModelConfig public int LatentChannels { get; set; } = 4; public string DecoderModelPath { get; set; } public string EncoderModelPath { get; set; } + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float[] LatentsStd { get; set; } + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float[] LatentsMean { get; set; } } } diff --git a/TensorStack.StableDiffusion/Enums/PipelineType.cs b/TensorStack.StableDiffusion/Enums/PipelineType.cs index c7aa346..43ab651 100644 --- a/TensorStack.StableDiffusion/Enums/PipelineType.cs +++ b/TensorStack.StableDiffusion/Enums/PipelineType.cs @@ -11,6 +11,7 @@ public enum PipelineType StableCascade = 10, LatentConsistency = 20, Flux = 30, - Nitro = 40 + Nitro = 40, + Wan = 50 } } diff --git a/TensorStack.StableDiffusion/Models/AutoEncoderModel.cs b/TensorStack.StableDiffusion/Models/AutoEncoderModel.cs index d50be16..1049b33 100644 --- a/TensorStack.StableDiffusion/Models/AutoEncoderModel.cs +++ b/TensorStack.StableDiffusion/Models/AutoEncoderModel.cs @@ -142,6 +142,7 @@ public virtual async Task> DecodeAsync(Tensor inputTensor, if (!disableShift) inputTensor.Add(ShiftFactor); + ApplyNormalization(inputTensor, Configuration.LatentsMean, Configuration.LatentsStd); var outputDimensions = new[] { 1, OutChannels, inputTensor.Dimensions[2] * Scale, inputTensor.Dimensions[3] * Scale }; using (var modelParameters = new ModelParameters(Decoder.Metadata, cancellationToken)) { @@ -195,6 +196,39 @@ public virtual async Task> EncodeAsync(ImageTensor inputTensor, bo } + /// + /// Applies per-channel normalization to a latent tensor in-place, equivalent to: + /// latents = latents / latentsStd + latentsMean + /// + /// The latents. + /// Per-channel mean values. Length must equal the number of channels in . + /// Per-channel standard deviation values. Length must equal the number of channels in . Each value is inverted (1 / std) before applying to the tensor. + private static void ApplyNormalization(Tensor latents, ReadOnlySpan latentsMean, ReadOnlySpan latentsStd) + { + if (latentsMean.IsEmpty || latentsStd.IsEmpty) + return; + + var dimensions = latents.Dimensions; + var channels = dimensions[1]; + + Span invStd = stackalloc float[channels]; + for (int c = 0; c < channels; c++) + invStd[c] = 1f / latentsStd[c]; + + var data = latents.Memory.Span; + var strideC = data.Length / channels; + + for (int c = 0; c < channels; c++) + { + var mean = latentsMean[c]; + var inv = invStd[c]; + var slice = data.Slice(c * strideC, strideC); + for (int i = 0; i < slice.Length; i++) + slice[i] = slice[i] * inv + mean; + } + } + + /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. /// diff --git a/TensorStack.StableDiffusion/Models/TransformerWanModel.cs b/TensorStack.StableDiffusion/Models/TransformerWanModel.cs new file mode 100644 index 0000000..aab6f88 --- /dev/null +++ b/TensorStack.StableDiffusion/Models/TransformerWanModel.cs @@ -0,0 +1,55 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Config; + +namespace TensorStack.StableDiffusion.Models +{ + /// + /// TransformerModel: Wan Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + /// + public class TransformerWanModel : TransformerModel + { + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + public TransformerWanModel(TransformerModelConfig configuration) + : base(configuration) { } + + + /// + /// Runs the Transformer model with the specified inputs + /// + /// The timestep. + /// The hidden states. + /// The encoder hidden states. + /// The cancellation token. + public async Task> RunAsync(int timestep, Tensor hiddenStates, Tensor encoderHiddenStates, CancellationToken cancellationToken = default) + { + if (!Transformer.IsLoaded()) + await Transformer.LoadAsync(cancellationToken: cancellationToken); + + using (var transformerParams = new ModelParameters(Transformer.Metadata, cancellationToken)) + { + // Inputs + transformerParams.AddInput(hiddenStates); + transformerParams.AddScalarInput(timestep); + transformerParams.AddInput(encoderHiddenStates); + + // Outputs + transformerParams.AddOutput(hiddenStates.Dimensions); + + // Inference + using (var results = await Transformer.RunInferenceAsync(transformerParams)) + { + return results[0].ToTensor(); + } + } + } + + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Wan/WanBase.cs b/TensorStack.StableDiffusion/Pipelines/Wan/WanBase.cs new file mode 100644 index 0000000..68af217 --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Wan/WanBase.cs @@ -0,0 +1,421 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Common; +using TensorStack.StableDiffusion.Enums; +using TensorStack.StableDiffusion.Models; +using TensorStack.StableDiffusion.Schedulers; +using TensorStack.TextGeneration.Tokenizers; + +namespace TensorStack.StableDiffusion.Pipelines.Wan +{ + public abstract class WanBase : PipelineBase + { + private readonly int _outputFrameRate = 16; + + /// + /// Initializes a new instance of the class. + /// + /// The transformer. + /// The tokenizer. + /// The text encoder. + /// The automatic encoder. + /// The logger. + public WanBase(TransformerWanModel transformer, T5Tokenizer tokenizer, T5EncoderModel textEncoder, AutoEncoderModel autoEncoder, ILogger logger = default) : base(logger) + { + Transformer = transformer; + Tokenizer = tokenizer; + TextEncoder = textEncoder; + AutoEncoder = autoEncoder; + Initialize(); + Logger?.LogInformation("[WanPipeline] Name: {Name}", Name); + } + + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The logger. + public WanBase(WanConfig configuration, ILogger logger = default) : this( + new TransformerWanModel(configuration.Transformer), + new T5Tokenizer(configuration.Tokenizer), + new T5EncoderModel(configuration.TextEncoder), + new AutoEncoderModel(configuration.AutoEncoder), + logger) + { + Name = configuration.Name; + } + + /// + /// Gets the type of the pipeline. + /// + public override PipelineType PipelineType => PipelineType.Wan; + + /// + /// Gets the friendly name. + /// + public override string Name { get; init; } = nameof(PipelineType.Wan); + + /// + /// Gets the tokenizer. + /// + public T5Tokenizer Tokenizer { get; init; } + + /// + /// Gets the TextEncoder. + /// + public T5EncoderModel TextEncoder { get; init; } + + /// + /// Gets the transformer. + /// + public TransformerWanModel Transformer { get; init; } + + /// + /// Gets the automatic encoder. + /// + public AutoEncoderModel AutoEncoder { get; init; } + + + /// + /// Loads the pipeline. + /// + /// The cancellation token. + public Task LoadAsync(CancellationToken cancellationToken = default) + { + // Wan pipelines are lazy loaded on first run + return Task.CompletedTask; + } + + + /// + /// Unloads the pipeline. + /// + /// The cancellation token. + public async Task UnloadAsync(CancellationToken cancellationToken = default) + { + await Task.WhenAll + ( + Transformer.UnloadAsync(), + TextEncoder.UnloadAsync(), + AutoEncoder.EncoderUnloadAsync(), + AutoEncoder.DecoderUnloadAsync() + ); + Logger?.LogInformation("[{PipeLineType}] Pipeline Unloaded", PipelineType); + } + + + /// + /// Validates the options. + /// + /// The options. + protected override void ValidateOptions(GenerateOptions options) + { + base.ValidateOptions(options); + if (!Transformer.HasControlNet && options.HasControlNet) + throw new ArgumentException("Model does not support ControlNet"); + } + + + /// + /// Creates the prompt input embeddings. + /// + /// The options. + /// The cancellation token. + protected async Task CreatePromptAsync(IPipelineOptions options, CancellationToken cancellationToken = default) + { + var cachedPrompt = GetPromptCache(options); + if (cachedPrompt is not null) + return cachedPrompt; + + // Tokenizer + var conditionalTokens = await TokenizePromptAsync(options.Prompt, cancellationToken); + var unconditionalTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken); + + // TextEncoder + var conditionalEmbeds = await EncodePromptAsync(conditionalTokens, cancellationToken); + var unconditionalEmbeds = await EncodePromptAsync(unconditionalTokens, cancellationToken); + if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled) + await TextEncoder.UnloadAsync(); + + return SetPromptCache(options, new PromptResult(conditionalEmbeds.HiddenStates, conditionalEmbeds.TextEmbeds, unconditionalEmbeds.HiddenStates, unconditionalEmbeds.TextEmbeds)); + } + + + /// + /// Tokenize prompt with Tokenizer3 + /// + /// The input text. + /// The cancellation token. + protected virtual async Task TokenizePromptAsync(string inputText, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[TokenizePrompt3Async] Begin Tokenizer"); + var tokenizerResult = await Tokenizer.EncodeAsync(inputText); + Logger.LogEnd(LogLevel.Debug, timestamp, "[TokenizePrompt3Async] Tokenizer Complete"); + return tokenizerResult; + } + + + /// + /// Encode prompt tokens with TextEncoder + /// + /// The prompt tokens. + /// The cancellation token that can be used by other objects or threads to receive notice of cancellation. + protected virtual async Task EncodePromptAsync(TokenizerResult promptTokens, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodePrompt3Async] Begin TextEncoder3"); + var textEncoderResult = await TextEncoder.RunAsync(promptTokens, cancellationToken); + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodePrompt3Async] TextEncoder3 Complete"); + return textEncoderResult; + } + + + /// + /// Decode the model latents to video + /// + /// The options. + /// The latents. + /// The cancellation token. + protected async Task DecodeLatentsAsync(IPipelineOptions options, Tensor latents, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[DecodeLatentsAsync] Begin AutoEncoder Decode"); + var decoderResult = await AutoEncoder.DecodeAsync(latents, disableShift: true, disableScale: true, cancellationToken: cancellationToken); + if (options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) + await AutoEncoder.DecoderUnloadAsync(); + + decoderResult = decoderResult + .Permute([0, 2, 1, 3, 4]) + .Reshape(decoderResult.Dimensions[1..]); + Logger.LogEnd(LogLevel.Debug, timestamp, "[DecodeLatentsAsync] AutoEncoder Decode Complete"); + return decoderResult.AsVideoTensor(_outputFrameRate); + } + + + /// + /// Encode the image to model latents + /// + /// The options. + /// The latents. + /// The cancellation token. + private async Task> EncodeLatentsAsync(IPipelineOptions options, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[EncodeLatentsAsync] Begin AutoEncoder Encode"); + var cacheResult = GetEncoderCache(options); + if (cacheResult is not null) + { + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete, Cached Result."); + return cacheResult; + } + + var inputTensor = options.InputImage.ResizeImage(options.Width, options.Height); + var encoderResult = await AutoEncoder.EncodeAsync(inputTensor, cancellationToken: cancellationToken); + if (options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) + await AutoEncoder.EncoderUnloadAsync(); + + Logger.LogEnd(LogLevel.Debug, timestamp, "[EncodeLatentsAsync] AutoEncoder Encode Complete"); + return SetEncoderCache(options, encoderResult); + } + + + /// + /// Run Transformer model inference + /// + /// The options. + /// The prompt. + /// The progress callback. + /// The cancellation token. + protected async Task> RunInferenceAsync(IPipelineOptions options, IScheduler scheduler, PromptResult prompt, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + var timestamp = Logger.LogBegin(LogLevel.Debug, "[RunInferenceAsync] Begin Transformer Inference"); + + // Prompt + var isGuidanceEnabled = IsGuidanceEnabled(options); + var promptEmbedsCond = prompt.PromptEmbeds; + var promptEmbedsUncond = prompt.NegativePromptEmbeds; + + // Latents + var latents = await CreateLatentInputAsync(options, scheduler, cancellationToken); + + // Load Model + await LoadTransformerAsync(options, progressCallback, cancellationToken); + + // Timesteps + var timesteps = scheduler.GetTimesteps(); + for (int i = 0; i < timesteps.Count; i++) + { + var timestep = timesteps[i]; + var steptime = Stopwatch.GetTimestamp(); + cancellationToken.ThrowIfCancellationRequested(); + + // Inputs. + var latentInput = scheduler.ScaleInput(timestep, latents); + + // Inference + var conditional = await Transformer.RunAsync(timestep, latentInput, promptEmbedsCond, cancellationToken: cancellationToken); + if (isGuidanceEnabled) + { + var unconditional = await Transformer.RunAsync(timestep, latentInput, promptEmbedsUncond, cancellationToken: cancellationToken); + conditional = ApplyGuidance(conditional, unconditional, options.GuidanceScale); + } + + // Scheduler + var stepResult = scheduler.Step(timestep, conditional, latents); + + // Result + latents = stepResult.Sample; + + // Progress + if (scheduler.IsFinalOrder) + progressCallback.Notify(scheduler.CurrentStep, scheduler.TotalSteps, latents, steptime); + + Logger.LogEnd(LogLevel.Debug, steptime, $"[RunInferenceAsync] Step: {i + 1}/{timesteps.Count}"); + } + + // Unload + if (options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) + await Transformer.UnloadAsync(); + + Logger.LogEnd(LogLevel.Debug, timestamp, "[RunInferenceAsync] Transformer Inference Complete"); + return latents; + } + + + /// + /// Create latent input. + /// + /// The options. + /// The scheduler. + /// The cancellation token. + private async Task> CreateLatentInputAsync(IPipelineOptions options, IScheduler scheduler, CancellationToken cancellationToken = default) + { + var dimensions = new int[] { 1, AutoEncoder.LatentChannels, 21, options.Height / AutoEncoder.Scale, options.Width / AutoEncoder.Scale }; + var noiseTensor = scheduler.CreateRandomSample(dimensions); + if (options.HasInputImage) + { + var timestep = scheduler.GetStartTimestep(); + var encoderResult = await EncodeLatentsAsync(options, cancellationToken); + return scheduler.ScaleNoise(timestep, encoderResult, noiseTensor); + } + return noiseTensor; + } + + + /// + /// Gets the model optimizations. + /// + /// The generate options. + /// The progress callback. + private ModelOptimization GetOptimizations(IPipelineOptions generateOptions, IProgress progressCallback = null) + { + var optimizations = new ModelOptimization(Optimization.None); + if (Transformer.HasOptimizationsChanged(optimizations)) + { + progressCallback.Notify("Optimizing Pipeline..."); + } + return optimizations; + } + + + /// + /// Determines whether classifier-free guidance is enabled + /// + /// The options. + private bool IsGuidanceEnabled(IPipelineOptions options) + { + return options.GuidanceScale > 1; + } + + + /// + /// Load Transformer with optimizations + /// + /// The options. + /// The progress callback. + /// The cancellation token. + private async Task LoadTransformerAsync(IPipelineOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + var optimizations = GetOptimizations(options, progressCallback); + await Transformer.LoadAsync(optimizations, cancellationToken); + } + + + /// + /// Checks the state of the pipeline. + /// + /// The options. + protected override async Task CheckPipelineState(IPipelineOptions options) + { + // Check Transformer/ControlNet status + if (options.HasControlNet && Transformer.IsLoaded()) + await Transformer.UnloadAsync(); + if (!options.HasControlNet && Transformer.IsControlNetLoaded()) + await Transformer.UnloadControlNetAsync(); + + // Check LowMemory status + if ((options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled) && TextEncoder.IsLoaded()) + await TextEncoder.UnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && Transformer.IsLoaded()) + await Transformer.UnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) && AutoEncoder.IsEncoderLoaded()) + await AutoEncoder.EncoderUnloadAsync(); + if ((options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) && AutoEncoder.IsDecoderLoaded()) + await AutoEncoder.DecoderUnloadAsync(); + } + + + /// + /// Configures the supported schedulers. + /// + protected override IReadOnlyList ConfigureSchedulers() + { + return [SchedulerType.FlowMatchEulerDiscrete, SchedulerType.FlowMatchEulerDynamic]; + } + + + /// + /// Configures the default SchedulerOptions. + /// + protected override GenerateOptions ConfigureDefaultOptions() + { + var options = new GenerateOptions + { + Steps = 50, + Shift = 3f, + Width = 832, + Height = 480, + GuidanceScale = 5f, + Scheduler = SchedulerType.FlowMatchEulerDiscrete + }; + + return options; + } + + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + private bool _disposed; + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) + { + Tokenizer?.Dispose(); + TextEncoder?.Dispose(); + Transformer?.Dispose(); + AutoEncoder?.Dispose(); + } + _disposed = true; + } + + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Wan/WanConfig.cs b/TensorStack.StableDiffusion/Pipelines/Wan/WanConfig.cs new file mode 100644 index 0000000..c73ddc5 --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Wan/WanConfig.cs @@ -0,0 +1,201 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using System; +using System.IO; +using System.Linq; +using TensorStack.Common; +using TensorStack.StableDiffusion.Config; +using TensorStack.StableDiffusion.Enums; +using TensorStack.TextGeneration.Tokenizers; + +namespace TensorStack.StableDiffusion.Pipelines.Wan +{ + public record WanConfig : PipelineConfig + { + /// + /// Initializes a new instance of the class. + /// + public WanConfig() + { + Tokenizer = new TokenizerConfig(); + TextEncoder = new CLIPModelConfig + { + PadTokenId = 0, + HiddenSize = 4096, + SequenceLength = 512, + IsFixedSequenceLength = true, + }; + Transformer = new TransformerModelConfig + { + JointAttention = 4096, + IsOptimizationSupported = true + }; + AutoEncoder = new AutoEncoderModelConfig + { + LatentChannels = 16, + ScaleFactor = 1, + ShiftFactor = 0, + LatentsMean = + [ + -0.7571f, + -0.7089f, + -0.9113f, + 0.1075f, + -0.1745f, + 0.9653f, + -0.1517f, + 1.5508f, + 0.4134f, + -0.0715f, + 0.5517f, + -0.3632f, + -0.1922f, + -0.9497f, + 0.2503f, + -0.2921f + ], + LatentsStd = + [ + 2.8184f, + 1.4541f, + 2.3275f, + 2.6558f, + 1.2196f, + 1.7708f, + 2.6052f, + 2.0743f, + 3.2687f, + 2.1526f, + 2.8652f, + 1.5579f, + 1.6382f, + 1.1253f, + 2.8251f, + 1.916f + ] + }; + } + + public string Name { get; init; } = "Wan"; + public override PipelineType Pipeline { get; } = PipelineType.Wan; + public TokenizerConfig Tokenizer { get; init; } + public CLIPModelConfig TextEncoder { get; init; } + public TransformerModelConfig Transformer { get; init; } + public AutoEncoderModelConfig AutoEncoder { get; init; } + + + /// + /// Sets the execution provider for all models. + /// + /// The execution provider. + public override void SetProvider(ExecutionProvider executionProvider) + { + TextEncoder.SetProvider(executionProvider); + Transformer.SetProvider(executionProvider); + AutoEncoder.SetProvider(executionProvider); + } + + + /// + /// Saves the configuration to file. + /// + /// The configuration file. + /// if set to true use relative paths. + public override void Save(string configFile, bool useRelativePaths = true) + { + ConfigService.Serialize(configFile, this, useRelativePaths); + } + + + /// + /// Create Wan configuration from default values + /// + /// The name. + /// Type of the model. + /// The execution provider. + /// WanConfig. + public static WanConfig FromDefault(string name, ModelType modelType, ExecutionProvider executionProvider = default) + { + var config = new WanConfig { Name = name }; + config.Transformer.ModelType = modelType; + config.SetProvider(executionProvider); + return config; + } + + + /// + /// Create StableDiffusionv configuration from json file + /// + /// The configuration file. + /// The execution provider. + /// WanConfig. + public static WanConfig FromFile(string configFile, ExecutionProvider executionProvider = default) + { + var config = ConfigService.Deserialize(configFile); + config.SetProvider(executionProvider); + return config; + } + + + /// + /// Create Wan configuration from folder structure + /// + /// The model folder. + /// Type of the model. + /// The execution provider. + /// WanConfig. + public static WanConfig FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider = default) + { + return CreateFromFolder(modelFolder, default, modelType, executionProvider); + } + + + /// + /// Create Wan configuration from folder structure + /// + /// The model folder. + /// The variant. + /// Type of the model. + /// The execution provider. + /// WanConfig. + public static WanConfig FromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider = default) + { + return CreateFromFolder(modelFolder, variant, modelType, executionProvider); + } + + + /// + /// Create Wan configuration from folder structure + /// + /// The model folder. + /// The variant. + /// The execution provider. + /// WanConfig. + public static WanConfig FromFolder(string modelFolder, string variant, ExecutionProvider executionProvider = default) + { + string[] typeOptions = ["Turbo", "Distilled", "Dist", "Flash"]; + var modelType = typeOptions.Any(v => variant.Contains(v, StringComparison.OrdinalIgnoreCase)) ? ModelType.Turbo : ModelType.Base; + return CreateFromFolder(modelFolder, variant, modelType, executionProvider); + } + + + /// + /// Create Wan configuration from folder structure + /// + /// The model folder. + /// The variant. + /// Type of the model. + /// The execution provider. + /// WanConfig. + private static WanConfig CreateFromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider = default) + { + var config = FromDefault(Path.GetFileNameWithoutExtension(modelFolder), modelType, executionProvider); + config.Tokenizer.Path = Path.Combine(modelFolder, "tokenizer", "spiece.model"); + config.TextEncoder.Path = GetVariantPath(modelFolder, "text_encoder", "model.onnx", variant); + config.Transformer.Path = GetVariantPath(modelFolder, "transformer", "model.onnx", variant); + config.AutoEncoder.DecoderModelPath = GetVariantPath(modelFolder, "vae_decoder", "model.onnx", variant); + //config.AutoEncoder.EncoderModelPath = GetVariantPath(modelFolder, "vae_encoder", "model.onnx", variant); + return config; + } + } +} diff --git a/TensorStack.StableDiffusion/Pipelines/Wan/WanPipeline.cs b/TensorStack.StableDiffusion/Pipelines/Wan/WanPipeline.cs new file mode 100644 index 0000000..98e16af --- /dev/null +++ b/TensorStack.StableDiffusion/Pipelines/Wan/WanPipeline.cs @@ -0,0 +1,113 @@ +// Copyright (c) TensorStack. All rights reserved. +// Licensed under the Apache 2.0 License. +using Microsoft.Extensions.Logging; +using System; +using System.Threading; +using System.Threading.Tasks; +using TensorStack.Common; +using TensorStack.Common.Pipeline; +using TensorStack.Common.Tensor; +using TensorStack.StableDiffusion.Common; +using TensorStack.StableDiffusion.Enums; +using TensorStack.StableDiffusion.Models; +using TensorStack.TextGeneration.Tokenizers; + +namespace TensorStack.StableDiffusion.Pipelines.Wan +{ + public class WanPipeline : WanBase, IPipeline + { + /// + /// Initializes a new instance of the class. + /// + /// The transformer. + /// The tokenizer. + /// The text encoder. + /// The automatic encoder. + /// The logger. + public WanPipeline(TransformerWanModel transformer, T5Tokenizer tokenizer, T5EncoderModel textEncoder, AutoEncoderModel autoEncoder, ILogger logger = null) + : base(transformer, tokenizer, textEncoder, autoEncoder, logger) { } + + /// + /// Initializes a new instance of the class. + /// + /// The configuration. + /// The logger. + public WanPipeline(WanConfig configuration, ILogger logger = null) + : base(configuration, logger) { } + + + /// + /// Run ImageTensor pipeline. + /// + /// The options. + /// The progress callback. + /// The cancellation token. + public async Task RunAsync(GenerateOptions options, IProgress progressCallback = null, CancellationToken cancellationToken = default) + { + ValidateOptions(options); + + var prompt = await CreatePromptAsync(options, cancellationToken); + using (var scheduler = CreateScheduler(options)) + { + var latents = await RunInferenceAsync(options, scheduler, prompt, progressCallback, cancellationToken); + return await DecodeLatentsAsync(options, latents, cancellationToken); + } + } + + + /// + /// Create Wan pipeline from StableDiffusionConfig file + /// + /// The configuration file. + /// The execution provider. + /// The logger. + /// WanPipeline. + public static WanPipeline FromConfig(string configFile, ExecutionProvider executionProvider, ILogger logger = default) + { + return new WanPipeline(WanConfig.FromFile(configFile, executionProvider), logger); + } + + + /// + /// Create Wan pipeline from folder structure + /// + /// The model folder. + /// Type of the model. + /// The execution provider. + /// The logger. + /// WanPipeline. + public static WanPipeline FromFolder(string modelFolder, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default) + { + return new WanPipeline(WanConfig.FromFolder(modelFolder, modelType, executionProvider), logger); + } + + + /// + /// Create Wan pipeline from folder structure + /// + /// The model folder. + /// The variant. + /// Type of the model. + /// The execution provider. + /// The logger. + /// WanPipeline. + public static WanPipeline FromFolder(string modelFolder, string variant, ModelType modelType, ExecutionProvider executionProvider, ILogger logger = default) + { + return new WanPipeline(WanConfig.FromFolder(modelFolder, variant, modelType, executionProvider), logger); + } + + + /// + /// Create Wan pipeline from folder structure + /// + /// The model folder. + /// The variant. + /// The execution provider. + /// The logger. + /// WanPipeline. + public static WanPipeline FromFolder(string modelFolder, string variant, ExecutionProvider executionProvider, ILogger logger = default) + { + return new WanPipeline(WanConfig.FromFolder(modelFolder, variant, executionProvider), logger); + } + } +}