Skip to content
This repository was archived by the owner on Dec 9, 2025. It is now read-only.

Commit 0215aa7

Browse files
committed
SDXL ImageToImage, ImageInpaint
1 parent 9692289 commit 0215aa7

File tree

5 files changed

+334
-12
lines changed

5 files changed

+334
-12
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using SixLabors.ImageSharp;
12+
using System;
13+
using System.Collections.Generic;
14+
using System.Linq;
15+
using System.Threading.Tasks;
16+
17+
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL
18+
{
19+
public sealed class ImageDiffuser : StableDiffusionXLDiffuser
20+
{
21+
/// <summary>
22+
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
23+
/// </summary>
24+
/// <param name="configuration">The configuration.</param>
25+
/// <param name="onnxModelService">The onnx model service.</param>
26+
public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<StableDiffusionXLDiffuser> logger)
27+
: base(onnxModelService, promptService, logger)
28+
{
29+
}
30+
31+
32+
/// <summary>
33+
/// Gets the type of the diffuser.
34+
/// </summary>
35+
public override DiffuserType DiffuserType => DiffuserType.ImageToImage;
36+
37+
38+
/// <summary>
39+
/// Gets the timesteps.
40+
/// </summary>
41+
/// <param name="prompt">The prompt.</param>
42+
/// <param name="options">The options.</param>
43+
/// <param name="scheduler">The scheduler.</param>
44+
/// <returns></returns>
45+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
46+
{
47+
// Image2Image we narrow step the range by the Strength
48+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
49+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
50+
return scheduler.Timesteps.Skip(start).ToList();
51+
}
52+
53+
54+
/// <summary>
55+
/// Prepares the latents for inference.
56+
/// </summary>
57+
/// <param name="prompt">The prompt.</param>
58+
/// <param name="options">The options.</param>
59+
/// <param name="scheduler">The scheduler.</param>
60+
/// <returns></returns>
61+
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
62+
{
63+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
64+
65+
//TODO: Model Config, Channels
66+
var outputDimension = options.GetScaledDimension();
67+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
68+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
69+
{
70+
inferenceParameters.AddInputTensor(imageTensor);
71+
inferenceParameters.AddOutputBuffer(outputDimension);
72+
73+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
74+
using (var result = results.First())
75+
{
76+
var outputResult = result.ToDenseTensor();
77+
var scaledSample = outputResult
78+
.Add(scheduler.CreateRandomSample(outputDimension, options.InitialNoiseLevel))
79+
.MultiplyBy(model.ScaleFactor);
80+
81+
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);
82+
}
83+
}
84+
}
85+
86+
}
87+
}
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using SixLabors.ImageSharp;
12+
using SixLabors.ImageSharp.Processing;
13+
using System;
14+
using System.Collections.Generic;
15+
using System.Diagnostics;
16+
using System.Linq;
17+
using System.Threading;
18+
using System.Threading.Tasks;
19+
20+
namespace OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL
21+
{
22+
public sealed class InpaintLegacyDiffuser : StableDiffusionXLDiffuser
23+
{
24+
/// <summary>
25+
/// Initializes a new instance of the <see cref="InpaintLegacyDiffuser"/> class.
26+
/// </summary>
27+
/// <param name="configuration">The configuration.</param>
28+
/// <param name="onnxModelService">The onnx model service.</param>
29+
public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<StableDiffusionXLDiffuser> logger)
30+
: base(onnxModelService, promptService, logger)
31+
{
32+
}
33+
34+
35+
/// <summary>
36+
/// Gets the type of the diffuser.
37+
/// </summary>
38+
public override DiffuserType DiffuserType => DiffuserType.ImageInpaintLegacy;
39+
40+
41+
/// <summary>
42+
/// Runs the scheduler steps.
43+
/// </summary>
44+
/// <param name="modelOptions">The model options.</param>
45+
/// <param name="promptOptions">The prompt options.</param>
46+
/// <param name="schedulerOptions">The scheduler options.</param>
47+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
48+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
49+
/// <param name="progressCallback">The progress callback.</param>
50+
/// <param name="cancellationToken">The cancellation token.</param>
51+
/// <returns></returns>
52+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
53+
{
54+
using (var scheduler = GetScheduler(schedulerOptions))
55+
{
56+
// Get timesteps
57+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
58+
59+
// Create latent sample
60+
var latentsOriginal = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
61+
62+
// Create masks sample
63+
var maskImage = PrepareMask(modelOptions, promptOptions, schedulerOptions);
64+
65+
// Generate some noise
66+
var noise = scheduler.CreateRandomSample(latentsOriginal.Dimensions);
67+
68+
// Add noise to original latent
69+
var latents = scheduler.AddNoise(latentsOriginal, noise, timesteps);
70+
71+
// Get Model metadata
72+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
73+
74+
// Loop though the timesteps
75+
var step = 0;
76+
foreach (var timestep in timesteps)
77+
{
78+
step++;
79+
var stepTime = Stopwatch.GetTimestamp();
80+
cancellationToken.ThrowIfCancellationRequested();
81+
82+
// Create input tensor.
83+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
84+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
85+
var timestepTensor = CreateTimestepTensor(timestep);
86+
var addTimeIds = GetAddTimeIds(schedulerOptions, performGuidance);
87+
88+
var outputChannels = performGuidance ? 2 : 1;
89+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
90+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
91+
{
92+
inferenceParameters.AddInputTensor(inputTensor);
93+
inferenceParameters.AddInputTensor(timestepTensor);
94+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
95+
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
96+
inferenceParameters.AddInputTensor(addTimeIds);
97+
inferenceParameters.AddOutputBuffer(outputDimension);
98+
99+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
100+
using (var result = results.First())
101+
{
102+
var noisePred = result.ToDenseTensor();
103+
104+
// Perform guidance
105+
if (performGuidance)
106+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
107+
108+
// Scheduler Step
109+
var steplatents = scheduler.Step(noisePred, timestep, latents).Result;
110+
111+
// Add noise to original latent
112+
var initLatentsProper = scheduler.AddNoise(latentsOriginal, noise, new[] { timestep });
113+
114+
// Apply mask and combine
115+
latents = ApplyMaskedLatents(steplatents, initLatentsProper, maskImage);
116+
}
117+
}
118+
119+
progressCallback?.Invoke(step, timesteps.Count);
120+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
121+
}
122+
123+
// Decode Latents
124+
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
125+
}
126+
}
127+
128+
129+
/// <summary>
130+
/// Gets the timesteps.
131+
/// </summary>
132+
/// <param name="prompt">The prompt.</param>
133+
/// <param name="options">The options.</param>
134+
/// <param name="scheduler">The scheduler.</param>
135+
/// <returns></returns>
136+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
137+
{
138+
var inittimestep = Math.Min((int)(options.InferenceSteps * options.Strength), options.InferenceSteps);
139+
var start = Math.Max(options.InferenceSteps - inittimestep, 0);
140+
return scheduler.Timesteps.Skip(start).ToList();
141+
}
142+
143+
144+
/// <summary>
145+
/// Prepares the latents for inference.
146+
/// </summary>
147+
/// <param name="prompt">The prompt.</param>
148+
/// <param name="options">The options.</param>
149+
/// <param name="scheduler">The scheduler.</param>
150+
/// <returns></returns>
151+
protected override async Task<DenseTensor<float>> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
152+
{
153+
var imageTensor = prompt.InputImage.ToDenseTensor(new[] { 1, 3, options.Height, options.Width });
154+
155+
//TODO: Model Config, Channels
156+
var outputDimensions = options.GetScaledDimension();
157+
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeEncoder);
158+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
159+
{
160+
inferenceParameters.AddInputTensor(imageTensor);
161+
inferenceParameters.AddOutputBuffer(outputDimensions);
162+
163+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeEncoder, inferenceParameters);
164+
using (var result = results.First())
165+
{
166+
var outputResult = result.ToDenseTensor();
167+
var scaledSample = outputResult
168+
.Add(scheduler.CreateRandomSample(outputDimensions, options.InitialNoiseLevel))
169+
.MultiplyBy(model.ScaleFactor);
170+
171+
return scaledSample;
172+
}
173+
}
174+
}
175+
176+
177+
/// <summary>
178+
/// Prepares the mask.
179+
/// </summary>
180+
/// <param name="promptOptions">The prompt options.</param>
181+
/// <param name="schedulerOptions">The scheduler options.</param>
182+
/// <returns></returns>
183+
private DenseTensor<float> PrepareMask(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
184+
{
185+
using (var mask = promptOptions.InputImageMask.ToImage())
186+
{
187+
// Prepare the mask
188+
int width = schedulerOptions.GetScaledWidth();
189+
int height = schedulerOptions.GetScaledHeight();
190+
mask.Mutate(x => x.Grayscale());
191+
mask.Mutate(x => x.Resize(new Size(width, height), KnownResamplers.NearestNeighbor, true));
192+
var maskTensor = new DenseTensor<float>(new[] { 1, 4, width, height });
193+
mask.ProcessPixelRows(img =>
194+
{
195+
for (int x = 0; x < width; x++)
196+
{
197+
for (int y = 0; y < height; y++)
198+
{
199+
var pixelSpan = img.GetRowSpan(y);
200+
var value = 1f - (pixelSpan[x].A / 255.0f);
201+
maskTensor[0, 0, y, x] = value;
202+
maskTensor[0, 1, y, x] = value; // Needed for shape only
203+
maskTensor[0, 2, y, x] = value; // Needed for shape only
204+
maskTensor[0, 3, y, x] = value; // Needed for shape only
205+
}
206+
}
207+
});
208+
return maskTensor;
209+
}
210+
}
211+
212+
213+
/// <summary>
214+
/// Applies the masked latents.
215+
/// </summary>
216+
/// <param name="latents">The latents.</param>
217+
/// <param name="initLatentsProper">The initialize latents proper.</param>
218+
/// <param name="mask">The mask.</param>
219+
/// <returns></returns>
220+
private DenseTensor<float> ApplyMaskedLatents(DenseTensor<float> latents, DenseTensor<float> initLatentsProper, DenseTensor<float> mask)
221+
{
222+
var result = new DenseTensor<float>(latents.Dimensions);
223+
for (int i = 0; i < result.Length; i++)
224+
{
225+
float maskValue = mask.GetValue(i);
226+
result.SetValue(i, initLatentsProper.GetValue(i) * maskValue + latents.GetValue(i) * (1f - maskValue));
227+
}
228+
return result;
229+
}
230+
}
231+
}

