diff --git a/StabilityMatrix.Avalonia/App.axaml b/StabilityMatrix.Avalonia/App.axaml index 95f8e0d5..3bbd1981 100644 --- a/StabilityMatrix.Avalonia/App.axaml +++ b/StabilityMatrix.Avalonia/App.axaml @@ -95,6 +95,7 @@ + diff --git a/StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml b/StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml new file mode 100644 index 00000000..9c47cbb3 --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml @@ -0,0 +1,79 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml.cs b/StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml.cs new file mode 100644 index 00000000..a211431b --- /dev/null +++ b/StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml.cs @@ -0,0 +1,9 @@ +using Avalonia; +using Avalonia.Controls; +using Avalonia.Controls.Primitives; +using Injectio.Attributes; + +namespace StabilityMatrix.Avalonia.Controls; + +[RegisterTransient] +public class TiledVAECard : TemplatedControlBase { } diff --git a/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs b/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs index a42f7391..26a87502 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs @@ -27,6 +27,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base; [JsonDerivedType(typeof(PlasmaNoiseCardViewModel), PlasmaNoiseCardViewModel.ModuleKey)] [JsonDerivedType(typeof(NrsCardViewModel), NrsCardViewModel.ModuleKey)] [JsonDerivedType(typeof(CfzCudnnToggleCardViewModel), CfzCudnnToggleCardViewModel.ModuleKey)] +[JsonDerivedType(typeof(TiledVAECardViewModel), TiledVAECardViewModel.ModuleKey)] [JsonDerivedType(typeof(FreeUModule))] [JsonDerivedType(typeof(HiresFixModule))] [JsonDerivedType(typeof(FluxHiresFixModule))] @@ -43,6 +44,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base; [JsonDerivedType(typeof(PlasmaNoiseModule))] [JsonDerivedType(typeof(NRSModule))] [JsonDerivedType(typeof(CfzCudnnToggleModule))] +[JsonDerivedType(typeof(TiledVAEModule))] public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/TiledVAEModule.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/TiledVAEModule.cs new file mode 100644 index 00000000..3199a53f --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/Modules/TiledVAEModule.cs @@ -0,0 +1,55 @@ +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Models.Inference; +using StabilityMatrix.Avalonia.Services; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; +using StabilityMatrix.Core.Models.Api.Comfy.Nodes; + +namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules; + +[ManagedService] +[RegisterTransient] +public class TiledVAEModule : ModuleBase +{ + public TiledVAEModule(IServiceManager vmFactory) + : base(vmFactory) + { + Title = "Tiled VAE Decode"; + AddCards(vmFactory.Get()); + } + + protected override void OnApplyStep(ModuleApplyStepEventArgs e) + { + var card = GetCard(); + + // Register a pre-output action that replaces standard VAE decode with tiled decode + e.PreOutputActions.Add(args => + { + var builder = args.Builder; + + // Only apply if primary is in latent space + if (builder.Connections.Primary?.IsT0 != true) + return; + + var latent = builder.Connections.Primary.AsT0; + var vae = builder.Connections.GetDefaultVAE(); + + // Use tiled VAE decode instead of standard decode + var tiledDecode = builder.Nodes.AddTypedNode( + new ComfyNodeBuilder.TiledVAEDecode + { + Name = builder.Nodes.GetUniqueName("TiledVAEDecode"), + Samples = latent, + Vae = vae, + TileSize = card.TileSize, + Overlap = card.Overlap, + TemporalSize = card.TemporalSize, + TemporalOverlap = card.TemporalOverlap + } + ); + + // Update primary connection to the decoded image + builder.Connections.Primary = tiledDecode.Output; + }); + } +} diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs index 0e165faf..8cdb313a 100644 --- a/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs @@ -155,6 +155,7 @@ TabContext tabContext typeof(RescaleCfgModule), typeof(PlasmaNoiseModule), typeof(NRSModule), + typeof(TiledVAEModule), ]; }); } diff --git a/StabilityMatrix.Avalonia/ViewModels/Inference/TiledVAECardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/Inference/TiledVAECardViewModel.cs new file mode 100644 index 00000000..ae292077 --- /dev/null +++ b/StabilityMatrix.Avalonia/ViewModels/Inference/TiledVAECardViewModel.cs @@ -0,0 +1,40 @@ +using System.ComponentModel.DataAnnotations; +using CommunityToolkit.Mvvm.ComponentModel; +using Injectio.Attributes; +using StabilityMatrix.Avalonia.Controls; +using StabilityMatrix.Avalonia.ViewModels.Base; +using StabilityMatrix.Core.Attributes; + +namespace StabilityMatrix.Avalonia.ViewModels.Inference; + +[View(typeof(TiledVAECard))] +[ManagedService] +[RegisterTransient] +public partial class TiledVAECardViewModel : LoadableViewModelBase +{ + public const string ModuleKey = "TiledVAE"; + + [ObservableProperty] + [NotifyDataErrorInfo] + [Required] + [Range(64, 4096)] + private int tileSize = 512; + + [ObservableProperty] + [NotifyDataErrorInfo] + [Required] + [Range(0, 4096)] + private int overlap = 64; + + [ObservableProperty] + [NotifyDataErrorInfo] + [Required] + [Range(8, 4096)] + private int temporalSize = 64; + + [ObservableProperty] + [NotifyDataErrorInfo] + [Required] + [Range(4, 4096)] + private int temporalOverlap = 8; +} diff --git a/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs b/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs index 181cfc70..c164a632 100644 --- a/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs +++ b/StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs @@ -62,6 +62,25 @@ public record VAEDecode : ComfyTypedNodeBase public required VAENodeConnection Vae { get; init; } } + [TypedNodeOptions(Name = "VAEDecodeTiled")] + public record TiledVAEDecode : ComfyTypedNodeBase + { + public required LatentNodeConnection Samples { get; init; } + public required VAENodeConnection Vae { get; init; } + + [Range(64, 4096)] + public int TileSize { get; init; } = 512; + + [Range(0, 4096)] + public int Overlap { get; init; } = 64; + + [Range(8, 4096)] + public int TemporalSize { get; init; } = 64; + + [Range(4, 4096)] + public int TemporalOverlap { get; init; } = 8; + } + public record KSampler : ComfyTypedNodeBase { public required ModelNodeConnection Model { get; init; }