Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions agents/src/stt/stream_adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ export class StreamAdapterWrapper extends SpeechStream {

async #run() {
const forwardInput = async () => {
for await (const input of this.input) {
if (input === SpeechStream.FLUSH_SENTINEL) {
while (true) {
const { done, value } = await this.inputReader.read();
if (done) break;

if (value === SpeechStream.FLUSH_SENTINEL) {
this.#vadStream.flush();
} else {
this.#vadStream.pushFrame(input);
this.#vadStream.pushFrame(value);
}
}
this.#vadStream.endInput();
Expand All @@ -67,18 +70,18 @@ export class StreamAdapterWrapper extends SpeechStream {
for await (const ev of this.#vadStream) {
switch (ev.type) {
case VADEventType.START_OF_SPEECH:
this.output.put({ type: SpeechEventType.START_OF_SPEECH });
this.outputWriter.write({ type: SpeechEventType.START_OF_SPEECH });
break;
case VADEventType.END_OF_SPEECH:
this.output.put({ type: SpeechEventType.END_OF_SPEECH });
this.outputWriter.write({ type: SpeechEventType.END_OF_SPEECH });

try {
const event = await this.#stt.recognize(ev.frames);
if (!event.alternatives![0].text) {
continue;
}

this.output.put(event);
this.outputWriter.write(event);
break;
} catch (error) {
let logger = log();
Expand Down
104 changes: 72 additions & 32 deletions agents/src/stt/stt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import type { AudioFrame } from '@livekit/rtc-node';
import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter';
import { EventEmitter } from 'node:events';
import type { ReadableStream } from 'node:stream/web';
import type {
ReadableStream,
ReadableStreamDefaultReader,
WritableStreamDefaultWriter,
} from 'node:stream/web';
import { log } from '../log.js';
import type { STTMetrics } from '../metrics/base.js';
import { DeferredReadableStream } from '../stream/deferred_stream.js';
import { IdentityTransform } from '../stream/identity_transform.js';
import type { AudioBuffer } from '../utils.js';
import { AsyncIterableQueue } from '../utils.js';

/** Indicates start/middle/end of speech */
export enum SpeechEventType {
Expand Down Expand Up @@ -140,102 +144,138 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter<STTCal
*/
export abstract class SpeechStream implements AsyncIterableIterator<SpeechEvent> {
protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL');
protected input = new AsyncIterableQueue<AudioFrame | typeof SpeechStream.FLUSH_SENTINEL>();
protected output = new AsyncIterableQueue<SpeechEvent>();
protected queue = new AsyncIterableQueue<SpeechEvent>();
abstract label: string;
protected input = new IdentityTransform<AudioFrame | typeof SpeechStream.FLUSH_SENTINEL>();
protected output = new IdentityTransform<SpeechEvent>();

protected inputReader: ReadableStreamDefaultReader<
AudioFrame | typeof SpeechStream.FLUSH_SENTINEL
>;
protected outputWriter: WritableStreamDefaultWriter<SpeechEvent>;
protected closed = false;
protected inputClosed = false;
abstract label: string;
#stt: STT;
private deferredInputStream: DeferredReadableStream<AudioFrame>;
private logger = log();
private inputWriter: WritableStreamDefaultWriter<AudioFrame | typeof SpeechStream.FLUSH_SENTINEL>;
private outputReader: ReadableStreamDefaultReader<SpeechEvent>;
private metricsStream: ReadableStream<SpeechEvent>;

constructor(stt: STT) {
this.#stt = stt;
this.deferredInputStream = new DeferredReadableStream<AudioFrame>();

this.inputWriter = this.input.writable.getWriter();
this.inputReader = this.input.readable.getReader();
this.outputWriter = this.output.writable.getWriter();

const [outputStream, metricsStream] = this.output.readable.tee();
this.metricsStream = metricsStream;
this.outputReader = outputStream.getReader();

this.pumpDeferredStream();
this.monitorMetrics();
this.mainTask();
}

protected async mainTask() {
// TODO(AJS-35): Implement STT with webstreams API
/**
* Reads from the deferred input stream and forwards chunks to the input writer.
*
* Note: we can't just do this.deferredInputStream.stream.pipeTo(this.input.writable)
* because the inputWriter locks the this.input.writable stream. All writes must go through
* the inputWriter.
*/
private async pumpDeferredStream() {
const reader = this.deferredInputStream.stream.getReader();
try {
const inputStream = this.deferredInputStream.stream;
const reader = inputStream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) break;
this.pushFrame(value);
await this.inputWriter.write(value);
}
} catch (error) {
this.logger.error('Error in STTStream mainTask:', error);
} catch (e) {
this.logger.error(`Error pumping deferred stream: ${e}`);
throw e;
} finally {
reader.releaseLock();
}
}

protected async monitorMetrics() {
const startTime = process.hrtime.bigint();
const metricsReader = this.metricsStream.getReader();

while (true) {
const { done, value } = await metricsReader.read();
if (done) {
break;
}

if (value.type !== SpeechEventType.RECOGNITION_USAGE) continue;

for await (const event of this.queue) {
this.output.put(event);
if (event.type !== SpeechEventType.RECOGNITION_USAGE) continue;
const duration = process.hrtime.bigint() - startTime;
const metrics: STTMetrics = {
timestamp: Date.now(),
requestId: event.requestId!,
requestId: value.requestId!,
duration: Math.trunc(Number(duration / BigInt(1000000))),
label: this.label,
audioDuration: event.recognitionUsage!.audioDuration,
audioDuration: value.recognitionUsage!.audioDuration,
streamed: true,
};
this.#stt.emit(SpeechEventType.METRICS_COLLECTED, metrics);
}
this.output.close();
}

updateInputStream(audioStream: ReadableStream<AudioFrame>) {
this.deferredInputStream.setSource(audioStream);
}

/** Push an audio frame to the STT */
/** @deprecated Use `updateInputStream` instead */
pushFrame(frame: AudioFrame) {
if (this.input.closed) {
// TODO: remove this method in future version
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.put(frame);
this.inputWriter.write(frame);
}

/** Flush the STT, causing it to process all pending text */
flush() {
if (this.input.closed) {
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.put(SpeechStream.FLUSH_SENTINEL);
this.inputWriter.write(SpeechStream.FLUSH_SENTINEL);
}

/** Mark the input as ended and forbid additional pushes */
endInput() {
if (this.input.closed) {
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.close();
this.inputClosed = true;
this.inputWriter.close();
}

next(): Promise<IteratorResult<SpeechEvent>> {
return this.output.next();
async next(): Promise<IteratorResult<SpeechEvent>> {
return this.outputReader.read().then(({ done, value }) => {
if (done) {
return { done: true, value: undefined };
}
return { done: false, value };
});
}

/** Close both the input and output of the STT stream */
close() {
this.input.close();
this.queue.close();
this.output.close();
this.input.writable.close();
this.closed = true;
}

Expand Down
20 changes: 11 additions & 9 deletions plugins/deepgram/src/stt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ export class SpeechStream extends stt.SpeechStream {
constructor(stt: STT, opts: STTOptions) {
super(stt);
this.#opts = opts;
this.closed = false;
this.#audioEnergyFilter = new AudioEnergyFilter();

this.#run();
Expand All @@ -134,7 +133,7 @@ export class SpeechStream extends stt.SpeechStream {
async #run(maxRetry = 32) {
let retries = 0;
let ws: WebSocket;
while (!this.input.closed) {
while (!this.inputClosed) {
const streamURL = new URL(API_BASE_URL_V1);
const params = {
model: this.#opts.model,
Expand Down Expand Up @@ -193,7 +192,7 @@ export class SpeechStream extends stt.SpeechStream {
}
}

this.closed = true;
this.close();
}

updateOptions(opts: Partial<STTOptions>) {
Expand Down Expand Up @@ -222,7 +221,10 @@ export class SpeechStream extends stt.SpeechStream {
samples100Ms,
);

for await (const data of this.input) {
while (true) {
const { done, value: data } = await this.inputReader.read();
if (done) break;

let frames: AudioFrame[];
if (data === SpeechStream.FLUSH_SENTINEL) {
frames = stream.flush();
Expand Down Expand Up @@ -270,7 +272,7 @@ export class SpeechStream extends stt.SpeechStream {
// It's also possible we receive a transcript without a SpeechStarted event.
if (this.#speaking) return;
this.#speaking = true;
this.queue.put({ type: stt.SpeechEventType.START_OF_SPEECH });
this.outputWriter.write({ type: stt.SpeechEventType.START_OF_SPEECH });
break;
}
// see this page:
Expand All @@ -288,16 +290,16 @@ export class SpeechStream extends stt.SpeechStream {
if (alternatives[0] && alternatives[0].text) {
if (!this.#speaking) {
this.#speaking = true;
this.queue.put({ type: stt.SpeechEventType.START_OF_SPEECH });
this.outputWriter.write({ type: stt.SpeechEventType.START_OF_SPEECH });
}

if (isFinal) {
this.queue.put({
this.outputWriter.write({
type: stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives: [alternatives[0], ...alternatives.slice(1)],
});
} else {
this.queue.put({
this.outputWriter.write({
type: stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives: [alternatives[0], ...alternatives.slice(1)],
});
Expand All @@ -309,7 +311,7 @@ export class SpeechStream extends stt.SpeechStream {
// a non-empty transcript (deepgram doesn't have a SpeechEnded event)
if (isEndpoint && this.#speaking) {
this.#speaking = false;
this.queue.put({ type: stt.SpeechEventType.END_OF_SPEECH });
this.outputWriter.write({ type: stt.SpeechEventType.END_OF_SPEECH });
}

break;
Expand Down
Loading