OnnxStack.StableDiffusion/Diffusers/StableDiffusionXL/StableDiffusionXLDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptio
113113
/// </summary>
114114
/// <param name="schedulerOptions">The scheduler options.</param>
115115
/// <returns></returns>
116-
private DenseTensor<float> GetAddTimeIds(SchedulerOptions schedulerOptions, bool performGuidance)
116+
protected DenseTensor<float> GetAddTimeIds(SchedulerOptions schedulerOptions, bool performGuidance)
117117
{
118118
var addTimeIds = new float[]
119119
{

OnnxStack.StableDiffusion/Registration.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public static void AddOnnxStackStableDiffusion(this IServiceCollection serviceCo
4343

4444
//StableDiffusionXL
4545
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.StableDiffusionXL.TextDiffuser>();
46+
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.StableDiffusionXL.ImageDiffuser>();
47+
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.StableDiffusionXL.InpaintLegacyDiffuser>();
4648

4749
//LatentConsistency
4850
serviceCollection.AddSingleton<IDiffuser, StableDiffusion.Diffusers.LatentConsistency.TextDiffuser>();

OnnxStack.UI/appsettings.json

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,10 @@
489489
]
490490
},
491491
{
492-
"Name": "DreamShaper XL",
492+
"Name": "Stable Diffusion XL Base 1.0",
493493
"Description": "",
494-
"Author": "softwareweaver",
495-
"Repository": "https://huggingface.co/softwareweaver/dreamshaper-xl-1-0-Olive-Onnx",
494+
"Author": "stabilityai",
495+
"Repository": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
496496
"ImageIcon": "",
497497
"Status": "Active",
498498
"PadTokenId": 1,
@@ -505,16 +505,18 @@
505505
"SampleSize": 1024,
506506
"PipelineType": "StableDiffusionXL",
507507
"Diffusers": [
508-
"TextToImage"
508+
"TextToImage",
509+
"ImageToImage",
510+
"ImageInpaint"
509511
],
510512
"ModelFiles": [
511-
"https://huggingface.co/softwareweaver/dreamshaper-xl-1-0-Olive-Onnx/resolve/main/text_encoder/model.onnx",
512-
"https://huggingface.co/softwareweaver/dreamshaper-xl-1-0-Olive-Onnx/resolve/main/text_encoder_2/model.onnx",
513-
"https://huggingface.co/softwareweaver/dreamshaper-xl-1-0-Olive-Onnx/resolve/main/text_encoder_2/model.onnx.data",
514-
"https://huggingface.co/softwareweaver/dreamshaper-xl-1-0-Olive-Onnx/resolve/main/unet/model.onnx",
515-
"https://huggingface.co/softwareweaver/dreamshaper-xl-1-0-Olive-Onnx/resolve/main/unet/model.onnx.data",
516-
"https://huggingface.co/softwareweaver/dreamshaper-xl-1-0-Olive-Onnx/resolve/main/vae_decoder/model.onnx",
517-
"https://huggingface.co/softwareweaver/dreamshaper-xl-1-0-Olive-Onnx/resolve/main/vae_encoder/model.onnx"
513+
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/text_encoder/model.onnx",
514+
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/text_encoder_2/model.onnx",
515+
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/text_encoder_2/model.onnx_data",
516+
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/unet/model.onnx",
517+
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/unet/model.onnx_data",
518+
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/vae_decoder/model.onnx",
519+
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/vae_encoder/model.onnx"
518520
],
519521
"Images": [
520522
"",

0 commit comments

Comments
 (0)