Skip to content

Commit 7a35320

Browse files
committed
Merge branch 'TextDemo'
2 parents 0fb822a + 6247915 commit 7a35320

22 files changed

+1234
-120
lines changed

Examples/TensorStack.Example.TextGeneration/Common/TextModel.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public class TextModel : BaseModel
2020
public enum TextModelType
2121
{
2222
Summary = 0,
23-
Phi3 = 1
23+
Phi3 = 1,
24+
Whisper = 2,
25+
Supertonic = 3
2426
}
2527
}

Examples/TensorStack.Example.TextGeneration/MainWindow.xaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@
1515

1616
<!--Main Menu-->
1717
<Grid DockPanel.Dock="Top" WindowChrome.IsHitTestVisibleInChrome="True">
18-
<UniformGrid Columns="3" Height="30" Margin="2">
18+
<UniformGrid Columns="5" Height="30" Margin="2">
1919

2020
<!--Logo-->
2121
<Grid IsHitTestVisible="False">
2222
<Image Source="{StaticResource ImageTensorstackText}" Height="32" HorizontalAlignment="Left" Margin="4,2,50,0" />
2323
</Grid>
2424

2525
<!--Views-->
26-
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.TextSummary}" Content="Text Summary" />
26+
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.Summary}" Content="Summary" />
27+
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.Whisper}" Content="Whisper" />
28+
<Button Command="{Binding NavigateCommand}" CommandParameter="{x:Static Views:View.Supertonic}" Content="Supertonic" />
2729

2830
<!--Window Options-->
2931
<StackPanel Orientation="Horizontal" HorizontalAlignment="Right">

Examples/TensorStack.Example.TextGeneration/MainWindow.xaml.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public MainWindow(Settings configuration, NavigationService navigation)
1717
NavigateCommand = new AsyncRelayCommand<View>(NavigateAsync, CanNavigate);
1818
InitializeComponent();
1919

20-
NavigateCommand.Execute(View.TextSummary);
20+
NavigateCommand.Execute(View.Whisper);
2121
}
2222

2323
public NavigationService Navigation { get; }

Examples/TensorStack.Example.TextGeneration/Services/TextService.cs

Lines changed: 129 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
using TensorStack.TextGeneration.Pipelines.Other;
99
using TensorStack.TextGeneration.Pipelines.Phi;
1010
using TensorStack.Providers;
11+
using TensorStack.TextGeneration.Pipelines.Whisper;
12+
using TensorStack.Common.Tensor;
13+
using TensorStack.TextGeneration.Pipelines.Supertonic;
1114

