From 00c63cbd2387b9f2b75caacac9cde6354f112154 Mon Sep 17 00:00:00 2001 From: Szymon Sypniewicz Date: Mon, 30 Mar 2026 01:40:14 +0200 Subject: [PATCH] fix stabilize transcription session state and diarization fallback --- .../Meeting/NotificationService.swift | 6 +- .../OpenOats/Settings/SettingsStore.swift | 8 +- .../Transcription/TranscriptionEngine.swift | 110 ++++++++++--- .../NotificationServiceTests.swift | 14 ++ .../OpenOatsTests/SettingsStoreTests.swift | 30 ++++ .../TranscriptionEngineTests.swift | 154 ++++++++++++++++++ 6 files changed, 297 insertions(+), 25 deletions(-) create mode 100644 OpenOats/Tests/OpenOatsTests/NotificationServiceTests.swift create mode 100644 OpenOats/Tests/OpenOatsTests/TranscriptionEngineTests.swift diff --git a/OpenOats/Sources/OpenOats/Meeting/NotificationService.swift b/OpenOats/Sources/OpenOats/Meeting/NotificationService.swift index 8528d700..9d984cea 100644 --- a/OpenOats/Sources/OpenOats/Meeting/NotificationService.swift +++ b/OpenOats/Sources/OpenOats/Meeting/NotificationService.swift @@ -29,6 +29,8 @@ final class NotificationService: NSObject, UNUserNotificationCenterDelegate { private static let notMeetingAction = "NOT_A_MEETING" private static let ignoreAppAction = "IGNORE_APP" private static let dismissAction = "DISMISS" + static let batchCompletedTitle = "Re-transcription Complete" + static let batchCompletedBody = "Re-transcription is complete. Your meeting transcript has been updated with higher-quality text." override init() { super.init() @@ -144,8 +146,8 @@ final class NotificationService: NSObject, UNUserNotificationCenterDelegate { guard await ensurePermission() else { return } let content = UNMutableNotificationContent() - content.title = "Re-transcription Complete" - content.body = "Batch transcription is complete. Your meeting transcript has been updated with higher-quality text." + content.title = Self.batchCompletedTitle + content.body = Self.batchCompletedBody content.sound = .default let request = UNNotificationRequest( diff --git a/OpenOats/Sources/OpenOats/Settings/SettingsStore.swift b/OpenOats/Sources/OpenOats/Settings/SettingsStore.swift index a11d6456..8750a3a5 100644 --- a/OpenOats/Sources/OpenOats/Settings/SettingsStore.swift +++ b/OpenOats/Sources/OpenOats/Settings/SettingsStore.swift @@ -9,6 +9,8 @@ import Security final class SettingsStore { private let defaults: UserDefaults private let secretStore: AppSecretStore + private static let enableLiveTranscriptCleanupLegacyKey = "enableTranscriptRefinement" + private static let enableBatchRetranscriptionLegacyKey = "enableBatchRefinement" // MARK: - AI Settings @@ -206,6 +208,7 @@ final class SettingsStore { withMutation(keyPath: \.enableLiveTranscriptCleanup) { _enableLiveTranscriptCleanup = newValue defaults.set(newValue, forKey: "enableLiveTranscriptCleanup") + defaults.set(newValue, forKey: Self.enableLiveTranscriptCleanupLegacyKey) } } } @@ -373,6 +376,7 @@ final class SettingsStore { withMutation(keyPath: \.enableBatchRetranscription) { _enableBatchRetranscription = newValue defaults.set(newValue, forKey: "enableBatchRetranscription") + defaults.set(newValue, forKey: Self.enableBatchRetranscriptionLegacyKey) } } } @@ -614,11 +618,11 @@ final class SettingsStore { // Migrate renamed settings keys (old -> new) if defaults.object(forKey: "enableLiveTranscriptCleanup") == nil, - let oldValue = defaults.object(forKey: "enableTranscriptRefinement") { + let oldValue = defaults.object(forKey: Self.enableLiveTranscriptCleanupLegacyKey) { defaults.set(oldValue, forKey: "enableLiveTranscriptCleanup") } if defaults.object(forKey: "enableBatchRetranscription") == nil, - let oldValue = defaults.object(forKey: "enableBatchRefinement") { + let oldValue = defaults.object(forKey: Self.enableBatchRetranscriptionLegacyKey) { defaults.set(oldValue, forKey: "enableBatchRetranscription") } diff --git a/OpenOats/Sources/OpenOats/Transcription/TranscriptionEngine.swift b/OpenOats/Sources/OpenOats/Transcription/TranscriptionEngine.swift index e48da036..eec3c91a 100644 --- a/OpenOats/Sources/OpenOats/Transcription/TranscriptionEngine.swift +++ b/OpenOats/Sources/OpenOats/Transcription/TranscriptionEngine.swift @@ -4,17 +4,6 @@ import FluidAudio import Observation import os -enum TranscriptionEngineError: LocalizedError { - case transcriberNotInitialized - - var errorDescription: String? { - switch self { - case .transcriberNotInitialized: - "Transcription engine is not initialized. Please check your audio settings." - } - } -} - /// Enriched download progress info computed from fraction changes over time. struct DownloadProgressDetail: Sendable { let fraction: Double @@ -26,6 +15,41 @@ struct DownloadProgressDetail: Sendable { let etaText: String? } +/// Session-scoped transcription settings captured at start time. +struct ActiveTranscriptionSession: Sendable, Equatable { + let transcriptionModel: TranscriptionModel + + var flushIntervalSamples: Int { + transcriptionModel.flushIntervalSamples + } + + func clearModelCache( + using makeBackend: (TranscriptionModel) -> any TranscriptionBackend = { $0.makeBackend() } + ) { + makeBackend(transcriptionModel).clearModelCache() + } +} + +/// Stops forwarding diarization samples after the first feed failure. +struct DiarizationFeedRelay: Sendable { + private(set) var hasFailed = false + + mutating func feedAudio( + _ samples: [Float], + into feedAudio: @Sendable ([Float]) async throws -> Void, + onFailure: @Sendable (Error) async -> Void + ) async { + guard !hasFailed else { return } + + do { + try await feedAudio(samples) + } catch { + hasFailed = true + await onFailure(error) + } + } +} + /// Orchestrates dual StreamingTranscriber instances for mic (you) and system audio (them). @Observable @MainActor @@ -134,6 +158,9 @@ final class TranscriptionEngine { /// Speaker diarization manager for system audio (nil when diarization is disabled). private var diarizationManager: DiarizationManager? + /// Active transcription model captured for the current session/startup. + @ObservationIgnored nonisolated(unsafe) var activeTranscriptionSession: ActiveTranscriptionSession? + /// Tracks the resolved mic device ID currently in use. private var currentMicDeviceID: AudioDeviceID = 0 @@ -204,7 +231,14 @@ final class TranscriptionEngine { return } - guard await ensureMicrophonePermission() else { return } + activeTranscriptionSession = ActiveTranscriptionSession( + transcriptionModel: transcriptionModel + ) + + guard await ensureMicrophonePermission() else { + activeTranscriptionSession = nil + return + } isRunning = true @@ -276,7 +310,7 @@ final class TranscriptionEngine { Log.transcription.info("Transcription model loaded") } catch { let msg = "Failed to load models: \(error.localizedDescription)" - Log.transcription.error("Failed to load models: \(msg, privacy: .public)") + Log.transcription.error("Failed to load models: \(error, privacy: .public)") lastError = msg assetStatus = "Ready" isRunning = false @@ -285,14 +319,20 @@ final class TranscriptionEngine { downloadStartTime = nil downloadTotalBytes = nil // Clear corrupt cache so the next attempt triggers a fresh download - settings.transcriptionModel.makeBackend().clearModelCache() - Log.transcription.info("Cleared model cache for \(self.settings.transcriptionModel.rawValue, privacy: .public)") + activeTranscriptionSession?.clearModelCache() + Log.transcription.info( + "Cleared model cache for \(transcriptionModel.rawValue, privacy: .public)" + ) needsModelDownload = true downloadConfirmed = false + activeTranscriptionSession = nil return } - guard let vadManager else { return } + guard let vadManager else { + activeTranscriptionSession = nil + return + } // 2. Start mic capture userSelectedDeviceID = inputDeviceID @@ -302,6 +342,7 @@ final class TranscriptionEngine { lastError = msg assetStatus = "Ready" isRunning = false + activeTranscriptionSession = nil return } currentMicDeviceID = targetMicID @@ -501,6 +542,7 @@ final class TranscriptionEngine { assetStatus = "Ready" transcriptStore.volatileYouText = "" transcriptStore.volatileThemText = "" + activeTranscriptionSession = nil return } @@ -536,8 +578,10 @@ final class TranscriptionEngine { micBackend = nil systemBackend = nil + vadManager = nil transcriptStore.volatileYouText = "" transcriptStore.volatileThemText = "" + activeTranscriptionSession = nil isRunning = false assetStatus = "Ready" } @@ -570,8 +614,10 @@ final class TranscriptionEngine { currentMicDeviceID = 0 micBackend = nil systemBackend = nil + vadManager = nil transcriptStore.volatileYouText = "" transcriptStore.volatileThemText = "" + activeTranscriptionSession = nil isRunning = false assetStatus = "Ready" } @@ -684,6 +730,7 @@ final class TranscriptionEngine { lastError = "Failed to create transcriber. Try restarting." isRunning = false assetStatus = "Ready" + activeTranscriptionSession = nil return } micTask = Task.detached { @@ -704,7 +751,7 @@ final class TranscriptionEngine { clearSystemAudioErrorIfPresent() } catch { let msg = "Failed to start system audio: \(error.localizedDescription)" - Log.transcription.error("Failed to start system audio: \(msg, privacy: .public)") + Log.transcription.error("Failed to start system audio: \(error, privacy: .public)") lastError = msg return } @@ -725,7 +772,8 @@ final class TranscriptionEngine { let originalSysStream = sysStream let (diarTapped, diarContinuation) = AsyncStream.makeStream() Task { - nonisolated(unsafe) let safeDm = dm + let safeDm = dm + var diarizationRelay = DiarizationFeedRelay() var diarBuf: [Float] = [] for await buffer in originalSysStream { nonisolated(unsafe) let b = buffer @@ -737,12 +785,28 @@ final class TranscriptionEngine { if diarBuf.count >= diarFlushSize { let batch = diarBuf diarBuf.removeAll(keepingCapacity: true) - try? await safeDm.feedAudio(batch) + await diarizationRelay.feedAudio( + batch, + into: { samples in try await safeDm.feedAudio(samples) }, + onFailure: { error in + Log.transcription.error( + "Diarization feed failed: \(error, privacy: .public)" + ) + } + ) } } // Flush tail if !diarBuf.isEmpty { - try? await safeDm.feedAudio(diarBuf) + await diarizationRelay.feedAudio( + diarBuf, + into: { samples in try await safeDm.feedAudio(samples) }, + onFailure: { error in + Log.transcription.error( + "Diarization feed failed: \(error, privacy: .public)" + ) + } + ) } diarContinuation.finish() } @@ -799,12 +863,16 @@ final class TranscriptionEngine { locale: locale, vadManager: vadManager, speaker: speaker, - flushInterval: settings.transcriptionModel.flushIntervalSamples, + flushInterval: currentTranscriptionModel().flushIntervalSamples, onPartial: onPartial, onFinal: onFinal ) } + func currentTranscriptionModel() -> TranscriptionModel { + activeTranscriptionSession?.transcriptionModel ?? settings.transcriptionModel + } + private func resolvedMicDeviceID(for inputDeviceID: AudioDeviceID) -> AudioDeviceID? { if inputDeviceID > 0 { let availableDeviceIDs = Set(MicCapture.availableInputDevices().map(\.id)) diff --git a/OpenOats/Tests/OpenOatsTests/NotificationServiceTests.swift b/OpenOats/Tests/OpenOatsTests/NotificationServiceTests.swift new file mode 100644 index 00000000..f8b77449 --- /dev/null +++ b/OpenOats/Tests/OpenOatsTests/NotificationServiceTests.swift @@ -0,0 +1,14 @@ +import XCTest +@testable import OpenOatsKit + +@MainActor +final class NotificationServiceTests: XCTestCase { + + func testBatchCompletedNotificationCopyUsesReTranscriptionWording() { + XCTAssertEqual(NotificationService.batchCompletedTitle, "Re-transcription Complete") + XCTAssertEqual( + NotificationService.batchCompletedBody, + "Re-transcription is complete. Your meeting transcript has been updated with higher-quality text." + ) + } +} diff --git a/OpenOats/Tests/OpenOatsTests/SettingsStoreTests.swift b/OpenOats/Tests/OpenOatsTests/SettingsStoreTests.swift index acf35fc2..4db03648 100644 --- a/OpenOats/Tests/OpenOatsTests/SettingsStoreTests.swift +++ b/OpenOats/Tests/OpenOatsTests/SettingsStoreTests.swift @@ -135,6 +135,21 @@ final class SettingsStoreTests: XCTestCase { XCTAssertTrue(store.enableLiveTranscriptCleanup) } + func testEnableLiveTranscriptCleanupDualWritesLegacyKey() { + let suiteName = "com.openoats.test.\(UUID().uuidString)" + let defaults = UserDefaults(suiteName: suiteName)! + defaults.removePersistentDomain(forName: suiteName) + + let store = makeStore(defaults: defaults) + store.enableLiveTranscriptCleanup = true + + XCTAssertEqual(defaults.bool(forKey: "enableLiveTranscriptCleanup"), true) + XCTAssertEqual(defaults.bool(forKey: "enableTranscriptRefinement"), true) + + let reopened = makeStore(defaults: defaults) + XCTAssertTrue(reopened.enableLiveTranscriptCleanup) + } + // MARK: - Capture Settings Group func testDefaultInputDeviceID() { @@ -175,6 +190,21 @@ final class SettingsStoreTests: XCTestCase { XCTAssertFalse(store.enableBatchRetranscription) } + func testEnableBatchRetranscriptionDualWritesLegacyKey() { + let suiteName = "com.openoats.test.\(UUID().uuidString)" + let defaults = UserDefaults(suiteName: suiteName)! + defaults.removePersistentDomain(forName: suiteName) + + let store = makeStore(defaults: defaults) + store.enableBatchRetranscription = true + + XCTAssertEqual(defaults.bool(forKey: "enableBatchRetranscription"), true) + XCTAssertEqual(defaults.bool(forKey: "enableBatchRefinement"), true) + + let reopened = makeStore(defaults: defaults) + XCTAssertTrue(reopened.enableBatchRetranscription) + } + func testDefaultBatchTranscriptionModel() { let store = makeStore() XCTAssertEqual(store.batchTranscriptionModel, .whisperLargeV3Turbo) diff --git a/OpenOats/Tests/OpenOatsTests/TranscriptionEngineTests.swift b/OpenOats/Tests/OpenOatsTests/TranscriptionEngineTests.swift new file mode 100644 index 00000000..73a1889a --- /dev/null +++ b/OpenOats/Tests/OpenOatsTests/TranscriptionEngineTests.swift @@ -0,0 +1,154 @@ +import XCTest +@testable import OpenOatsKit + +@MainActor +final class TranscriptionEngineTests: XCTestCase { + // MARK: - Helpers + + private func makeSettings() -> AppSettings { + let suiteName = "com.openoats.tests.transcription-engine.\(UUID().uuidString)" + let defaults = UserDefaults(suiteName: suiteName) ?? .standard + defaults.removePersistentDomain(forName: suiteName) + let storage = AppSettingsStorage( + defaults: defaults, + secretStore: .ephemeral, + defaultNotesDirectory: URL(fileURLWithPath: NSTemporaryDirectory()), + runMigrations: false + ) + return AppSettings(storage: storage) + } + + // MARK: - Active Session Model + + func testActiveTranscriptionSessionCapturesModelForFlushAndCacheClearing() { + let session = ActiveTranscriptionSession(transcriptionModel: .whisperLargeV3Turbo) + XCTAssertEqual( + session.flushIntervalSamples, + TranscriptionModel.whisperLargeV3Turbo.flushIntervalSamples + ) + + let backend = CacheClearingBackend() + var capturedModel: TranscriptionModel? + session.clearModelCache(using: { model in + capturedModel = model + return backend + }) + + XCTAssertEqual(capturedModel, .whisperLargeV3Turbo) + XCTAssertEqual(backend.clearModelCacheCallCount, 1) + } + + func testCurrentTranscriptionModelPrefersActiveSessionOverMutableSettings() { + let settings = makeSettings() + settings.transcriptionModel = .parakeetV2 + + let engine = TranscriptionEngine( + transcriptStore: TranscriptStore(), + settings: settings, + mode: .scripted([]) + ) + engine.activeTranscriptionSession = ActiveTranscriptionSession( + transcriptionModel: .whisperBase + ) + + settings.transcriptionModel = .qwen3ASR06B + + XCTAssertEqual(engine.currentTranscriptionModel(), .whisperBase) + } + + // MARK: - Diarization Feed Gate + + func testDiarizationFeedRelayStopsAfterFirstFailure() async { + var relay = DiarizationFeedRelay() + let recorder = FeedRecorder(failOnCall: 2) + let errorRecorder = ErrorRecorder() + + await relay.feedAudio( + [1.0, 2.0], + into: { samples in try await recorder.feed(samples) }, + onFailure: { error in await errorRecorder.record(error) } + ) + await relay.feedAudio( + [3.0, 4.0], + into: { samples in try await recorder.feed(samples) }, + onFailure: { error in await errorRecorder.record(error) } + ) + await relay.feedAudio( + [5.0, 6.0], + into: { samples in try await recorder.feed(samples) }, + onFailure: { error in await errorRecorder.record(error) } + ) + + let recordedBatches = await recorder.snapshotBatches() + let failureCount = await errorRecorder.snapshotCount() + + XCTAssertEqual(recordedBatches, [[1.0, 2.0], [3.0, 4.0]]) + XCTAssertTrue(relay.hasFailed) + XCTAssertEqual(failureCount, 1) + } +} + +// MARK: - Test Helpers + +private final class CacheClearingBackend: TranscriptionBackend, @unchecked Sendable { + let displayName = "Mock cache clearing backend" + private(set) var clearModelCacheCallCount = 0 + + func checkStatus() -> BackendStatus { + .ready + } + + func prepare( + onStatus: @Sendable (String) -> Void, + onProgress: @escaping @Sendable (Double) -> Void + ) async throws { + } + + func transcribe( + _ samples: [Float], + locale: Locale, + previousContext: String? + ) async throws -> String { + "" + } + + func clearModelCache() { + clearModelCacheCallCount += 1 + } +} + +private actor FeedRecorder { + private(set) var batches: [[Float]] = [] + private let failOnCall: Int? + private var callCount = 0 + + init(failOnCall: Int?) { + self.failOnCall = failOnCall + } + + func feed(_ batch: [Float]) throws { + callCount += 1 + batches.append(batch) + + if callCount == failOnCall { + struct RelayFailure: Error {} + throw RelayFailure() + } + } + + func snapshotBatches() -> [[Float]] { + batches + } +} + +private actor ErrorRecorder { + private(set) var count = 0 + + func record(_ error: Error) { + count += 1 + } + + func snapshotCount() -> Int { + count + } +}