88using TensorStack . TextGeneration . Pipelines . Other ;
99using TensorStack . TextGeneration . Pipelines . Phi ;
1010using TensorStack . Providers ;
11+ using TensorStack . TextGeneration . Pipelines . Whisper ;
12+ using TensorStack . Common . Tensor ;
13+ using TensorStack . TextGeneration . Pipelines . Supertonic ;
1114
1215namespace TensorStack . Example . Services
1316{
1417 public class TextService : ServiceBase , ITextService
1518 {
1619 private readonly Settings _settings ;
17- private IPipeline < GenerateResult , GenerateOptions , GenerateProgress > _greedyPipeline ;
18- private IPipeline < GenerateResult [ ] , SearchOptions , GenerateProgress > _beamSearchPipeline ;
20+ private IPipeline _currentPipeline ;
1921 private CancellationTokenSource _cancellationTokenSource ;
2022 private bool _isLoaded ;
2123 private bool _isLoading ;
@@ -77,34 +79,41 @@ public async Task LoadAsync(TextModel model, Device device)
7779 using ( _cancellationTokenSource = new CancellationTokenSource ( ) )
7880 {
7981 var cancellationToken = _cancellationTokenSource . Token ;
80- if ( _greedyPipeline != null )
81- await _greedyPipeline . UnloadAsync ( cancellationToken ) ;
82+ if ( _currentPipeline != null )
83+ await _currentPipeline . UnloadAsync ( cancellationToken ) ;
8284
8385 var provider = device . GetProvider ( ) ;
8486 var providerCPU = Provider . GetProvider ( DeviceType . CPU ) ; // TODO: DirectML not working with decoder
8587 if ( model . Type == TextModelType . Phi3 )
8688 {
8789 if ( ! Enum . TryParse < PhiType > ( model . Version , true , out var phiType ) )
88- throw new ArgumentException ( "Invalid PhiType Version" ) ;
90+ throw new ArgumentException ( "Invalid Phi Version" ) ;
8991
90- var pipeline = Phi3Pipeline . Create ( providerCPU , model . Path , phiType ) ;
91- _greedyPipeline = pipeline ;
92- _beamSearchPipeline = pipeline ;
92+ _currentPipeline = Phi3Pipeline . Create ( providerCPU , model . Path , phiType ) ;
9393 }
9494 else if ( model . Type == TextModelType . Summary )
9595 {
96- var pipeline = SummaryPipeline . Create ( provider , providerCPU , model . Path ) ;
97- _greedyPipeline = pipeline ;
98- _beamSearchPipeline = pipeline ;
96+ _currentPipeline = SummaryPipeline . Create ( provider , providerCPU , model . Path ) ;
9997 }
100- await Task . Run ( ( ) => _greedyPipeline . LoadAsync ( cancellationToken ) , cancellationToken ) ;
98+ else if ( model . Type == TextModelType . Whisper )
99+ {
100+ if ( ! Enum . TryParse < WhisperType > ( model . Version , true , out var whisperType ) )
101+ throw new ArgumentException ( "Invalid Whisper Version" ) ;
102+
103+ _currentPipeline = WhisperPipeline . Create ( provider , providerCPU , model . Path , whisperType ) ;
104+ }
105+ else if ( model . Type == TextModelType . Supertonic )
106+ {
107+ _currentPipeline = SupertonicPipeline . Create ( model . Path , provider ) ;
108+ }
109+ await Task . Run ( ( ) => _currentPipeline . LoadAsync ( cancellationToken ) , cancellationToken ) ;
101110
102111 }
103112 }
104113 catch ( OperationCanceledException )
105114 {
106- _greedyPipeline ? . Dispose ( ) ;
107- _greedyPipeline = null ;
115+ _currentPipeline ? . Dispose ( ) ;
116+ _currentPipeline = null ;
108117 _currentConfig = null ;
109118 throw ;
110119 }
@@ -148,11 +157,63 @@ public async Task<GenerateResult[]> ExecuteAsync(TextRequest options)
148157 if ( options . Beams == 0 )
149158 {
150159 // Greedy Search
151- return [ await _greedyPipeline . RunAsync ( pipelineOptions , cancellationToken : _cancellationTokenSource . Token ) ] ;
160+ var greedyPipeline = _currentPipeline as IPipeline < GenerateResult , GenerateOptions , GenerateProgress > ;
161+ return [ await greedyPipeline . RunAsync ( pipelineOptions , cancellationToken : _cancellationTokenSource . Token ) ] ;
162+ }
163+
164+ // Beam Search
165+ var beamSearchPipeline = _currentPipeline as IPipeline < GenerateResult [ ] , SearchOptions , GenerateProgress > ;
166+ return await beamSearchPipeline . RunAsync ( new SearchOptions ( pipelineOptions ) , cancellationToken : _cancellationTokenSource . Token ) ;
167+ } ) ;
168+
169+ return pipelineResult ;
170+ }
171+ }
172+ finally
173+ {
174+ IsExecuting = false ;
175+ }
176+ }
177+
178+
179+ public async Task < GenerateResult [ ] > ExecuteAsync ( WhisperRequest options )
180+ {
181+ try
182+ {
183+ IsExecuting = true ;
184+ using ( _cancellationTokenSource = new CancellationTokenSource ( ) )
185+ {
186+ var pipelineOptions = new WhisperOptions
187+ {
188+ Prompt = options . Prompt ,
189+ Seed = options . Seed ,
190+ Beams = options . Beams ,
191+ TopK = options . TopK ,
192+ TopP = options . TopP ,
193+ Temperature = options . Temperature ,
194+ MaxLength = options . MaxLength ,
195+ MinLength = options . MinLength ,
196+ NoRepeatNgramSize = options . NoRepeatNgramSize ,
197+ LengthPenalty = options . LengthPenalty ,
198+ DiversityLength = options . DiversityLength ,
199+ EarlyStopping = options . EarlyStopping ,
200+ AudioInput = options . AudioInput ,
201+ Language = options . Language ,
202+ Task = options . Task
203+ } ;
204+
205+ var pipelineResult = await Task . Run ( async ( ) =>
206+ {
207+ if ( options . Beams == 0 )
208+ {
209+ // Greedy Search
210+ var greedyPipeline = _currentPipeline as IPipeline < GenerateResult , WhisperOptions , GenerateProgress > ;
211+ return [ await greedyPipeline . RunAsync ( pipelineOptions , cancellationToken : _cancellationTokenSource . Token ) ] ;
152212 }
153213
154214 // Beam Search
155- return await _beamSearchPipeline . RunAsync ( new SearchOptions ( pipelineOptions ) , cancellationToken : _cancellationTokenSource . Token ) ;
215+ var beamSearchPipeline = _currentPipeline as IPipeline < GenerateResult [ ] , WhisperSearchOptions , GenerateProgress > ;
216+ return await beamSearchPipeline . RunAsync ( new WhisperSearchOptions ( pipelineOptions ) , cancellationToken : _cancellationTokenSource . Token ) ;
156217 } ) ;
157218
158219 return pipelineResult ;
@@ -165,6 +226,34 @@ public async Task<GenerateResult[]> ExecuteAsync(TextRequest options)
165226 }
166227
167228
229+ public async Task < AudioTensor > ExecuteAsync ( SupertonicRequest options )
230+ {
231+ try
232+ {
233+ IsExecuting = true ;
234+ using ( _cancellationTokenSource = new CancellationTokenSource ( ) )
235+ {
236+ var pipeline = _currentPipeline as IPipeline < AudioTensor , SupertonicOptions , GenerateProgress > ;
237+ var pipelineOptions = new SupertonicOptions
238+ {
239+ TextInput = options . InputText ,
240+ VoiceStyle = options . VoiceStyle ,
241+ Steps = options . Steps ,
242+ Speed = options . Speed ,
243+ SilenceDuration = options . SilenceDuration ,
244+ Seed = options . Seed ,
245+ } ;
246+
247+ return await pipeline . RunAsync ( pipelineOptions , cancellationToken : _cancellationTokenSource . Token ) ;
248+ }
249+ }
250+ finally
251+ {
252+ IsExecuting = false ;
253+ }
254+ }
255+
256+
168257 /// <summary>
169258 /// Cancel the running task (Load or Execute)
170259 /// </summary>
@@ -179,12 +268,12 @@ public async Task CancelAsync()
179268 /// </summary>
180269 public async Task UnloadAsync ( )
181270 {
182- if ( _greedyPipeline != null )
271+ if ( _currentPipeline != null )
183272 {
184273 await _cancellationTokenSource . SafeCancelAsync ( ) ;
185- await _greedyPipeline . UnloadAsync ( ) ;
186- _greedyPipeline . Dispose ( ) ;
187- _greedyPipeline = null ;
274+ await _currentPipeline . UnloadAsync ( ) ;
275+ _currentPipeline . Dispose ( ) ;
276+ _currentPipeline = null ;
188277 _currentConfig = null ;
189278 }
190279
@@ -205,6 +294,8 @@ public interface ITextService
205294 Task UnloadAsync ( ) ;
206295 Task CancelAsync ( ) ;
207296 Task < GenerateResult [ ] > ExecuteAsync ( TextRequest options ) ;
297+ Task < GenerateResult [ ] > ExecuteAsync ( WhisperRequest options ) ;
298+ Task < AudioTensor > ExecuteAsync ( SupertonicRequest options ) ;
208299 }
209300
210301
@@ -224,4 +315,22 @@ public record TextRequest : ITransformerRequest
224315 public int DiversityLength { get ; set ; } = 5 ;
225316 }
226317
318+
319+ public record WhisperRequest : TextRequest
320+ {
321+ public AudioTensor AudioInput { get ; set ; }
322+ public LanguageType Language { get ; set ; } = LanguageType . EN ;
323+ public TaskType Task { get ; set ; } = TaskType . Transcribe ;
324+ }
325+
326+
327+ public record SupertonicRequest
328+ {
329+ public string InputText { get ; set ; }
330+ public string VoiceStyle { get ; set ; }
331+ public int Steps { get ; set ; } = 5 ;
332+ public float Speed { get ; set ; } = 1f ;
333+ public float SilenceDuration { get ; set ; } = 0.3f ;
334+ public int Seed { get ; set ; }
335+ }
227336}
0 commit comments