1215
namespace TensorStack.Example.Services
1316
{
1417
public class TextService : ServiceBase, ITextService
1518
{
1619
private readonly Settings _settings;
17-
private IPipeline<GenerateResult, GenerateOptions, GenerateProgress> _greedyPipeline;
18-
private IPipeline<GenerateResult[], SearchOptions, GenerateProgress> _beamSearchPipeline;
20+
private IPipeline _currentPipeline;
1921
private CancellationTokenSource _cancellationTokenSource;
2022
private bool _isLoaded;
2123
private bool _isLoading;
@@ -77,34 +79,41 @@ public async Task LoadAsync(TextModel model, Device device)
7779
using (_cancellationTokenSource = new CancellationTokenSource())
7880
{
7981
var cancellationToken = _cancellationTokenSource.Token;
80-
if (_greedyPipeline != null)
81-
await _greedyPipeline.UnloadAsync(cancellationToken);
82+
if (_currentPipeline != null)
83+
await _currentPipeline.UnloadAsync(cancellationToken);
8284

8385
var provider = device.GetProvider();
8486
var providerCPU = Provider.GetProvider(DeviceType.CPU); // TODO: DirectML not working with decoder
8587
if (model.Type == TextModelType.Phi3)
8688
{
8789
if (!Enum.TryParse<PhiType>(model.Version, true, out var phiType))
88-
throw new ArgumentException("Invalid PhiType Version");
90+
throw new ArgumentException("Invalid Phi Version");
8991

90-
var pipeline = Phi3Pipeline.Create(providerCPU, model.Path, phiType);
91-
_greedyPipeline = pipeline;
92-
_beamSearchPipeline = pipeline;
92+
_currentPipeline = Phi3Pipeline.Create(providerCPU, model.Path, phiType);
9393
}
9494
else if (model.Type == TextModelType.Summary)
9595
{
96-
var pipeline = SummaryPipeline.Create(provider, providerCPU, model.Path);
97-
_greedyPipeline = pipeline;
98-
_beamSearchPipeline = pipeline;
96+
_currentPipeline = SummaryPipeline.Create(provider, providerCPU, model.Path);
9997
}
100-
await Task.Run(() => _greedyPipeline.LoadAsync(cancellationToken), cancellationToken);
98+
else if (model.Type == TextModelType.Whisper)
99+
{
100+
if (!Enum.TryParse<WhisperType>(model.Version, true, out var whisperType))
101+
throw new ArgumentException("Invalid Whisper Version");
102+
103+
_currentPipeline = WhisperPipeline.Create(provider, providerCPU, model.Path, whisperType);
104+
}
105+
else if (model.Type == TextModelType.Supertonic)
106+
{
107+
_currentPipeline = SupertonicPipeline.Create(model.Path, provider);
108+
}
109+
await Task.Run(() => _currentPipeline.LoadAsync(cancellationToken), cancellationToken);
101110

102111
}
103112
}
104113
catch (OperationCanceledException)
105114
{
106-
_greedyPipeline?.Dispose();
107-
_greedyPipeline = null;
115+
_currentPipeline?.Dispose();
116+
_currentPipeline = null;
108117
_currentConfig = null;
109118
throw;
110119
}
@@ -148,11 +157,63 @@ public async Task<GenerateResult[]> ExecuteAsync(TextRequest options)
148157
if (options.Beams == 0)
149158
{
150159
// Greedy Search
151-
return [await _greedyPipeline.RunAsync(pipelineOptions, cancellationToken: _cancellationTokenSource.Token)];
160+
var greedyPipeline = _currentPipeline as IPipeline<GenerateResult, GenerateOptions, GenerateProgress>;
161+
return [await greedyPipeline.RunAsync(pipelineOptions, cancellationToken: _cancellationTokenSource.Token)];
162+
}
163+
164+
// Beam Search
165+
var beamSearchPipeline = _currentPipeline as IPipeline<GenerateResult[], SearchOptions, GenerateProgress>;
166+
return await beamSearchPipeline.RunAsync(new SearchOptions(pipelineOptions), cancellationToken: _cancellationTokenSource.Token);
167+
});
168+
169+
return pipelineResult;
170+
}
171+
}
172+
finally
173+
{
174+
IsExecuting = false;
175+
}
176+
}
177+
178+
179+
public async Task<GenerateResult[]> ExecuteAsync(WhisperRequest options)
180+
{
181+
try
182+
{
183+
IsExecuting = true;
184+
using (_cancellationTokenSource = new CancellationTokenSource())
185+
{
186+
var pipelineOptions = new WhisperOptions
187+
{
188+
Prompt = options.Prompt,
189+
Seed = options.Seed,
190+
Beams = options.Beams,
191+
TopK = options.TopK,
192+
TopP = options.TopP,
193+
Temperature = options.Temperature,
194+
MaxLength = options.MaxLength,
195+
MinLength = options.MinLength,
196+
NoRepeatNgramSize = options.NoRepeatNgramSize,
197+
LengthPenalty = options.LengthPenalty,
198+
DiversityLength = options.DiversityLength,
199+
EarlyStopping = options.EarlyStopping,
200+
AudioInput = options.AudioInput,
201+
Language = options.Language,
202+
Task = options.Task
203+
};
204+
205+
var pipelineResult = await Task.Run(async () =>
206+
{
207+
if (options.Beams == 0)
208+
{
209+
// Greedy Search
210+
var greedyPipeline = _currentPipeline as IPipeline<GenerateResult, WhisperOptions, GenerateProgress>;
211+
return [await greedyPipeline.RunAsync(pipelineOptions, cancellationToken: _cancellationTokenSource.Token)];
152212
}
153213

154214
// Beam Search
155-
return await _beamSearchPipeline.RunAsync(new SearchOptions(pipelineOptions), cancellationToken: _cancellationTokenSource.Token);
215+
var beamSearchPipeline = _currentPipeline as IPipeline<GenerateResult[], WhisperSearchOptions, GenerateProgress>;
216+
return await beamSearchPipeline.RunAsync(new WhisperSearchOptions(pipelineOptions), cancellationToken: _cancellationTokenSource.Token);
156217
});
157218

158219
return pipelineResult;
@@ -165,6 +226,34 @@ public async Task<GenerateResult[]> ExecuteAsync(TextRequest options)
165226
}
166227

167228

229+
public async Task<AudioTensor> ExecuteAsync(SupertonicRequest options)
230+
{
231+
try
232+
{
233+
IsExecuting = true;
234+
using (_cancellationTokenSource = new CancellationTokenSource())
235+
{
236+
var pipeline = _currentPipeline as IPipeline<AudioTensor, SupertonicOptions, GenerateProgress>;
237+
var pipelineOptions = new SupertonicOptions
238+
{
239+
TextInput = options.InputText,
240+
VoiceStyle = options.VoiceStyle,
241+
Steps = options.Steps,
242+
Speed = options.Speed,
243+
SilenceDuration = options.SilenceDuration,
244+
Seed = options.Seed,
245+
};
246+
247+
return await pipeline.RunAsync(pipelineOptions, cancellationToken: _cancellationTokenSource.Token);
248+
}
249+
}
250+
finally
251+
{
252+
IsExecuting = false;
253+
}
254+
}
255+
256+
168257
/// <summary>
169258
/// Cancel the running task (Load or Execute)
170259
/// </summary>
@@ -179,12 +268,12 @@ public async Task CancelAsync()
179268
/// </summary>
180269
public async Task UnloadAsync()
181270
{
182-
if (_greedyPipeline != null)
271+
if (_currentPipeline != null)
183272
{
184273
await _cancellationTokenSource.SafeCancelAsync();
185-
await _greedyPipeline.UnloadAsync();
186-
_greedyPipeline.Dispose();
187-
_greedyPipeline = null;
274+
await _currentPipeline.UnloadAsync();
275+
_currentPipeline.Dispose();
276+
_currentPipeline = null;
188277
_currentConfig = null;
189278
}
190279

@@ -205,6 +294,8 @@ public interface ITextService
205294
Task UnloadAsync();
206295
Task CancelAsync();
207296
Task<GenerateResult[]> ExecuteAsync(TextRequest options);
297+
Task<GenerateResult[]> ExecuteAsync(WhisperRequest options);
298+
Task<AudioTensor> ExecuteAsync(SupertonicRequest options);
208299
}
209300

210301

@@ -224,4 +315,22 @@ public record TextRequest : ITransformerRequest
224315
public int DiversityLength { get; set; } = 5;
225316
}
226317

318+
319+
public record WhisperRequest : TextRequest
320+
{
321+
public AudioTensor AudioInput { get; set; }
322+
public LanguageType Language { get; set; } = LanguageType.EN;
323+
public TaskType Task { get; set; } = TaskType.Transcribe;
324+
}
325+
326+
327+
public record SupertonicRequest
328+
{
329+
public string InputText { get; set; }
330+
public string VoiceStyle { get; set; }
331+
public int Steps { get; set; } = 5;
332+
public float Speed { get; set; } = 1f;
333+
public float SilenceDuration { get; set; } = 0.3f;
334+
public int Seed { get; set; }
335+
}
227336
}

Examples/TensorStack.Example.TextGeneration/Settings.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ public class Settings : IUIConfiguration
1919
public string VideoCodec { get; set; } = "mp4v";
2020
public string DirectoryTemp { get; set; }
2121
public IReadOnlyList<Device> Devices { get; set; }
22-
public ObservableCollection<TextModel> TextModels { get; set; }
22+
public ObservableCollection<TextModel> TextToTextModels { get; set; }
23+
public ObservableCollection<TextModel> TextToAudioModels { get; set; }
24+
public ObservableCollection<TextModel> AudioToTextModels { get; set; }
2325

2426

2527
public void Initialize()

0 commit comments

Comments
 (0)