Skip to content

Commit 39b5ba3

Browse files
committed
Fix Nitro padding
1 parent 23ceda3 commit 39b5ba3

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

TensorStack.StableDiffusion/Pipelines/Nitro/NitroBase.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
133133
{
134134
Seed = options.Seed,
135135
Prompt = options.Prompt,
136+
MinLength = 128,
136137
MaxLength = 128
137138
}, cancellationToken);
138139

@@ -144,6 +145,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
144145
{
145146
Seed = options.Seed,
146147
Prompt = options.NegativePrompt,
148+
MinLength = 128,
147149
MaxLength = 128
148150
}, cancellationToken);
149151
}
@@ -348,8 +350,8 @@ protected override GenerateOptions ConfigureDefaultOptions()
348350
{
349351
Steps = 20,
350352
Shift = 1f,
351-
Width = 1024,
352-
Height = 1024,
353+
Width = 512,
354+
Height = 512,
353355
GuidanceScale = 0f,
354356
Scheduler = SchedulerType.FlowMatchEulerDiscrete
355357
};
@@ -361,8 +363,8 @@ protected override GenerateOptions ConfigureDefaultOptions()
361363
{
362364
Steps = 4,
363365
Shift = 1f,
364-
Width = 1024,
365-
Height = 1024,
366+
Width = 512,
367+
Height = 512,
366368
GuidanceScale = 0,
367369
Scheduler = SchedulerType.FlowMatchEulerDiscrete
368370
};

TensorStack.TextGeneration/Pipelines/Llama/LlamaPipeline.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,20 @@ public async Task<Tensor<float>> GetLastHiddenState(GenerateOptions options, Can
104104
}
105105

106106

107+
/// <summary>
108+
/// Tokenize the prompt
109+
/// </summary>
110+
/// <param name="options">The options.</param>
111+
/// <returns>A Task representing the asynchronous operation.</returns>
112+
protected override async Task TokenizePromptAsync(GenerateOptions options)
113+
{
114+
var tokenizerResult = await Tokenizer.EncodeAsync(options.Prompt);
115+
var inputIds = tokenizerResult.InputIds.Span.Pad(Tokenizer.EOS, options.MinLength);
116+
var mask = tokenizerResult.Mask.Span.Pad(0, options.MinLength);
117+
TokenizerOutput = new TokenizerResult(inputIds, mask);
118+
}
119+
120+
107121
/// <summary>
108122
/// Gets the token processors.
109123
/// </summary>

0 commit comments

Comments
 (0)