From e4b314c2f1f69a120c1b549e2a5b6675bacaf4ba Mon Sep 17 00:00:00 2001 From: Chris Watson Date: Sat, 28 Mar 2026 22:19:34 -0600 Subject: [PATCH 1/5] Add KMP-backed local transcription runtime - Move local model catalog and install/delete flows into the shared runtime - Bridge Swift model management to the KMP transcription APIs - Add runtime module wiring and tests --- Pindrop/Services/ModelManager.swift | 55 +- .../KMPTranscriptionBridge.swift | 120 +++- .../NativeTranscriptionAdapters.swift | 583 ++++++++++++++++++ .../TranscriptionModelCatalog.swift | 297 +-------- Pindrop/Services/TranscriptionService.swift | 173 +++++- justfile | 2 +- shared/README.md | 5 +- shared/core/build.gradle.kts | 2 + .../shared/core/TranscriptionContracts.kt | 1 + shared/feature-transcription/build.gradle.kts | 6 +- shared/runtime-transcription/build.gradle.kts | 43 ++ .../LocalTranscriptionCatalog.kt | 174 ++++++ .../LocalTranscriptionContracts.kt | 153 +++++ .../LocalTranscriptionRuntime.kt | 272 ++++++++ .../LocalTranscriptionRuntimeTest.kt | 241 ++++++++ shared/settings.gradle.kts | 1 + 16 files changed, 1815 insertions(+), 313 deletions(-) create mode 100644 shared/runtime-transcription/build.gradle.kts create mode 100644 shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt create mode 100644 shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionContracts.kt create mode 100644 shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntime.kt create mode 100644 shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntimeTest.kt diff --git a/Pindrop/Services/ModelManager.swift b/Pindrop/Services/ModelManager.swift index 9cf0cde..c64c6a3 100644 --- a/Pindrop/Services/ModelManager.swift +++ b/Pindrop/Services/ModelManager.swift @@ -213,6 +213,13 @@ class ModelManager { private(set) var downloadedFeatureModels: Set = [] private let fileManager = FileManager.default + #if canImport(PindropSharedTranscription) + @ObservationIgnored + private lazy var localRuntimeBridge = KMPTranscriptionRuntimeBridge( + modelManager: self, + engineFactory: TranscriptionService.defaultEngineFactory(provider:) + ) + #endif /// Last decile (0...10) logged for WhisperKit file download progress to avoid log spam. private var whisperKitDownloadLastLoggedDecile: Int = -1 @@ -261,6 +268,18 @@ class ModelManager { } func refreshDownloadedModels() async { + #if canImport(PindropSharedTranscription) + let downloaded = await localRuntimeBridge.refreshInstalledModelNames() + if downloaded != downloadedModelNames { + Log.model.debug("Found \(downloaded.count) downloaded models via KMP runtime: \(downloaded)") + } + downloadedModelNames = downloaded + #else + await refreshDownloadedModelsFromDisk() + #endif + } + + private func refreshDownloadedModelsFromDisk() async { var downloaded: Set = [] let whisperKitPath = whisperKitModelsURL @@ -339,13 +358,31 @@ class ModelManager { isDownloading = false currentDownloadModel = nil } - + + #if canImport(PindropSharedTranscription) + if model.provider.isLocal { + try await localRuntimeBridge.installModel(named: modelName, onProgress: onProgress) + } else { + throw ModelError.downloadNotImplemented(model.provider.rawValue) + } + #else + try await installModelArtifacts(named: modelName, onProgress: onProgress) + #endif + Log.boot.info("ModelManager.downloadModel finished OK name=\(modelName) wallClock=\(String(format: "%.2fs", CFAbsoluteTimeGetCurrent() - downloadWallClock))") + } + + func installModelArtifacts(named modelName: String, onProgress: ((Double) -> Void)? = nil) async throws { + guard let model = availableModels.first(where: { $0.name == modelName }) else { + throw ModelError.modelNotFound(modelName) + } + if model.provider == .parakeet { try await downloadParakeetModel(named: modelName, onProgress: onProgress) - } else { + } else if model.provider == .whisperKit { try await downloadWhisperKitModel(named: modelName, onProgress: onProgress) + } else { + throw ModelError.downloadNotImplemented(model.provider.rawValue) } - Log.boot.info("ModelManager.downloadModel finished OK name=\(modelName) wallClock=\(String(format: "%.2fs", CFAbsoluteTimeGetCurrent() - downloadWallClock))") } private func downloadWhisperKitModel(named modelName: String, onProgress: ((Double) -> Void)? = nil) async throws { @@ -474,6 +511,17 @@ class ModelManager { } func deleteModel(named modelName: String) async throws { + #if canImport(PindropSharedTranscription) + try await localRuntimeBridge.deleteModel(named: modelName) + await refreshDownloadedModels() + return + #else + try await deleteModelArtifacts(named: modelName) + await refreshDownloadedModels() + #endif + } + + func deleteModelArtifacts(named modelName: String) async throws { guard let model = availableModels.first(where: { $0.name == modelName }) else { throw ModelError.modelNotFound(modelName) } @@ -488,7 +536,6 @@ class ModelManager { do { try fileManager.removeItem(at: modelPath) - await refreshDownloadedModels() } catch { throw ModelError.deleteFailed(error.localizedDescription) } diff --git a/Pindrop/Services/Transcription/KMPTranscriptionBridge.swift b/Pindrop/Services/Transcription/KMPTranscriptionBridge.swift index 2cdbd24..82838b9 100644 --- a/Pindrop/Services/Transcription/KMPTranscriptionBridge.swift +++ b/Pindrop/Services/Transcription/KMPTranscriptionBridge.swift @@ -66,6 +66,25 @@ struct SharedTranscriptionStateTransition: Equatable, Sendable { } enum KMPTranscriptionBridge { + static func localAvailableModels() -> [ModelManager.WhisperModel] { + #if canImport(PindropSharedTranscription) + LocalTranscriptionCatalog.shared.models(platform: localPlatform()).map(localModel(from:)) + #else + [] + #endif + } + + static func recommendedLocalModels(for language: AppLanguage) -> [ModelManager.WhisperModel] { + #if canImport(PindropSharedTranscription) + LocalTranscriptionCatalog.shared.recommendedModels( + platform: localPlatform(), + language: coreLanguage(from: language) + ).map(localModel(from:)) + #else + [] + #endif + } + static func normalizeTranscriptionText(_ text: String) -> String { #if canImport(PindropSharedTranscription) SharedTranscriptionOrchestrator.shared.normalizeTranscriptionText(text: text) @@ -187,7 +206,7 @@ enum KMPTranscriptionBridge { #if canImport(PindropSharedTranscription) let policy = TranscriptionRuntimePolicy( selectedProvider: coreProvider(from: selectedProvider), - selectedModelId: CoreTranscriptionModelId(value: selectedModelName), + selectedModelId: TranscriptionModelId(value: selectedModelName), streamingFeatureEnabled: streamingFeatureEnabled, diarizationFeatureEnabled: diarizationFeatureEnabled, outputMode: outputMode.kmpValue, @@ -256,7 +275,7 @@ enum KMPTranscriptionBridge { #if canImport(PindropSharedTranscription) let plan = SharedTranscriptionOrchestrator.shared.planTranscriptionExecution( selectedProvider: coreProvider(from: selectedProvider), - selectedModelId: CoreTranscriptionModelId(value: selectedModelName), + selectedModelId: TranscriptionModelId(value: selectedModelName), diarizationRequested: diarizationRequested, isStreamingSessionActive: isStreamingSessionActive ) @@ -441,7 +460,7 @@ enum KMPTranscriptionBridge { ) -> [ModelManager.WhisperModel] { #if canImport(PindropSharedTranscription) let orchestrator = SharedTranscriptionOrchestrator.shared - let curatedIds = recommendedModelNames(for: language).map { CoreTranscriptionModelId(value: $0) } + let curatedIds = recommendedModelNames(for: language).map { TranscriptionModelId(value: $0) } let descriptors = availableModels.map(coreDescriptor(from:)) let language = coreLanguage(from: language) @@ -482,10 +501,10 @@ enum KMPTranscriptionBridge { let modelsByName = Dictionary(uniqueKeysWithValues: availableModels.map { ($0.name, $0) }) let resolution = orchestrator.resolveStartupModel( - selectedModelId: CoreTranscriptionModelId(value: selectedModelId), - defaultModelId: CoreTranscriptionModelId(value: defaultModelId), + selectedModelId: TranscriptionModelId(value: selectedModelId), + defaultModelId: TranscriptionModelId(value: defaultModelId), availableModels: descriptors, - downloadedModelIds: downloadedModelIds.map(CoreTranscriptionModelId.init(value:)) + downloadedModelIds: downloadedModelIds.map { TranscriptionModelId(value: $0) } ) let resolvedModel = modelsByName[resolution.resolvedModel.id.value] ?? availableModels.first! @@ -610,7 +629,7 @@ enum KMPTranscriptionBridge { #if canImport(PindropSharedTranscription) private extension KMPTranscriptionBridge { - static func coreProvider(from provider: ModelManager.ModelProvider) -> CoreTranscriptionProviderId { + static func coreProvider(from provider: ModelManager.ModelProvider) -> TranscriptionProviderId { switch provider { case .whisperKit: .whisperKit @@ -625,7 +644,7 @@ private extension KMPTranscriptionBridge { } } - static func modelProvider(from provider: CoreTranscriptionProviderId) -> ModelManager.ModelProvider { + static func modelProvider(from provider: TranscriptionProviderId) -> ModelManager.ModelProvider { switch provider { case .whisperKit: .whisperKit @@ -642,7 +661,7 @@ private extension KMPTranscriptionBridge { } } - static func coreLanguage(from language: AppLanguage) -> CoreTranscriptionLanguage { + static func coreLanguage(from language: AppLanguage) -> TranscriptionLanguage { switch language { case .automatic: .automatic @@ -673,7 +692,7 @@ private extension KMPTranscriptionBridge { static func coreLanguageSupport( from support: ModelManager.LanguageSupport - ) -> CoreModelLanguageSupport { + ) -> ModelLanguageSupport { switch support { case .englishOnly: .englishOnly @@ -686,7 +705,7 @@ private extension KMPTranscriptionBridge { static func coreAvailability( from availability: ModelManager.ModelAvailability - ) -> CoreModelAvailability { + ) -> ModelAvailability { switch availability { case .available: .available @@ -697,9 +716,9 @@ private extension KMPTranscriptionBridge { } } - static func coreDescriptor(from model: ModelManager.WhisperModel) -> CoreModelDescriptor { - CoreModelDescriptor( - id: CoreTranscriptionModelId(value: model.name), + static func coreDescriptor(from model: ModelManager.WhisperModel) -> ModelDescriptor { + ModelDescriptor( + id: TranscriptionModelId(value: model.name), displayName: model.displayName, provider: coreProvider(from: model.provider), languageSupport: coreLanguageSupport(from: model.languageSupport), @@ -711,7 +730,76 @@ private extension KMPTranscriptionBridge { ) } - static func coreState(from state: TranscriptionService.State) -> CoreSharedTranscriptionState { + static func localProvider( + from provider: LocalModelProvider + ) -> ModelManager.ModelProvider { + switch provider { + case .whisperKit, .wcpp: + .whisperKit + case .parakeetCoreml, .parakeetNative: + .parakeet + default: + .whisperKit + } + } + + static func localAvailability( + from availability: ModelAvailability + ) -> ModelManager.ModelAvailability { + switch availability { + case .available: + .available + case .comingSoon: + .comingSoon + case .requiresSetup: + .requiresSetup + default: + .available + } + } + + static func localLanguageSupport( + from support: ModelLanguageSupport + ) -> ModelManager.LanguageSupport { + switch support { + case .englishOnly: + .englishOnly + case .fullMultilingual: + .fullMultilingual + case .parakeetV3European: + .parakeetV3European + default: + .fullMultilingual + } + } + + static func localModel( + from descriptor: LocalModelDescriptor + ) -> ModelManager.WhisperModel { + ModelManager.WhisperModel( + name: descriptor.id.value, + displayName: descriptor.displayName, + sizeInMB: Int(descriptor.sizeInMb), + description: descriptor.description_, + speedRating: descriptor.speedRating, + accuracyRating: descriptor.accuracyRating, + languageSupport: localLanguageSupport(from: descriptor.languageSupport), + provider: localProvider(from: descriptor.provider), + availability: localAvailability(from: descriptor.availability) + ) + } + + static func localPlatform() -> LocalPlatformId { + #if os(macOS) + .macos + #elseif os(Windows) + .windows + #else + .linux + #endif + } + + static func coreState(from state: TranscriptionService.State) -> SharedTranscriptionState { switch state { case .unloaded: .unloaded @@ -726,7 +814,7 @@ private extension KMPTranscriptionBridge { } } - static func serviceState(from state: CoreSharedTranscriptionState) -> TranscriptionService.State { + static func serviceState(from state: SharedTranscriptionState) -> TranscriptionService.State { switch state { case .unloaded: .unloaded diff --git a/Pindrop/Services/Transcription/NativeTranscriptionAdapters.swift b/Pindrop/Services/Transcription/NativeTranscriptionAdapters.swift index 55ef31d..a597782 100644 --- a/Pindrop/Services/Transcription/NativeTranscriptionAdapters.swift +++ b/Pindrop/Services/Transcription/NativeTranscriptionAdapters.swift @@ -172,3 +172,586 @@ final class MacOSSettingsSnapshotAdapter: SettingsSnapshotProvider { ) } } + +#if canImport(PindropSharedTranscription) +import PindropSharedTranscription + +@MainActor +final class KMPTranscriptionRuntimeBridge { + private let modelManager: ModelManager + private let backendRegistry: MacOSRuntimeBackendRegistry + private let runtime: LocalTranscriptionRuntime + + init( + modelManager: ModelManager, + engineFactory: @escaping @MainActor (ModelManager.ModelProvider) throws -> any TranscriptionEngine + ) { + self.modelManager = modelManager + self.backendRegistry = MacOSRuntimeBackendRegistry(engineFactory: engineFactory) + let installedIndex = MacOSInstalledModelIndexAdapter(modelManager: modelManager) + let installer = MacOSModelInstallerAdapter(modelManager: modelManager) + self.runtime = LocalTranscriptionRuntime( + platform: .macos, + installedModelIndex: installedIndex, + modelInstaller: installer, + backendRegistry: backendRegistry, + observer: nil + ) + } + + func refreshInstalledModelNames() async -> Set { + do { + let records = try await refreshInstalledModels() + return Set(records.map(\.modelId.value)) + } catch { + Log.model.error("KMP runtime refresh failed: \(error.localizedDescription)") + return [] + } + } + + func refreshInstalledModels() async throws -> [InstalledModelRecord] { + try await withCheckedThrowingContinuation { continuation in + runtime.refreshInstalledModels { records, error in + if let error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: records ?? []) + } + } + } + } + + func installModel( + named modelName: String, + onProgress: ((Double) -> Void)? = nil + ) async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + runtime.installModel(modelId: TranscriptionModelId(value: modelName)) { _, error in + if let error { + continuation.resume(throwing: error) + } else { + onProgress?(1.0) + continuation.resume(returning: ()) + } + } + } + } + + func loadModel( + named modelName: String, + provider: ModelManager.ModelProvider + ) async throws -> (any TranscriptionEnginePort) { + let backendProvider = effectiveRuntimeProvider(for: provider) + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + runtime.loadModel(modelId: TranscriptionModelId(value: modelName)) { error in + if let error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: ()) + } + } + } + + guard let engine = backendRegistry.engine(for: backendProvider) else { + throw TranscriptionService.TranscriptionError.modelLoadFailed( + "No runtime engine available for \(backendProvider.rawValue)" + ) + } + return engine + } + + func loadModel(fromPath path: String) async throws -> (any TranscriptionEnginePort) { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + runtime.loadModelFromPath(path: path, family: .whisper) { error in + if let error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: ()) + } + } + } + + guard let engine = backendRegistry.engine(for: .whisperKit) else { + throw TranscriptionService.TranscriptionError.modelLoadFailed( + "No runtime engine available for WhisperKit" + ) + } + return engine + } + + func transcribe(audioData: Data, options: TranscriptionOptions) async throws -> String { + let request = TranscriptionRequest( + audioData: KotlinByteArray.from(data: audioData), + language: options.language.transcriptionLanguage, + diarizationEnabled: false, + customVocabularyWords: options.customVocabularyWords + ) + + let result = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + runtime.transcribe(request: request) { result, error in + if let error { + continuation.resume(throwing: error) + } else if let result { + continuation.resume(returning: result) + } else { + continuation.resume( + throwing: TranscriptionService.TranscriptionError.transcriptionFailed( + "Runtime returned no transcription result" + ) + ) + } + } + } + + return result.text + } + + func unloadModel() async { + await withCheckedContinuation { (continuation: CheckedContinuation) in + runtime.unloadModel { _ in + continuation.resume(returning: ()) + } + } + } + + func engine(for provider: ModelManager.ModelProvider) -> (any TranscriptionEnginePort)? { + backendRegistry.engine(for: effectiveRuntimeProvider(for: provider)) + } + + func effectiveRuntimeProvider(for provider: ModelManager.ModelProvider) -> ModelManager.ModelProvider { + switch provider { + case .whisperKit, .parakeet: + provider + case .openAI, .elevenLabs, .groq: + .whisperKit + } + } + + func deleteModel(named modelName: String) async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + runtime.deleteModel(modelId: TranscriptionModelId(value: modelName)) { error in + if let error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: ()) + } + } + } + } +} + +private final class MacOSInstalledModelIndexAdapter: NSObject, InstalledModelIndexPort { + private let modelManager: ModelManager + private let fileManager = FileManager.default + + init(modelManager: ModelManager) { + self.modelManager = modelManager + } + + func refreshInstalledModels(completionHandler: @escaping ([InstalledModelRecord]?, (any Error)?) -> Void) { + let installRoot = fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! + .appendingPathComponent("Pindrop", isDirectory: true) + let whisperRoot = installRoot + .appendingPathComponent("models", isDirectory: true) + .appendingPathComponent("argmaxinc", isDirectory: true) + .appendingPathComponent("whisperkit-coreml", isDirectory: true) + let parakeetRoot = installRoot + .appendingPathComponent("FluidInference", isDirectory: true) + .appendingPathComponent("parakeet-coreml", isDirectory: true) + + let records = modelManager.availableModels.compactMap { model -> InstalledModelRecord? in + switch model.provider { + case .whisperKit: + let modelPath = whisperRoot.appendingPathComponent(model.name, isDirectory: true) + guard directoryExists(at: modelPath) else { return nil } + return InstalledModelRecord( + modelId: TranscriptionModelId(value: model.name), + state: .installed, + storage: ModelStorageLayout( + installRootPath: whisperRoot.path, + modelPath: modelPath.path + ), + installedProvider: .whisperKit, + lastError: nil + ) + case .parakeet: + let folderName = model.name.hasSuffix("-coreml") ? model.name : "\(model.name)-coreml" + let modelPath = parakeetRoot.appendingPathComponent(folderName, isDirectory: true) + guard directoryExists(at: modelPath) else { return nil } + return InstalledModelRecord( + modelId: TranscriptionModelId(value: model.name), + state: .installed, + storage: ModelStorageLayout( + installRootPath: parakeetRoot.path, + modelPath: modelPath.path + ), + installedProvider: .parakeetCoreml, + lastError: nil + ) + case .openAI, .elevenLabs, .groq: + return nil + } + } + + completionHandler(records, nil) + } + + private func directoryExists(at url: URL) -> Bool { + var isDirectory: ObjCBool = false + return fileManager.fileExists(atPath: url.path, isDirectory: &isDirectory) && isDirectory.boolValue + } +} + +@MainActor +private final class MacOSModelInstallerAdapter: NSObject, @preconcurrency ModelInstallerPort { + private let modelManager: ModelManager + + init(modelManager: ModelManager) { + self.modelManager = modelManager + } + + func installModel( + model: LocalModelDescriptor, + onProgress: @escaping (ModelInstallProgress) -> Void, + completionHandler: @escaping (InstalledModelRecord?, (any Error)?) -> Void + ) { + Task { @MainActor in + do { + try await modelManager.installModelArtifacts(named: model.id.value) { progress in + onProgress( + ModelInstallProgress( + modelId: model.id, + progress: progress, + state: progress >= 1.0 ? .installed : .installing, + message: nil + ) + ) + } + + let records = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<[InstalledModelRecord], Error>) in + MacOSInstalledModelIndexAdapter(modelManager: modelManager) + .refreshInstalledModels { records, error in + if let error { + continuation.resume(throwing: error) + } else { + continuation.resume(returning: records ?? []) + } + } + } + completionHandler(records.first { $0.modelId.value == model.id.value }, nil) + } catch { + completionHandler(nil, error) + } + } + } + + func deleteModel(model: LocalModelDescriptor, completionHandler: @escaping ((any Error)?) -> Void) { + Task { @MainActor in + do { + try await modelManager.deleteModelArtifacts(named: model.id.value) + completionHandler(nil) + } catch { + completionHandler(error) + } + } + } +} + +@MainActor +private final class MacOSRuntimeBackendRegistry: NSObject, @preconcurrency BackendRegistryPort { + private let whisperBackend: RuntimeBackedLocalEngineAdapter + private let parakeetBackend: RuntimeBackedLocalEngineAdapter + + init( + engineFactory: @escaping @MainActor (ModelManager.ModelProvider) throws -> any TranscriptionEngine + ) { + self.whisperBackend = RuntimeBackedLocalEngineAdapter( + backendId: .whisperKit, + provider: .whisperKit, + supportedFamilies: [.whisper], + supportsPathLoading: true, + engineFactory: engineFactory + ) + self.parakeetBackend = RuntimeBackedLocalEngineAdapter( + backendId: .parakeetApple, + provider: .parakeet, + supportedFamilies: [.parakeet], + supportsPathLoading: false, + engineFactory: engineFactory + ) + } + + func preferredBackend(model: LocalModelDescriptor) -> LocalBackendId? { + switch model.family { + case .whisper: + return .whisperKit + case .parakeet: + return .parakeetApple + default: + return nil + } + } + + func backend(id: LocalBackendId) -> LocalInferenceBackendPort? { + switch id { + case .whisperKit: + whisperBackend + case .parakeetApple: + parakeetBackend + default: + nil + } + } + + func engine(for provider: ModelManager.ModelProvider) -> (any TranscriptionEnginePort)? { + switch provider { + case .whisperKit: + whisperBackend.serviceEngine + case .parakeet: + parakeetBackend.serviceEngine + case .openAI, .elevenLabs, .groq: + nil + } + } +} + +@MainActor +private final class RuntimeBackedLocalEngineAdapter: NSObject, @preconcurrency LocalInferenceBackendPort { + let backendId: LocalBackendId + let supportedFamilies: Set + let supportsPathLoading: Bool + + private let provider: ModelManager.ModelProvider + private let engineFactory: @MainActor (ModelManager.ModelProvider) throws -> any TranscriptionEngine + private var engine: (any TranscriptionEnginePort)? + fileprivate lazy var serviceEngine = RuntimeBackedTranscriptionEngineProxy(owner: self) + + init( + backendId: LocalBackendId, + provider: ModelManager.ModelProvider, + supportedFamilies: Set, + supportsPathLoading: Bool, + engineFactory: @escaping @MainActor (ModelManager.ModelProvider) throws -> any TranscriptionEngine + ) { + self.backendId = backendId + self.provider = provider + self.supportedFamilies = supportedFamilies + self.supportsPathLoading = supportsPathLoading + self.engineFactory = engineFactory + } + + fileprivate var state: TranscriptionEngineState { + engine?.state ?? .unloaded + } + + fileprivate func loadModel(path: String) async throws { + try await resolvedEngine().loadModel(path: path) + } + + fileprivate func loadModel(name: String, downloadBase: URL?) async throws { + try await resolvedEngine().loadModel(name: name, downloadBase: downloadBase) + } + + fileprivate func transcribe(audioData: Data, options: TranscriptionOptions) async throws -> String { + try await resolvedEngine().transcribe(audioData: audioData, options: options) + } + + fileprivate func unloadServiceEngine() async { + await engine?.unloadModel() + engine = nil + } + + func loadModel( + model: LocalModelDescriptor, + installedRecord: InstalledModelRecord?, + completionHandler: @escaping ((any Error)?) -> Void + ) { + Task { @MainActor in + do { + switch provider { + case .whisperKit: + if let modelPath = installedRecord?.storage.modelPath { + try await loadModel(path: modelPath) + } else { + let downloadBase = installedRecord.map { + URL(fileURLWithPath: $0.storage.installRootPath) + .deletingLastPathComponent() + .deletingLastPathComponent() + .deletingLastPathComponent() + } + try await loadModel(name: model.id.value, downloadBase: downloadBase) + } + case .parakeet: + try await loadModel(name: model.id.value, downloadBase: nil) + case .openAI, .elevenLabs, .groq: + throw TranscriptionService.TranscriptionError.modelLoadFailed( + "Provider \(provider.rawValue) not supported locally" + ) + } + completionHandler(nil) + } catch { + completionHandler(error) + } + } + } + + func loadModelFromPath(path: String, completionHandler: @escaping ((any Error)?) -> Void) { + Task { @MainActor in + do { + try await loadModel(path: path) + completionHandler(nil) + } catch { + completionHandler(error) + } + } + } + + func transcribe( + request: TranscriptionRequest, + completionHandler: @escaping (TranscriptionResult?, (any Error)?) -> Void + ) { + Task { @MainActor in + do { + let text = try await transcribe( + audioData: request.audioData.dataValue, + options: TranscriptionOptions( + language: request.language.appLanguage, + customVocabularyWords: request.customVocabularyWords + ) + ) + completionHandler(TranscriptionResult(text: text, diarizedSegments: []), nil) + } catch { + completionHandler(nil, error) + } + } + } + + func unloadModel(completionHandler: @escaping ((any Error)?) -> Void) { + Task { @MainActor in + await unloadServiceEngine() + completionHandler(nil) + } + } + + private func resolvedEngine() throws -> any TranscriptionEnginePort { + if let engine { + return engine + } + + let created = try engineFactory(provider) + engine = created + return created + } +} + +@MainActor +private final class RuntimeBackedTranscriptionEngineProxy: TranscriptionEnginePort { + private unowned let owner: RuntimeBackedLocalEngineAdapter + + init(owner: RuntimeBackedLocalEngineAdapter) { + self.owner = owner + } + + var state: TranscriptionEngineState { + owner.state + } + + func loadModel(path: String) async throws { + try await owner.loadModel(path: path) + } + + func loadModel(name: String, downloadBase: URL?) async throws { + try await owner.loadModel(name: name, downloadBase: downloadBase) + } + + func transcribe(audioData: Data, options: TranscriptionOptions) async throws -> String { + try await owner.transcribe(audioData: audioData, options: options) + } + + func unloadModel() async { + await owner.unloadServiceEngine() + } +} + +private extension KotlinByteArray { + static func from(data: Data) -> KotlinByteArray { + let bytes = [UInt8](data) + return KotlinByteArray(size: Int32(bytes.count)) { index in + KotlinByte(char: Int8(truncatingIfNeeded: bytes[Int(index)])) + } + } + + var dataValue: Data { + var bytes = [UInt8]() + bytes.reserveCapacity(Int(size)) + for index in 0.. [ModelManager.WhisperModel] { - KMPTranscriptionBridge.recommendedModels( - availableModels: availableModels, - for: language - ) + let localRecommended = KMPTranscriptionBridge.recommendedLocalModels(for: language) + + guard !localRecommended.isEmpty else { + return KMPTranscriptionBridge.recommendedModels( + availableModels: availableModels, + for: language + ) + } + + return localRecommended } } diff --git a/Pindrop/Services/TranscriptionService.swift b/Pindrop/Services/TranscriptionService.swift index a4b5036..38e090e 100644 --- a/Pindrop/Services/TranscriptionService.swift +++ b/Pindrop/Services/TranscriptionService.swift @@ -75,15 +75,21 @@ class TranscriptionService { private var currentModelIdentifier: String? private var streamingPartialCallback: (@Sendable (String) -> Void)? private var streamingFinalUtteranceCallback: (@Sendable (String) -> Void)? + #if canImport(PindropSharedTranscription) + @ObservationIgnored + private lazy var localRuntimeBridge = KMPTranscriptionRuntimeBridge( + modelManager: ModelManager(), + engineFactory: engineFactory + ) + #endif private let engineFactory: @MainActor (ModelManager.ModelProvider) throws -> any TranscriptionEngine private let speakerDiarizerFactory: @MainActor () -> any SpeakerDiarizer private let streamingEngineFactory: @MainActor () -> any StreamingTranscriptionEngine + private let usesLegacyLocalExecution: Bool init( - engineFactory: @escaping @MainActor (ModelManager.ModelProvider) throws -> any TranscriptionEngine = { - try TranscriptionService.defaultEngineFactory(provider: $0) - }, + engineFactory: (@MainActor (ModelManager.ModelProvider) throws -> any TranscriptionEngine)? = nil, diarizerFactory: @escaping @MainActor () -> any SpeakerDiarizer = { FluidSpeakerDiarizer() }, @@ -91,12 +97,18 @@ class TranscriptionService { ParakeetStreamingEngine() } ) { - self.engineFactory = engineFactory + self.engineFactory = engineFactory ?? TranscriptionService.defaultEngineFactory(provider:) self.speakerDiarizerFactory = diarizerFactory self.streamingEngineFactory = streamingEngineFactory + self.usesLegacyLocalExecution = engineFactory != nil } func loadModel(modelName: String = "tiny", provider: ModelManager.ModelProvider = .whisperKit) async throws { + if !usesLegacyLocalExecution { + try await loadModelUsingRuntime(modelName: modelName, provider: provider) + return + } + try applyStateTransition( KMPTranscriptionBridge.beginModelLoad(currentState: state) ) @@ -170,6 +182,11 @@ class TranscriptionService { } func loadModel(modelPath: String) async throws { + if !usesLegacyLocalExecution { + try await loadModelFromPathUsingRuntime(modelPath: modelPath) + return + } + try applyStateTransition( KMPTranscriptionBridge.beginModelLoad(currentState: state) ) @@ -310,7 +327,11 @@ class TranscriptionService { } func unloadModel() async { - await engine?.unloadModel() + if !usesLegacyLocalExecution { + await localRuntimeBridge.unloadModel() + } else { + await engine?.unloadModel() + } await speakerDiarizer?.unloadModels() await streamingEngine?.unloadModel() engine = nil @@ -526,7 +547,13 @@ class TranscriptionService { shouldNormalizeOutput: Bool, options: TranscriptionOptions ) async throws -> TranscriptionOutput { - let text = try await engine.transcribe(audioData: audioData, options: options) + let text: String + if usesLegacyLocalExecution { + text = try await engine.transcribe(audioData: audioData, options: options) + } else { + text = try await localRuntimeBridge.transcribe(audioData: audioData, options: options) + } + return TranscriptionOutput(text: text, diarizedSegments: nil) .normalized(shouldNormalizeOutput: shouldNormalizeOutput) } @@ -834,7 +861,7 @@ class TranscriptionService { return created } - private static func defaultEngineFactory( + static func defaultEngineFactory( provider: ModelManager.ModelProvider ) throws -> any TranscriptionEngine { switch provider { @@ -944,6 +971,138 @@ class TranscriptionService { .streamingNotReady } } + + private func loadModelUsingRuntime( + modelName: String, + provider: ModelManager.ModelProvider + ) async throws { + try applyStateTransition( + KMPTranscriptionBridge.beginModelLoad(currentState: state) + ) + + let loadPlan = KMPTranscriptionBridge.planModelLoad( + requestedProvider: provider, + currentProvider: currentProvider, + loadsFromPath: false + ) + + if loadPlan.shouldUnloadCurrentModel { + await unloadModel() + } + + guard loadPlan.supportsLocalModelLoading else { + let loadError = TranscriptionError.modelLoadFailed( + "Provider \(loadPlan.resolvedProvider.rawValue) not supported locally" + ) + self.error = loadError + state = KMPTranscriptionBridge.completeModelLoad(success: false) + throw loadError + } + + error = nil + + let loadStarted = CFAbsoluteTimeGetCurrent() + Log.transcription.info("Loading model: \(modelName) with provider: \(loadPlan.resolvedProvider.rawValue)...") + Log.boot.info("TranscriptionService.loadModel(runtime) begin name=\(modelName) provider=\(loadPlan.resolvedProvider.rawValue)") + + do { + let loadedEngine = try await withThrowingTaskGroup(of: (any TranscriptionEnginePort).self) { group in + group.addTask { + try await self.localRuntimeBridge.loadModel(named: modelName, provider: loadPlan.resolvedProvider) + } + + group.addTask { + try await Task.sleep(for: .seconds(120)) + throw TranscriptionError.modelLoadFailed("Model loading timed out after 120s. This can happen on first launch after an update. Try restarting the app, or delete and re-download the model from Settings.") + } + + let result = try await group.next() + group.cancelAll() + guard let result else { + throw TranscriptionError.modelLoadFailed("Model loading finished without an engine result") + } + return result + } + + engine = loadedEngine + currentProvider = loadPlan.resolvedProvider + currentModelIdentifier = modelName + state = KMPTranscriptionBridge.completeModelLoad(success: true) + Log.boot.info("TranscriptionService.loadModel(runtime) success totalElapsed=\(String(format: "%.2fs", CFAbsoluteTimeGetCurrent() - loadStarted))") + } catch let error as TranscriptionError { + self.error = error + state = KMPTranscriptionBridge.completeModelLoad(success: false) + throw error + } catch { + let loadError = TranscriptionError.modelLoadFailed(error.localizedDescription) + self.error = loadError + state = KMPTranscriptionBridge.completeModelLoad(success: false) + throw loadError + } + } + + private func loadModelFromPathUsingRuntime( + modelPath: String + ) async throws { + try applyStateTransition( + KMPTranscriptionBridge.beginModelLoad(currentState: state) + ) + + let loadPlan = KMPTranscriptionBridge.planModelLoad( + requestedProvider: .whisperKit, + currentProvider: currentProvider, + loadsFromPath: true + ) + + if loadPlan.shouldUnloadCurrentModel { + await unloadModel() + } + + guard loadPlan.supportsLocalModelLoading else { + let loadError = TranscriptionError.modelLoadFailed( + "Provider \(loadPlan.resolvedProvider.rawValue) not supported locally" + ) + self.error = loadError + state = KMPTranscriptionBridge.completeModelLoad(success: false) + throw loadError + } + + error = nil + + do { + let loadedEngine = try await withThrowingTaskGroup(of: (any TranscriptionEnginePort).self) { group in + group.addTask { + try await self.localRuntimeBridge.loadModel(fromPath: modelPath) + } + + group.addTask { + try await Task.sleep(for: .seconds(120)) + throw TranscriptionError.modelLoadFailed("Model loading timed out after 120s. This can happen on first launch after an update. Try restarting the app, or delete and re-download the model from Settings.") + } + + let result = try await group.next() + group.cancelAll() + guard let result else { + throw TranscriptionError.modelLoadFailed("Model loading finished without an engine result") + } + return result + } + + engine = loadedEngine + currentProvider = loadPlan.resolvedProvider + currentModelIdentifier = URL(fileURLWithPath: modelPath).lastPathComponent + state = KMPTranscriptionBridge.completeModelLoad(success: true) + } catch let error as TranscriptionError { + self.error = error + state = KMPTranscriptionBridge.completeModelLoad(success: false) + throw error + } catch { + let loadError = TranscriptionError.modelLoadFailed(error.localizedDescription) + self.error = loadError + state = KMPTranscriptionBridge.completeModelLoad(success: false) + throw loadError + } + } } extension TranscriptionService: TranscriptionOrchestrating {} diff --git a/justfile b/justfile index 7b7b56f..32100bf 100644 --- a/justfile +++ b/justfile @@ -106,7 +106,7 @@ test: # Run Kotlin Multiplatform shared-module tests shared-test: @echo "🧪 Running shared Kotlin tests..." - ./shared/gradlew --no-daemon --console=plain -p shared :core:jvmTest :feature-transcription:jvmTest :ui-theme:jvmTest :ui-shell:jvmTest :ui-settings:jvmTest :ui-workspace:jvmTest + ./shared/gradlew --no-daemon --console=plain -p shared :core:jvmTest :runtime-transcription:jvmTest :feature-transcription:jvmTest :ui-theme:jvmTest :ui-shell:jvmTest :ui-settings:jvmTest :ui-workspace:jvmTest @echo "✅ Shared Kotlin tests complete" # Build Apple XCFrameworks for the shared Kotlin modules diff --git a/shared/README.md b/shared/README.md index 0947c1a..fc97da5 100644 --- a/shared/README.md +++ b/shared/README.md @@ -6,16 +6,17 @@ Layout: - `build.gradle.kts`, `settings.gradle.kts`, `gradle.properties`, `gradlew`: Gradle workspace root - `core/`: shared domain types and cross-platform ports - `feature-transcription/`: shared transcription policy and orchestration logic +- `runtime-transcription/`: shared local model catalog and executable local-runtime orchestration Current target status: - `macosArm64` / `macosX64`: actively built and embedded into the app +- `linuxX64` / `mingwX64`: compile-time targets for the shared local transcription runtime - `jvm`: used for shared unit tests -- `desktopLinuxStub` / `desktopWindowsStub`: explicit placeholder tasks that fail with a clear "not implemented yet" error until real Linux/Windows targets land Common commands from the repo root: - `just shared-test` - `just shared-xcframework` Direct commands from this directory: -- `./gradlew :core:jvmTest :feature-transcription:jvmTest` +- `./gradlew :core:jvmTest :runtime-transcription:jvmTest :feature-transcription:jvmTest` - `./gradlew :core:assemblePindropSharedCoreXCFramework :feature-transcription:assemblePindropSharedTranscriptionXCFramework` diff --git a/shared/core/build.gradle.kts b/shared/core/build.gradle.kts index 7ceba64..a73d44f 100644 --- a/shared/core/build.gradle.kts +++ b/shared/core/build.gradle.kts @@ -13,6 +13,8 @@ kotlin { } val macosArm64Target = macosArm64() val macosX64Target = macosX64() + linuxX64() + mingwX64() val xcframework = XCFramework("PindropSharedCore") diff --git a/shared/core/src/commonMain/kotlin/tech/watzon/pindrop/shared/core/TranscriptionContracts.kt b/shared/core/src/commonMain/kotlin/tech/watzon/pindrop/shared/core/TranscriptionContracts.kt index 1edd9cf..c5a3e93 100644 --- a/shared/core/src/commonMain/kotlin/tech/watzon/pindrop/shared/core/TranscriptionContracts.kt +++ b/shared/core/src/commonMain/kotlin/tech/watzon/pindrop/shared/core/TranscriptionContracts.kt @@ -29,6 +29,7 @@ data class TranscriptionRequest( val audioData: ByteArray, val language: TranscriptionLanguage = TranscriptionLanguage.AUTOMATIC, val diarizationEnabled: Boolean = false, + val customVocabularyWords: List = emptyList(), ) data class StreamingTranscriptionConfig( diff --git a/shared/feature-transcription/build.gradle.kts b/shared/feature-transcription/build.gradle.kts index a3fe79f..e91705f 100644 --- a/shared/feature-transcription/build.gradle.kts +++ b/shared/feature-transcription/build.gradle.kts @@ -22,13 +22,17 @@ kotlin { ).forEach { target -> target.binaries.framework { baseName = "PindropSharedTranscription" + export(project(":runtime-transcription")) + export(project(":core")) + transitiveExport = true xcframework.add(this) } } sourceSets { commonMain.dependencies { - implementation(project(":core")) + api(project(":runtime-transcription")) + api(project(":core")) } commonTest.dependencies { implementation(kotlin("test")) diff --git a/shared/runtime-transcription/build.gradle.kts b/shared/runtime-transcription/build.gradle.kts new file mode 100644 index 0000000..c8b19c5 --- /dev/null +++ b/shared/runtime-transcription/build.gradle.kts @@ -0,0 +1,43 @@ +import org.jetbrains.kotlin.gradle.plugin.mpp.apple.XCFramework +import org.jetbrains.kotlin.gradle.dsl.JvmTarget + +plugins { + kotlin("multiplatform") +} + +kotlin { + jvm { + compilerOptions { + jvmTarget.set(JvmTarget.JVM_21) + } + } + val macosArm64Target = macosArm64() + val macosX64Target = macosX64() + linuxX64() + mingwX64() + + val xcframework = XCFramework("PindropSharedRuntimeTranscription") + + listOf( + macosArm64Target, + macosX64Target, + ).forEach { target -> + target.binaries.framework { + baseName = "PindropSharedRuntimeTranscription" + export(project(":core")) + transitiveExport = true + xcframework.add(this) + } + } + + sourceSets { + commonMain.dependencies { + api(project(":core")) + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.2") + } + commonTest.dependencies { + implementation(kotlin("test")) + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.10.2") + } + } +} diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt new file mode 100644 index 0000000..e8fd776 --- /dev/null +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt @@ -0,0 +1,174 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import tech.watzon.pindrop.shared.core.ModelAvailability +import tech.watzon.pindrop.shared.core.ModelLanguageSupport +import tech.watzon.pindrop.shared.core.TranscriptionLanguage +import tech.watzon.pindrop.shared.core.TranscriptionModelId + +object LocalTranscriptionCatalog { + private val englishRecommendedIds = listOf( + TranscriptionModelId("openai_whisper-base.en"), + TranscriptionModelId("openai_whisper-small.en"), + TranscriptionModelId("openai_whisper-medium"), + TranscriptionModelId("openai_whisper-large-v3_turbo"), + TranscriptionModelId("parakeet-tdt-0.6b-v2"), + ) + + private val multilingualRecommendedIds = listOf( + TranscriptionModelId("openai_whisper-base"), + TranscriptionModelId("openai_whisper-small"), + TranscriptionModelId("openai_whisper-medium"), + TranscriptionModelId("openai_whisper-large-v3_turbo"), + TranscriptionModelId("parakeet-tdt-0.6b-v3"), + ) + + private val localModels = listOf( + whisper("openai_whisper-tiny", "Whisper Tiny", 75, "Fastest model, ideal for quick dictation with acceptable accuracy", 10.0, 6.0), + whisper("openai_whisper-tiny.en", "Whisper Tiny (English)", 75, "English-optimized tiny model with slightly better accuracy", 10.0, 6.5, ModelLanguageSupport.ENGLISH_ONLY), + whisper("openai_whisper-base", "Whisper Base", 145, "Good balance between speed and accuracy for everyday use", 9.0, 7.0), + whisper("openai_whisper-base.en", "Whisper Base (English)", 145, "English-optimized base model, recommended for most users", 9.0, 7.5, ModelLanguageSupport.ENGLISH_ONLY), + whisper("openai_whisper-small", "Whisper Small", 483, "Higher accuracy for complex vocabulary and technical terms", 7.5, 8.0), + whisper("openai_whisper-small_216MB", "Whisper Small (Quantized)", 216, "Quantized small model - half the size with similar accuracy", 8.0, 7.8), + whisper("openai_whisper-small.en", "Whisper Small (English)", 483, "English-optimized with excellent accuracy for professional use", 7.5, 8.5, ModelLanguageSupport.ENGLISH_ONLY), + whisper("openai_whisper-small.en_217MB", "Whisper Small (English, Quantized)", 217, "Quantized English small model - compact and fast", 8.0, 8.3, ModelLanguageSupport.ENGLISH_ONLY), + whisper("openai_whisper-medium", "Whisper Medium", 1530, "Excellent for multilingual and code-switching (e.g. Chinese/English mix)", 6.5, 8.8), + whisper("openai_whisper-medium.en", "Whisper Medium (English)", 1530, "English-optimized medium model with high accuracy", 6.5, 9.0, ModelLanguageSupport.ENGLISH_ONLY), + whisper("openai_whisper-large-v2", "Whisper Large v2", 3100, "Previous generation large model, still very capable", 5.0, 9.3), + whisper("openai_whisper-large-v2_949MB", "Whisper Large v2 (Quantized)", 949, "Quantized large v2 - much smaller with minimal accuracy loss", 6.0, 9.1), + whisper("openai_whisper-large-v2_turbo", "Whisper Large v2 Turbo", 3100, "Turbo-optimized large v2 for faster inference", 6.5, 9.3), + whisper("openai_whisper-large-v2_turbo_955MB", "Whisper Large v2 Turbo (Quantized)", 955, "Quantized turbo large v2 - fast and compact", 7.0, 9.1), + whisper("openai_whisper-large-v3", "Whisper Large v3", 3100, "Maximum accuracy for demanding transcription tasks", 5.0, 9.7), + whisper("openai_whisper-large-v3_947MB", "Whisper Large v3 (Quantized)", 947, "Quantized large v3 - great accuracy in a smaller package", 6.0, 9.5), + whisper("openai_whisper-large-v3_turbo", "Whisper Large v3 Turbo", 809, "Near large-model accuracy with significantly faster processing", 7.5, 9.5), + whisper("openai_whisper-large-v3_turbo_954MB", "Whisper Large v3 Turbo (Quantized)", 954, "Quantized turbo v3 - balanced speed and accuracy", 7.5, 9.3), + whisper("openai_whisper-large-v3-v20240930", "Whisper Large v3 (Sep 2024)", 3100, "Updated large v3 with improved multilingual performance", 5.0, 9.8), + whisper("openai_whisper-large-v3-v20240930_547MB", "Whisper Large v3 Sep 2024 (Q 547MB)", 547, "Heavily quantized - smallest large v3 variant", 7.0, 9.3), + whisper("openai_whisper-large-v3-v20240930_626MB", "Whisper Large v3 Sep 2024 (Q 626MB)", 626, "Quantized Sep 2024 large v3 - compact with great accuracy", 6.5, 9.5), + whisper("openai_whisper-large-v3-v20240930_turbo", "Whisper Large v3 Sep 2024 Turbo", 3100, "Latest turbo-optimized large v3 - best overall performance", 6.5, 9.8), + whisper("openai_whisper-large-v3-v20240930_turbo_632MB", "Whisper Large v3 Sep 2024 Turbo (Quantized)", 632, "Quantized latest turbo - excellent accuracy in ~600MB", 7.5, 9.5), + whisper("distil-whisper_distil-large-v3", "Distil Large v3", 1510, "Distilled large v3 - faster with minimal accuracy loss", 7.5, 9.3), + whisper("distil-whisper_distil-large-v3_594MB", "Distil Large v3 (Quantized)", 594, "Quantized distilled model - great speed/accuracy tradeoff", 8.0, 9.0), + whisper("distil-whisper_distil-large-v3_turbo", "Distil Large v3 Turbo", 1510, "Turbo-optimized distilled model for fastest large-class inference", 8.0, 9.3), + whisper("distil-whisper_distil-large-v3_turbo_600MB", "Distil Large v3 Turbo (Quantized)", 600, "Quantized turbo distilled - fastest large-class model at ~600MB", 8.5, 9.0), + parakeet("parakeet-tdt-0.6b-v2", "Parakeet TDT 0.6B V2", 2580, "NVIDIA's state-of-the-art speech recognition model, English-only", 8.5, 9.8, ModelLanguageSupport.ENGLISH_ONLY), + parakeet("parakeet-tdt-0.6b-v3", "Parakeet TDT 0.6B V3", 2670, "Latest Parakeet model with multilingual support", 8.0, 9.9, ModelLanguageSupport.PARAKEET_V3_EUROPEAN), + parakeet("parakeet-tdt-1.1b", "Parakeet TDT 1.1B", 4400, "Larger Parakeet model with exceptional accuracy", 7.0, 9.95, ModelLanguageSupport.ENGLISH_ONLY, ModelAvailability.COMING_SOON), + ) + + fun models(platform: LocalPlatformId): List { + return localModels.map { descriptor -> + when (descriptor.family) { + LocalModelFamily.WHISPER -> descriptor.copy( + provider = if (platform == LocalPlatformId.MACOS) { + LocalModelProvider.WHISPER_KIT + } else { + LocalModelProvider.WCPP + }, + ) + LocalModelFamily.PARAKEET -> descriptor.copy( + provider = if (platform == LocalPlatformId.MACOS) { + LocalModelProvider.PARAKEET_COREML + } else { + LocalModelProvider.PARAKEET_NATIVE + }, + ) + } + } + } + + fun recommendedModelIds(language: TranscriptionLanguage): List { + return if (language == TranscriptionLanguage.ENGLISH) { + englishRecommendedIds + } else { + multilingualRecommendedIds + } + } + + fun recommendedModels( + platform: LocalPlatformId, + language: TranscriptionLanguage, + ): List { + val models = models(platform) + val ranks = recommendedModelIds(language).withIndex().associate { it.value to it.index } + return models + .filter { it.id in ranks.keys } + .filter { supportsLanguage(it.languageSupport, language) } + .sortedBy { ranks[it.id] ?: Int.MAX_VALUE } + } + + fun model(platform: LocalPlatformId, modelId: TranscriptionModelId): LocalModelDescriptor? { + return models(platform).firstOrNull { it.id == modelId } + } + + fun supportsLanguage( + support: ModelLanguageSupport, + language: TranscriptionLanguage, + ): Boolean { + if (language == TranscriptionLanguage.AUTOMATIC) { + return true + } + + return when (support) { + ModelLanguageSupport.ENGLISH_ONLY -> language == TranscriptionLanguage.ENGLISH + ModelLanguageSupport.FULL_MULTILINGUAL -> true + ModelLanguageSupport.PARAKEET_V3_EUROPEAN -> language in setOf( + TranscriptionLanguage.ENGLISH, + TranscriptionLanguage.SPANISH, + TranscriptionLanguage.FRENCH, + TranscriptionLanguage.GERMAN, + TranscriptionLanguage.PORTUGUESE_BRAZIL, + TranscriptionLanguage.ITALIAN, + TranscriptionLanguage.DUTCH, + TranscriptionLanguage.TURKISH, + ) + } + } + + private fun whisper( + id: String, + displayName: String, + sizeInMb: Int, + description: String, + speedRating: Double, + accuracyRating: Double, + languageSupport: ModelLanguageSupport = ModelLanguageSupport.FULL_MULTILINGUAL, + availability: ModelAvailability = ModelAvailability.AVAILABLE, + ): LocalModelDescriptor { + return LocalModelDescriptor( + id = TranscriptionModelId(id), + family = LocalModelFamily.WHISPER, + provider = LocalModelProvider.WHISPER_KIT, + displayName = displayName, + languageSupport = languageSupport, + sizeInMb = sizeInMb, + description = description, + speedRating = speedRating, + accuracyRating = accuracyRating, + availability = availability, + ) + } + + private fun parakeet( + id: String, + displayName: String, + sizeInMb: Int, + description: String, + speedRating: Double, + accuracyRating: Double, + languageSupport: ModelLanguageSupport, + availability: ModelAvailability = ModelAvailability.AVAILABLE, + ): LocalModelDescriptor { + return LocalModelDescriptor( + id = TranscriptionModelId(id), + family = LocalModelFamily.PARAKEET, + provider = LocalModelProvider.PARAKEET_COREML, + displayName = displayName, + languageSupport = languageSupport, + sizeInMb = sizeInMb, + description = description, + speedRating = speedRating, + accuracyRating = accuracyRating, + availability = availability, + ) + } +} diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionContracts.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionContracts.kt new file mode 100644 index 0000000..ff755ce --- /dev/null +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionContracts.kt @@ -0,0 +1,153 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import tech.watzon.pindrop.shared.core.ModelAvailability +import tech.watzon.pindrop.shared.core.ModelLanguageSupport +import tech.watzon.pindrop.shared.core.TranscriptionModelId +import tech.watzon.pindrop.shared.core.TranscriptionRequest +import tech.watzon.pindrop.shared.core.TranscriptionResult + +enum class LocalModelFamily { + WHISPER, + PARAKEET, +} + +enum class LocalModelProvider { + WHISPER_KIT, + WCPP, + PARAKEET_COREML, + PARAKEET_NATIVE, +} + +enum class LocalBackendId { + WHISPER_KIT, + WHISPER_CPP, + PARAKEET_APPLE, + PARAKEET_NATIVE, +} + +enum class LocalPlatformId { + MACOS, + WINDOWS, + LINUX, +} + +enum class ModelInstallState { + NOT_INSTALLED, + INSTALLING, + INSTALLED, + FAILED, +} + +enum class LocalRuntimeState { + UNLOADED, + LOADING, + READY, + TRANSCRIBING, + INSTALLING, + ERROR, +} + +enum class LocalRuntimeErrorCode { + MODEL_NOT_FOUND, + MODEL_NOT_INSTALLED, + BACKEND_UNAVAILABLE, + UNSUPPORTED_ON_PLATFORM, + ENGINE_SWITCH_DURING_TRANSCRIPTION, + INVALID_AUDIO_DATA, + TRANSCRIPTION_ALREADY_IN_PROGRESS, + TRANSCRIPTION_FAILED, + INSTALL_FAILED, + DELETE_FAILED, + LOAD_FAILED, +} + +data class ModelStorageLayout( + val installRootPath: String, + val modelPath: String?, +) + +data class LocalModelDescriptor( + val id: TranscriptionModelId, + val family: LocalModelFamily, + val provider: LocalModelProvider, + val displayName: String, + val languageSupport: ModelLanguageSupport, + val sizeInMb: Int, + val description: String, + val speedRating: Double, + val accuracyRating: Double, + val availability: ModelAvailability, +) + +data class ModelInstallProgress( + val modelId: TranscriptionModelId, + val progress: Double, + val state: ModelInstallState, + val message: String? = null, +) + +data class InstalledModelRecord( + val modelId: TranscriptionModelId, + val state: ModelInstallState, + val storage: ModelStorageLayout, + val installedProvider: LocalModelProvider? = null, + val lastError: String? = null, +) + +enum class LocalModelSelectionAction { + LOAD_SELECTED, + LOAD_FALLBACK, + DOWNLOAD_SELECTED, +} + +data class LocalModelSelectionResolution( + val action: LocalModelSelectionAction, + val resolvedModel: LocalModelDescriptor, + val updatedSelectedModelId: TranscriptionModelId, +) + +data class ActiveLocalModel( + val descriptor: LocalModelDescriptor, + val installedRecord: InstalledModelRecord? = null, + val loadedFromPath: String? = null, +) + +interface InstalledModelIndexPort { + suspend fun refreshInstalledModels(): List +} + +interface ModelInstallerPort { + suspend fun installModel( + model: LocalModelDescriptor, + onProgress: (ModelInstallProgress) -> Unit, + ): InstalledModelRecord + + suspend fun deleteModel(model: LocalModelDescriptor) +} + +interface LocalInferenceBackendPort { + val backendId: LocalBackendId + val supportedFamilies: Set + val supportsPathLoading: Boolean + + suspend fun loadModel( + model: LocalModelDescriptor, + installedRecord: InstalledModelRecord?, + ) + + suspend fun loadModelFromPath(path: String) + suspend fun transcribe(request: TranscriptionRequest): TranscriptionResult + suspend fun unloadModel() +} + +interface BackendRegistryPort { + fun preferredBackend(model: LocalModelDescriptor): LocalBackendId? + fun backend(id: LocalBackendId): LocalInferenceBackendPort? +} + +interface RuntimeObserver { + fun onStateChanged(state: LocalRuntimeState) + fun onActiveModelChanged(model: ActiveLocalModel?) + fun onInstallProgress(progress: ModelInstallProgress?) + fun onErrorChanged(errorCode: LocalRuntimeErrorCode?, message: String?) +} diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntime.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntime.kt new file mode 100644 index 0000000..cb0c5ed --- /dev/null +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntime.kt @@ -0,0 +1,272 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import tech.watzon.pindrop.shared.core.ModelAvailability +import tech.watzon.pindrop.shared.core.TranscriptionLanguage +import tech.watzon.pindrop.shared.core.TranscriptionModelId +import tech.watzon.pindrop.shared.core.TranscriptionRequest +import tech.watzon.pindrop.shared.core.TranscriptionResult + +class LocalTranscriptionRuntime( + private val platform: LocalPlatformId, + private val installedModelIndex: InstalledModelIndexPort, + private val modelInstaller: ModelInstallerPort, + private val backendRegistry: BackendRegistryPort, + private val observer: RuntimeObserver? = null, +) { + var state: LocalRuntimeState = LocalRuntimeState.UNLOADED + private set + + var activeModel: ActiveLocalModel? = null + private set + + var lastProgress: ModelInstallProgress? = null + private set + + var lastErrorCode: LocalRuntimeErrorCode? = null + private set + + var lastErrorMessage: String? = null + private set + + private var installedModels: List = emptyList() + private var activeBackend: LocalInferenceBackendPort? = null + + fun catalog(): List = LocalTranscriptionCatalog.models(platform) + + fun recommendedModels(language: TranscriptionLanguage): List { + return LocalTranscriptionCatalog.recommendedModels(platform, language) + } + + suspend fun refreshInstalledModels(): List { + installedModels = installedModelIndex.refreshInstalledModels() + return installedModels + } + + fun installedModels(): List = installedModels + + fun resolveStartupModel( + selectedModelId: TranscriptionModelId, + defaultModelId: TranscriptionModelId, + ): LocalModelSelectionResolution { + val models = catalog() + val normalizedSelectedModel = models.firstOrNull { it.id == selectedModelId } + ?: models.firstOrNull { it.id == defaultModelId } + ?: models.first() + + val installedSet = installedModels + .filter { it.state == ModelInstallState.INSTALLED } + .map { it.modelId } + .toSet() + + if (normalizedSelectedModel.id in installedSet) { + return LocalModelSelectionResolution( + action = LocalModelSelectionAction.LOAD_SELECTED, + resolvedModel = normalizedSelectedModel, + updatedSelectedModelId = normalizedSelectedModel.id, + ) + } + + val fallbackModel = models.firstOrNull { it.id in installedSet } + if (fallbackModel != null) { + return LocalModelSelectionResolution( + action = LocalModelSelectionAction.LOAD_FALLBACK, + resolvedModel = fallbackModel, + updatedSelectedModelId = fallbackModel.id, + ) + } + + return LocalModelSelectionResolution( + action = LocalModelSelectionAction.DOWNLOAD_SELECTED, + resolvedModel = normalizedSelectedModel, + updatedSelectedModelId = normalizedSelectedModel.id, + ) + } + + suspend fun installModel(modelId: TranscriptionModelId): InstalledModelRecord { + val model = requireModel(modelId) + transitionTo(LocalRuntimeState.INSTALLING) + clearError() + + return runCatching { + val record = modelInstaller.installModel(model) { progress -> + lastProgress = progress + observer?.onInstallProgress(progress) + } + installedModels = refreshInstalledModels() + transitionTo(if (activeModel != null) LocalRuntimeState.READY else LocalRuntimeState.UNLOADED) + record + }.getOrElse { error -> + setError(LocalRuntimeErrorCode.INSTALL_FAILED, error.message) + transitionTo(LocalRuntimeState.ERROR) + throw error + } + } + + suspend fun deleteModel(modelId: TranscriptionModelId) { + val model = requireModel(modelId) + clearError() + + runCatching { + if (activeModel?.descriptor?.id == modelId) { + unloadModel() + } + modelInstaller.deleteModel(model) + installedModels = refreshInstalledModels() + }.getOrElse { error -> + setError(LocalRuntimeErrorCode.DELETE_FAILED, error.message) + throw error + } + } + + suspend fun loadModel(modelId: TranscriptionModelId) { + if (state == LocalRuntimeState.TRANSCRIBING) { + setError(LocalRuntimeErrorCode.ENGINE_SWITCH_DURING_TRANSCRIPTION, null) + return + } + + val model = requireModel(modelId) + if (model.availability != ModelAvailability.AVAILABLE) { + setError(LocalRuntimeErrorCode.UNSUPPORTED_ON_PLATFORM, "Model ${model.id.value} is not available") + transitionTo(LocalRuntimeState.ERROR) + return + } + + val installedRecord = installedModels.firstOrNull { + it.modelId == modelId && it.state == ModelInstallState.INSTALLED + } + if (installedRecord == null) { + setError(LocalRuntimeErrorCode.MODEL_NOT_INSTALLED, null) + transitionTo(LocalRuntimeState.ERROR) + return + } + + val backendId = backendRegistry.preferredBackend(model) + val backend = backendId?.let(backendRegistry::backend) + if (backend == null || model.family !in backend.supportedFamilies) { + setError(LocalRuntimeErrorCode.BACKEND_UNAVAILABLE, null) + transitionTo(LocalRuntimeState.ERROR) + return + } + + clearError() + transitionTo(LocalRuntimeState.LOADING) + + runCatching { + if (activeBackend != null && activeBackend !== backend) { + activeBackend?.unloadModel() + } + + backend.loadModel(model, installedRecord) + activeBackend = backend + activeModel = ActiveLocalModel(model, installedRecord = installedRecord) + observer?.onActiveModelChanged(activeModel) + transitionTo(LocalRuntimeState.READY) + }.getOrElse { error -> + setError(LocalRuntimeErrorCode.LOAD_FAILED, error.message) + transitionTo(LocalRuntimeState.ERROR) + throw error + } + } + + suspend fun loadModelFromPath(path: String, family: LocalModelFamily = LocalModelFamily.WHISPER) { + if (state == LocalRuntimeState.TRANSCRIBING) { + setError(LocalRuntimeErrorCode.ENGINE_SWITCH_DURING_TRANSCRIPTION, null) + return + } + + val backend = backendRegistry.backend( + when (family) { + LocalModelFamily.WHISPER -> LocalBackendId.WHISPER_KIT + LocalModelFamily.PARAKEET -> LocalBackendId.PARAKEET_APPLE + }, + ) + + if (backend == null || !backend.supportsPathLoading) { + setError(LocalRuntimeErrorCode.BACKEND_UNAVAILABLE, null) + transitionTo(LocalRuntimeState.ERROR) + return + } + + clearError() + transitionTo(LocalRuntimeState.LOADING) + + runCatching { + if (activeBackend != null && activeBackend !== backend) { + activeBackend?.unloadModel() + } + backend.loadModelFromPath(path) + activeBackend = backend + activeModel = null + observer?.onActiveModelChanged(null) + transitionTo(LocalRuntimeState.READY) + }.getOrElse { error -> + setError(LocalRuntimeErrorCode.LOAD_FAILED, error.message) + transitionTo(LocalRuntimeState.ERROR) + throw error + } + } + + suspend fun transcribe(request: TranscriptionRequest): TranscriptionResult { + val backend = activeBackend + if (backend == null) { + setError(LocalRuntimeErrorCode.MODEL_NOT_INSTALLED, "No model is loaded") + transitionTo(LocalRuntimeState.ERROR) + error("No local transcription backend is loaded") + } + if (request.audioData.isEmpty()) { + setError(LocalRuntimeErrorCode.INVALID_AUDIO_DATA, null) + throw IllegalArgumentException("Audio data must not be empty") + } + if (state == LocalRuntimeState.TRANSCRIBING) { + setError(LocalRuntimeErrorCode.TRANSCRIPTION_ALREADY_IN_PROGRESS, null) + throw IllegalStateException("Transcription already in progress") + } + + clearError() + transitionTo(LocalRuntimeState.TRANSCRIBING) + + return runCatching { + backend.transcribe(request) + }.onSuccess { + transitionTo(LocalRuntimeState.READY) + }.getOrElse { error -> + setError(LocalRuntimeErrorCode.TRANSCRIPTION_FAILED, error.message) + transitionTo(LocalRuntimeState.ERROR) + throw error + } + } + + suspend fun unloadModel() { + activeBackend?.unloadModel() + activeBackend = null + activeModel = null + observer?.onActiveModelChanged(null) + clearError() + transitionTo(LocalRuntimeState.UNLOADED) + } + + private fun requireModel(modelId: TranscriptionModelId): LocalModelDescriptor { + return LocalTranscriptionCatalog.model(platform, modelId) + ?: run { + setError(LocalRuntimeErrorCode.MODEL_NOT_FOUND, modelId.value) + error("Model ${modelId.value} not found") + } + } + + private fun transitionTo(newState: LocalRuntimeState) { + state = newState + observer?.onStateChanged(newState) + } + + private fun setError(errorCode: LocalRuntimeErrorCode, message: String?) { + lastErrorCode = errorCode + lastErrorMessage = message + observer?.onErrorChanged(errorCode, message) + } + + private fun clearError() { + lastErrorCode = null + lastErrorMessage = null + observer?.onErrorChanged(null, null) + } +} diff --git a/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntimeTest.kt b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntimeTest.kt new file mode 100644 index 0000000..5356718 --- /dev/null +++ b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntimeTest.kt @@ -0,0 +1,241 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import tech.watzon.pindrop.shared.core.ModelLanguageSupport +import tech.watzon.pindrop.shared.core.TranscriptionLanguage +import tech.watzon.pindrop.shared.core.TranscriptionModelId +import tech.watzon.pindrop.shared.core.TranscriptionRequest +import tech.watzon.pindrop.shared.core.TranscriptionResult + +class LocalTranscriptionRuntimeTest { + @Test + fun catalogPreservesRecommendedOrder() { + val models = LocalTranscriptionCatalog.recommendedModels( + platform = LocalPlatformId.MACOS, + language = TranscriptionLanguage.ENGLISH, + ) + + assertEquals( + listOf( + "openai_whisper-base.en", + "openai_whisper-small.en", + "openai_whisper-medium", + "openai_whisper-large-v3_turbo", + "parakeet-tdt-0.6b-v2", + ), + models.map { it.id.value }, + ) + } + + @Test + fun catalogMapsProviderByPlatform() { + val macWhisper = LocalTranscriptionCatalog.model(LocalPlatformId.MACOS, TranscriptionModelId("openai_whisper-base")) + val linuxWhisper = LocalTranscriptionCatalog.model(LocalPlatformId.LINUX, TranscriptionModelId("openai_whisper-base")) + val windowsParakeet = LocalTranscriptionCatalog.model(LocalPlatformId.WINDOWS, TranscriptionModelId("parakeet-tdt-0.6b-v3")) + + assertEquals(LocalModelProvider.WHISPER_KIT, macWhisper?.provider) + assertEquals(LocalModelProvider.WCPP, linuxWhisper?.provider) + assertEquals(LocalModelProvider.PARAKEET_NATIVE, windowsParakeet?.provider) + } + + @Test + fun startupResolutionMatchesSelectedFallbackRules() = runTest { + val runtime = runtimeWith( + installed = listOf( + InstalledModelRecord( + modelId = TranscriptionModelId("openai_whisper-base"), + state = ModelInstallState.INSTALLED, + storage = ModelStorageLayout("/tmp", "/tmp/base"), + installedProvider = LocalModelProvider.WHISPER_KIT, + ), + ), + ) + + runtime.refreshInstalledModels() + val resolution = runtime.resolveStartupModel( + selectedModelId = TranscriptionModelId("missing"), + defaultModelId = TranscriptionModelId("openai_whisper-base.en"), + ) + + assertEquals(LocalModelSelectionAction.LOAD_FALLBACK, resolution.action) + assertEquals("openai_whisper-base", resolution.updatedSelectedModelId.value) + } + + @Test + fun runtimeLoadsTranscribesAndUnloadsThroughBackend() = runTest { + val backend = FakeBackend() + val runtime = runtimeWith( + installed = listOf( + InstalledModelRecord( + modelId = TranscriptionModelId("openai_whisper-base"), + state = ModelInstallState.INSTALLED, + storage = ModelStorageLayout("/tmp", "/tmp/base"), + installedProvider = LocalModelProvider.WHISPER_KIT, + ), + ), + backendRegistry = FakeBackendRegistry( + preferredByModelId = mapOf("openai_whisper-base" to LocalBackendId.WHISPER_KIT), + backends = mapOf(LocalBackendId.WHISPER_KIT to backend), + ), + ) + + runtime.refreshInstalledModels() + runtime.loadModel(TranscriptionModelId("openai_whisper-base")) + val result = runtime.transcribe(TranscriptionRequest(audioData = byteArrayOf(1, 2, 3))) + runtime.unloadModel() + + assertEquals("ok", result.text) + assertEquals(LocalRuntimeState.UNLOADED, runtime.state) + assertEquals(1, backend.transcribeCalls) + } + + @Test + fun runtimeRejectsEmptyAudio() = runTest { + val backend = FakeBackend() + val runtime = runtimeWith( + installed = listOf( + InstalledModelRecord( + modelId = TranscriptionModelId("openai_whisper-base"), + state = ModelInstallState.INSTALLED, + storage = ModelStorageLayout("/tmp", "/tmp/base"), + installedProvider = LocalModelProvider.WHISPER_KIT, + ), + ), + backendRegistry = FakeBackendRegistry( + preferredByModelId = mapOf("openai_whisper-base" to LocalBackendId.WHISPER_KIT), + backends = mapOf(LocalBackendId.WHISPER_KIT to backend), + ), + ) + + runtime.refreshInstalledModels() + runtime.loadModel(TranscriptionModelId("openai_whisper-base")) + + assertFailsWith { + runtime.transcribe(TranscriptionRequest(audioData = byteArrayOf())) + } + assertEquals(LocalRuntimeErrorCode.INVALID_AUDIO_DATA, runtime.lastErrorCode) + } + + @Test + fun runtimeReportsBackendUnavailable() = runTest { + val runtime = runtimeWith( + installed = listOf( + InstalledModelRecord( + modelId = TranscriptionModelId("parakeet-tdt-0.6b-v3"), + state = ModelInstallState.INSTALLED, + storage = ModelStorageLayout("/tmp", "/tmp/v3"), + installedProvider = LocalModelProvider.PARAKEET_NATIVE, + ), + ), + backendRegistry = FakeBackendRegistry( + preferredByModelId = emptyMap(), + backends = emptyMap(), + ), + ) + + runtime.refreshInstalledModels() + runtime.loadModel(TranscriptionModelId("parakeet-tdt-0.6b-v3")) + + assertEquals(LocalRuntimeState.ERROR, runtime.state) + assertEquals(LocalRuntimeErrorCode.BACKEND_UNAVAILABLE, runtime.lastErrorCode) + } + + @Test + fun languageSupportMatchesCurrentSemantics() { + assertNotNull( + LocalTranscriptionCatalog.recommendedModels(LocalPlatformId.MACOS, TranscriptionLanguage.SPANISH) + .firstOrNull { it.languageSupport == ModelLanguageSupport.PARAKEET_V3_EUROPEAN }, + ) + assertNull( + LocalTranscriptionCatalog.recommendedModels(LocalPlatformId.MACOS, TranscriptionLanguage.SPANISH) + .firstOrNull { it.languageSupport == ModelLanguageSupport.ENGLISH_ONLY }, + ) + } + + private fun runtimeWith( + installed: List, + backendRegistry: BackendRegistryPort = FakeBackendRegistry( + preferredByModelId = mapOf("openai_whisper-base" to LocalBackendId.WHISPER_KIT), + backends = mapOf(LocalBackendId.WHISPER_KIT to FakeBackend()), + ), + ): LocalTranscriptionRuntime { + return LocalTranscriptionRuntime( + platform = LocalPlatformId.MACOS, + installedModelIndex = FakeInstalledModelIndex(installed), + modelInstaller = FakeInstaller(installed.toMutableList()), + backendRegistry = backendRegistry, + ) + } +} + +private class FakeInstalledModelIndex( + private val installed: List, +) : InstalledModelIndexPort { + override suspend fun refreshInstalledModels(): List = installed +} + +private class FakeInstaller( + private val installed: MutableList, +) : ModelInstallerPort { + override suspend fun installModel( + model: LocalModelDescriptor, + onProgress: (ModelInstallProgress) -> Unit, + ): InstalledModelRecord { + onProgress( + ModelInstallProgress( + modelId = model.id, + progress = 1.0, + state = ModelInstallState.INSTALLED, + ), + ) + return InstalledModelRecord( + modelId = model.id, + state = ModelInstallState.INSTALLED, + storage = ModelStorageLayout("/tmp", "/tmp/${model.id.value}"), + installedProvider = model.provider, + ).also(installed::add) + } + + override suspend fun deleteModel(model: LocalModelDescriptor) { + installed.removeAll { it.modelId == model.id } + } +} + +private class FakeBackend : LocalInferenceBackendPort { + override val backendId: LocalBackendId = LocalBackendId.WHISPER_KIT + override val supportedFamilies: Set = setOf(LocalModelFamily.WHISPER, LocalModelFamily.PARAKEET) + override val supportsPathLoading: Boolean = true + var transcribeCalls: Int = 0 + + override suspend fun loadModel( + model: LocalModelDescriptor, + installedRecord: InstalledModelRecord?, + ) = Unit + + override suspend fun loadModelFromPath(path: String) = Unit + + override suspend fun transcribe(request: TranscriptionRequest): TranscriptionResult { + transcribeCalls += 1 + return TranscriptionResult(text = "ok") + } + + override suspend fun unloadModel() = Unit +} + +private class FakeBackendRegistry( + private val preferredByModelId: Map, + private val backends: Map, +) : BackendRegistryPort { + override fun preferredBackend(model: LocalModelDescriptor): LocalBackendId? { + return preferredByModelId[model.id.value] + } + + override fun backend(id: LocalBackendId): LocalInferenceBackendPort? { + return backends[id] + } +} diff --git a/shared/settings.gradle.kts b/shared/settings.gradle.kts index e4d0f4b..b58cec9 100644 --- a/shared/settings.gradle.kts +++ b/shared/settings.gradle.kts @@ -17,6 +17,7 @@ rootProject.name = "pindrop-shared" include(":core") include(":feature-transcription") +include(":runtime-transcription") include(":ui-shell") include(":ui-settings") include(":ui-theme") From 7504129c2252267560ae23aebd1a11f8a4a75759 Mon Sep 17 00:00:00 2001 From: Chris Watson Date: Sat, 28 Mar 2026 23:09:36 -0600 Subject: [PATCH 2/5] Add local transcription session coordination - Add shared VoiceSessionCoordinator for recording, permissions, and results - Introduce filesystem-backed model storage and Okio test support - Cover startup, recording, transcription, and storage flows with tests --- shared/feature-transcription/build.gradle.kts | 1 + .../transcription/VoiceSessionCoordinator.kt | 491 +++++++++++++++++ .../VoiceSessionCoordinatorTest.kt | 510 ++++++++++++++++++ shared/runtime-transcription/build.gradle.kts | 2 + .../transcription/FileSystemModelStorage.kt | 229 ++++++++ .../LocalTranscriptionCatalog.kt | 61 ++- .../LocalTranscriptionContracts.kt | 1 + .../LocalTranscriptionRuntime.kt | 18 +- .../FileSystemModelStorageTest.kt | 175 ++++++ 9 files changed, 1469 insertions(+), 19 deletions(-) create mode 100644 shared/feature-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/feature/transcription/VoiceSessionCoordinator.kt create mode 100644 shared/feature-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/feature/transcription/VoiceSessionCoordinatorTest.kt create mode 100644 shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/FileSystemModelStorage.kt create mode 100644 shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/FileSystemModelStorageTest.kt diff --git a/shared/feature-transcription/build.gradle.kts b/shared/feature-transcription/build.gradle.kts index e91705f..0f0e2f4 100644 --- a/shared/feature-transcription/build.gradle.kts +++ b/shared/feature-transcription/build.gradle.kts @@ -36,6 +36,7 @@ kotlin { } commonTest.dependencies { implementation(kotlin("test")) + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.10.2") } } } diff --git a/shared/feature-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/feature/transcription/VoiceSessionCoordinator.kt b/shared/feature-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/feature/transcription/VoiceSessionCoordinator.kt new file mode 100644 index 0000000..c2116d6 --- /dev/null +++ b/shared/feature-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/feature/transcription/VoiceSessionCoordinator.kt @@ -0,0 +1,491 @@ +package tech.watzon.pindrop.shared.feature.transcription + +import tech.watzon.pindrop.shared.core.TranscriptionLanguage +import tech.watzon.pindrop.shared.core.TranscriptionModelId +import tech.watzon.pindrop.shared.core.TranscriptionRequest +import tech.watzon.pindrop.shared.runtime.transcription.LocalModelSelectionAction +import tech.watzon.pindrop.shared.runtime.transcription.LocalModelSelectionResolution +import tech.watzon.pindrop.shared.runtime.transcription.LocalRuntimeErrorCode +import tech.watzon.pindrop.shared.runtime.transcription.LocalTranscriptionRuntime + +enum class VoiceOutputMode { + CLIPBOARD, + DIRECT_INSERT, +} + +enum class PermissionStatus { + GRANTED, + DENIED, + NOT_DETERMINED, + RESTRICTED, + UNSUPPORTED, +} + +enum class HotkeyModifier { + CTRL, + ALT, + SHIFT, + META, +} + +enum class HotkeyMode { + TOGGLE, + PUSH_TO_TALK, +} + +data class HotkeyBinding( + val key: String, + val modifiers: Set = emptySet(), +) + +data class VoiceSettingsSnapshot( + val selectedModelId: TranscriptionModelId, + val selectedLanguage: TranscriptionLanguage = TranscriptionLanguage.AUTOMATIC, + val preferredInputDeviceId: String? = null, + val outputMode: VoiceOutputMode = VoiceOutputMode.CLIPBOARD, + val toggleHotkey: HotkeyBinding? = null, + val pushToTalkHotkey: HotkeyBinding? = null, + val launchOnStartupEnabled: Boolean = false, + val hasCompletedOnboarding: Boolean = false, +) + +data class TranscriptHistoryEntry( + val id: String, + val timestampEpochMillis: Long, + val text: String, + val durationMs: Long, + val modelId: TranscriptionModelId, +) + +enum class VoiceSessionState { + IDLE, + STARTING, + RECORDING, + PROCESSING, + COMPLETED, + ERROR, +} + +enum class VoiceSessionError { + MICROPHONE_PERMISSION_DENIED, + AUDIO_START_FAILED, + AUDIO_STOP_FAILED, + MODEL_NOT_INSTALLED, + MODEL_LOAD_FAILED, + TRANSCRIPTION_FAILED, + CLIPBOARD_WRITE_FAILED, + UNSUPPORTED_PLATFORM_INTEGRATION, +} + +data class VoiceSessionUiState( + val state: VoiceSessionState, + val activeModelId: TranscriptionModelId? = null, + val requiresModelInstallation: Boolean = false, + val canRecord: Boolean = true, + val message: String? = null, +) + +data class VoiceSessionBootstrapResult( + val settings: VoiceSettingsSnapshot, + val startupModel: LocalModelSelectionResolution?, + val requiresModelInstallation: Boolean, +) + +enum class VoiceSessionStopReason { + TRANSCRIPT_READY, + NO_SPEECH_DETECTED, + MODEL_INSTALL_REQUIRED, + FAILED, +} + +data class VoiceSessionStopResult( + val reason: VoiceSessionStopReason, + val transcript: String? = null, + val modelId: TranscriptionModelId? = null, + val durationMs: Long = 0, +) + +interface AudioCapturePort { + suspend fun startCapture() + suspend fun stopCapture(): ByteArray + suspend fun cancelCapture() + fun isCapturing(): Boolean + fun setPreferredInputDevice(deviceId: String?) +} + +interface ClipboardPort { + fun copyText(text: String): Boolean +} + +interface HotkeyRegistrationPort { + fun register(binding: HotkeyBinding, actionId: String, mode: HotkeyMode) + fun unregisterAll() +} + +interface SettingsStorePort { + fun load(): VoiceSettingsSnapshot + fun save(snapshot: VoiceSettingsSnapshot) +} + +interface TranscriptHistoryPort { + suspend fun save(entry: TranscriptHistoryEntry) + suspend fun latest(): TranscriptHistoryEntry? +} + +interface PermissionPort { + suspend fun microphoneStatus(): PermissionStatus + suspend fun requestMicrophonePermission(): PermissionStatus +} + +interface VoiceSessionEventSink { + fun onStateChanged(state: VoiceSessionUiState) + fun onError(error: VoiceSessionError) + fun onTranscriptReady(text: String) +} + +fun interface TimestampProvider { + fun nowEpochMillis(): Long +} + +class VoiceSessionCoordinator( + private val runtime: LocalTranscriptionRuntime, + private val audioCapture: AudioCapturePort, + private val clipboard: ClipboardPort, + private val permissions: PermissionPort, + private val settingsStore: SettingsStorePort, + private val eventSink: VoiceSessionEventSink, + private val history: TranscriptHistoryPort? = null, + private val timestampProvider: TimestampProvider = TimestampProvider { 0L }, + private val supportsDirectInsert: Boolean = false, +) { + private var currentSettings: VoiceSettingsSnapshot? = null + private var startupModel: LocalModelSelectionResolution? = null + private var activeModelId: TranscriptionModelId? = null + private var recordingStartedAtEpochMillis: Long? = null + private var initialized = false + + suspend fun initialize(): VoiceSessionBootstrapResult { + val settings = settingsStore.load() + currentSettings = settings + runtime.refreshInstalledModels() + + val resolution = runtime.catalog().takeIf { it.isNotEmpty() }?.let { + runtime.resolveStartupModel( + selectedModelId = settings.selectedModelId, + defaultModelId = settings.selectedModelId, + ) + } + startupModel = resolution + activeModelId = when (resolution?.action) { + LocalModelSelectionAction.LOAD_SELECTED, + LocalModelSelectionAction.LOAD_FALLBACK, + -> resolution.updatedSelectedModelId + LocalModelSelectionAction.DOWNLOAD_SELECTED, + null, + -> null + } + initialized = true + + val requiresModelInstallation = resolution?.action == LocalModelSelectionAction.DOWNLOAD_SELECTED + eventSink.onStateChanged( + VoiceSessionUiState( + state = VoiceSessionState.IDLE, + activeModelId = activeModelId, + requiresModelInstallation = requiresModelInstallation, + canRecord = !requiresModelInstallation, + message = if (requiresModelInstallation) { + "Install a local transcription model to begin." + } else { + null + }, + ), + ) + + return VoiceSessionBootstrapResult( + settings = settings, + startupModel = resolution, + requiresModelInstallation = requiresModelInstallation, + ) + } + + suspend fun startRecording(): Boolean { + ensureInitialized() + + if (audioCapture.isCapturing()) { + eventSink.onError(VoiceSessionError.AUDIO_START_FAILED) + return false + } + + val settings = currentSettings ?: settingsStore.load().also { currentSettings = it } + if (startupModel?.action == LocalModelSelectionAction.DOWNLOAD_SELECTED || activeModelId == null) { + transitionToError( + VoiceSessionError.MODEL_NOT_INSTALLED, + message = "Install a local transcription model before recording.", + ) + return false + } + + val microphoneStatus = permissions.microphoneStatus() + val resolvedPermission = when (microphoneStatus) { + PermissionStatus.NOT_DETERMINED -> permissions.requestMicrophonePermission() + else -> microphoneStatus + } + if (resolvedPermission != PermissionStatus.GRANTED) { + transitionToError( + VoiceSessionError.MICROPHONE_PERMISSION_DENIED, + message = "Microphone permission is required to start recording.", + ) + return false + } + + audioCapture.setPreferredInputDevice(settings.preferredInputDeviceId) + eventSink.onStateChanged( + VoiceSessionUiState( + state = VoiceSessionState.STARTING, + activeModelId = activeModelId, + message = "Starting microphone capture…", + ), + ) + + return runCatching { + audioCapture.startCapture() + recordingStartedAtEpochMillis = timestampProvider.nowEpochMillis() + eventSink.onStateChanged( + VoiceSessionUiState( + state = VoiceSessionState.RECORDING, + activeModelId = activeModelId, + ), + ) + true + }.getOrElse { + transitionToError( + VoiceSessionError.AUDIO_START_FAILED, + message = "Unable to start microphone capture.", + ) + false + } + } + + suspend fun stopRecording(): VoiceSessionStopResult { + ensureInitialized() + + if (!audioCapture.isCapturing()) { + transitionToError( + VoiceSessionError.AUDIO_STOP_FAILED, + message = "Recording is not active.", + ) + return VoiceSessionStopResult(reason = VoiceSessionStopReason.FAILED) + } + + val settings = currentSettings ?: settingsStore.load().also { currentSettings = it } + val modelId = activeModelId ?: startupModel?.updatedSelectedModelId + if (modelId == null) { + transitionToError( + VoiceSessionError.MODEL_NOT_INSTALLED, + message = "Install a local transcription model before recording.", + ) + return VoiceSessionStopResult(reason = VoiceSessionStopReason.MODEL_INSTALL_REQUIRED) + } + + eventSink.onStateChanged( + VoiceSessionUiState( + state = VoiceSessionState.PROCESSING, + activeModelId = modelId, + message = "Transcribing locally…", + ), + ) + + val audioData = runCatching { audioCapture.stopCapture() }.getOrElse { + transitionToError( + VoiceSessionError.AUDIO_STOP_FAILED, + activeModelId = modelId, + message = "Unable to stop microphone capture.", + ) + return VoiceSessionStopResult(reason = VoiceSessionStopReason.FAILED) + } + + val durationMs = durationSinceStart() + if (audioData.isEmpty()) { + val result = VoiceSessionStopResult( + reason = VoiceSessionStopReason.NO_SPEECH_DETECTED, + modelId = modelId, + durationMs = durationMs, + ) + transitionToCompleted( + activeModelId = modelId, + message = "No speech detected.", + ) + return result + } + + if (runtime.activeModel?.descriptor?.id != modelId) { + runCatching { runtime.loadModel(modelId) }.getOrElse { + transitionToError( + mapRuntimeError(runtime.lastErrorCode), + activeModelId = modelId, + message = "Unable to load the selected transcription model.", + ) + return VoiceSessionStopResult(reason = VoiceSessionStopReason.FAILED) + } + activeModelId = modelId + } + + val normalizedTranscript = runCatching { + runtime.transcribe( + TranscriptionRequest( + audioData = audioData, + language = settings.selectedLanguage, + ), + ).text + }.map { + SharedTranscriptionOrchestrator.normalizeTranscriptionText(it) + }.getOrElse { + transitionToError( + mapRuntimeError(runtime.lastErrorCode), + activeModelId = modelId, + message = "Transcription failed.", + ) + return VoiceSessionStopResult(reason = VoiceSessionStopReason.FAILED) + } + + if (SharedTranscriptionOrchestrator.isTranscriptionEffectivelyEmpty(normalizedTranscript)) { + val result = VoiceSessionStopResult( + reason = VoiceSessionStopReason.NO_SPEECH_DETECTED, + modelId = modelId, + durationMs = durationMs, + ) + transitionToCompleted( + activeModelId = modelId, + message = "No speech detected.", + ) + return result + } + + val copied = when (settings.outputMode) { + VoiceOutputMode.CLIPBOARD -> clipboard.copyText(normalizedTranscript) + VoiceOutputMode.DIRECT_INSERT -> clipboard.copyText(normalizedTranscript) + } + + if (!copied) { + transitionToError( + VoiceSessionError.CLIPBOARD_WRITE_FAILED, + activeModelId = modelId, + message = "Transcription completed, but copying to the clipboard failed.", + ) + return VoiceSessionStopResult(reason = VoiceSessionStopReason.FAILED) + } + + history?.save( + TranscriptHistoryEntry( + id = generatedTranscriptId(), + timestampEpochMillis = timestampProvider.nowEpochMillis(), + text = normalizedTranscript, + durationMs = durationMs, + modelId = modelId, + ), + ) + + eventSink.onTranscriptReady(normalizedTranscript) + transitionToCompleted( + activeModelId = modelId, + message = if (settings.outputMode == VoiceOutputMode.DIRECT_INSERT && !supportsDirectInsert) { + "Direct insert is unavailable on this platform. Copied transcript to the clipboard instead." + } else { + "Copied transcript to the clipboard." + }, + ) + + return VoiceSessionStopResult( + reason = VoiceSessionStopReason.TRANSCRIPT_READY, + transcript = normalizedTranscript, + modelId = modelId, + durationMs = durationMs, + ) + } + + suspend fun cancelRecording() { + if (!audioCapture.isCapturing()) { + return + } + + audioCapture.cancelCapture() + recordingStartedAtEpochMillis = null + eventSink.onStateChanged( + VoiceSessionUiState( + state = VoiceSessionState.IDLE, + activeModelId = activeModelId, + message = "Recording canceled.", + ), + ) + } + + fun isRecording(): Boolean = audioCapture.isCapturing() + + private suspend fun ensureInitialized() { + if (!initialized) { + initialize() + } + } + + private fun transitionToCompleted( + activeModelId: TranscriptionModelId?, + message: String, + ) { + recordingStartedAtEpochMillis = null + eventSink.onStateChanged( + VoiceSessionUiState( + state = VoiceSessionState.COMPLETED, + activeModelId = activeModelId, + message = message, + ), + ) + } + + private fun transitionToError( + error: VoiceSessionError, + activeModelId: TranscriptionModelId? = this.activeModelId, + message: String, + ) { + recordingStartedAtEpochMillis = null + eventSink.onError(error) + eventSink.onStateChanged( + VoiceSessionUiState( + state = VoiceSessionState.ERROR, + activeModelId = activeModelId, + canRecord = error != VoiceSessionError.MODEL_NOT_INSTALLED, + requiresModelInstallation = error == VoiceSessionError.MODEL_NOT_INSTALLED, + message = message, + ), + ) + } + + private fun durationSinceStart(): Long { + val startedAt = recordingStartedAtEpochMillis ?: return 0 + return (timestampProvider.nowEpochMillis() - startedAt).coerceAtLeast(0) + } + + private fun mapRuntimeError(errorCode: LocalRuntimeErrorCode?): VoiceSessionError { + return when (errorCode) { + LocalRuntimeErrorCode.MODEL_NOT_FOUND, + LocalRuntimeErrorCode.MODEL_NOT_INSTALLED, + -> VoiceSessionError.MODEL_NOT_INSTALLED + LocalRuntimeErrorCode.LOAD_FAILED, + LocalRuntimeErrorCode.BACKEND_UNAVAILABLE, + LocalRuntimeErrorCode.UNSUPPORTED_ON_PLATFORM, + -> VoiceSessionError.MODEL_LOAD_FAILED + LocalRuntimeErrorCode.TRANSCRIPTION_FAILED, + LocalRuntimeErrorCode.INVALID_AUDIO_DATA, + LocalRuntimeErrorCode.TRANSCRIPTION_ALREADY_IN_PROGRESS, + LocalRuntimeErrorCode.ENGINE_SWITCH_DURING_TRANSCRIPTION, + LocalRuntimeErrorCode.INSTALL_FAILED, + LocalRuntimeErrorCode.DELETE_FAILED, + null, + -> VoiceSessionError.TRANSCRIPTION_FAILED + } + } + + private fun generatedTranscriptId(): String { + val modelComponent = activeModelId?.value ?: "transcript" + return "$modelComponent-${timestampProvider.nowEpochMillis()}" + } +} diff --git a/shared/feature-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/feature/transcription/VoiceSessionCoordinatorTest.kt b/shared/feature-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/feature/transcription/VoiceSessionCoordinatorTest.kt new file mode 100644 index 0000000..80ef6cb --- /dev/null +++ b/shared/feature-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/feature/transcription/VoiceSessionCoordinatorTest.kt @@ -0,0 +1,510 @@ +package tech.watzon.pindrop.shared.feature.transcription + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import tech.watzon.pindrop.shared.core.ModelLanguageSupport +import tech.watzon.pindrop.shared.core.TranscriptionLanguage +import tech.watzon.pindrop.shared.core.TranscriptionModelId +import tech.watzon.pindrop.shared.core.TranscriptionRequest +import tech.watzon.pindrop.shared.core.TranscriptionResult +import tech.watzon.pindrop.shared.runtime.transcription.BackendRegistryPort +import tech.watzon.pindrop.shared.runtime.transcription.InstalledModelIndexPort +import tech.watzon.pindrop.shared.runtime.transcription.InstalledModelRecord +import tech.watzon.pindrop.shared.runtime.transcription.LocalBackendId +import tech.watzon.pindrop.shared.runtime.transcription.LocalInferenceBackendPort +import tech.watzon.pindrop.shared.runtime.transcription.LocalModelDescriptor +import tech.watzon.pindrop.shared.runtime.transcription.LocalModelFamily +import tech.watzon.pindrop.shared.runtime.transcription.LocalModelProvider +import tech.watzon.pindrop.shared.runtime.transcription.LocalPlatformId +import tech.watzon.pindrop.shared.runtime.transcription.LocalRuntimeErrorCode +import tech.watzon.pindrop.shared.runtime.transcription.LocalTranscriptionRuntime +import tech.watzon.pindrop.shared.runtime.transcription.ModelInstallProgress +import tech.watzon.pindrop.shared.runtime.transcription.ModelInstallState +import tech.watzon.pindrop.shared.runtime.transcription.ModelInstallerPort +import tech.watzon.pindrop.shared.runtime.transcription.ModelStorageLayout + +class VoiceSessionCoordinatorTest { + @Test + fun initializeUsesInstalledSelectedModel() = runTest { + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + ) + + val result = fixture.coordinator.initialize() + + assertFalse(result.requiresModelInstallation) + assertEquals("openai_whisper-base", result.startupModel?.updatedSelectedModelId?.value) + assertEquals(VoiceSessionState.IDLE, fixture.eventSink.states.last().state) + } + + @Test + fun initializeFallsBackToInstalledModelWhenSelectionIsMissing() = runTest { + val fixture = fixture( + settings = defaultSettings(selectedModelId = "missing-model"), + installed = listOf(installedRecord("openai_whisper-base")), + ) + + val result = fixture.coordinator.initialize() + + assertFalse(result.requiresModelInstallation) + assertEquals("openai_whisper-base", result.startupModel?.updatedSelectedModelId?.value) + } + + @Test + fun initializeRequiresInstallationWhenNoModelIsInstalled() = runTest { + val fixture = fixture(settings = defaultSettings(selectedModelId = "openai_whisper-base")) + + val result = fixture.coordinator.initialize() + + assertTrue(result.requiresModelInstallation) + assertTrue(fixture.eventSink.states.last().requiresModelInstallation) + } + + @Test + fun startRecordingRequestsPermissionLazily() = runTest { + val permissions = FakePermissionPort( + status = PermissionStatus.NOT_DETERMINED, + requestedStatus = PermissionStatus.GRANTED, + ) + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + permissions = permissions, + ) + fixture.coordinator.initialize() + + val didStart = fixture.coordinator.startRecording() + + assertTrue(didStart) + assertEquals(1, permissions.requestCalls) + assertEquals(VoiceSessionState.RECORDING, fixture.eventSink.states.last().state) + } + + @Test + fun startRecordingFailsWhenPermissionDenied() = runTest { + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + permissions = FakePermissionPort(status = PermissionStatus.DENIED), + ) + fixture.coordinator.initialize() + + val didStart = fixture.coordinator.startRecording() + + assertFalse(didStart) + assertEquals(VoiceSessionError.MICROPHONE_PERMISSION_DENIED, fixture.eventSink.errors.single()) + assertEquals(VoiceSessionState.ERROR, fixture.eventSink.states.last().state) + } + + @Test + fun stopWithoutActiveRecordingFails() = runTest { + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + ) + fixture.coordinator.initialize() + + val result = fixture.coordinator.stopRecording() + + assertEquals(VoiceSessionStopReason.FAILED, result.reason) + assertEquals(VoiceSessionError.AUDIO_STOP_FAILED, fixture.eventSink.errors.single()) + } + + @Test + fun successfulTranscriptionCopiesTranscriptAndSavesHistory() = runTest { + val history = FakeTranscriptHistoryPort() + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + audioCapture = FakeAudioCapturePort(stopAudio = byteArrayOf(1, 2, 3)), + history = history, + timestampProvider = FakeTimestampProvider(now = 2_000L), + ) + fixture.coordinator.initialize() + fixture.coordinator.startRecording() + fixture.timestampProvider.now = 3_250L + + val result = fixture.coordinator.stopRecording() + + assertEquals(VoiceSessionStopReason.TRANSCRIPT_READY, result.reason) + assertEquals("transcribed text", result.transcript) + assertEquals("transcribed text", fixture.clipboard.lastCopiedText) + assertEquals("transcribed text", fixture.eventSink.transcripts.single()) + assertEquals("transcribed text", history.latest()?.text) + assertEquals(1_250L, result.durationMs) + assertEquals(VoiceSessionState.COMPLETED, fixture.eventSink.states.last().state) + } + + @Test + fun emptyAudioReturnsNoSpeechDetected() = runTest { + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + audioCapture = FakeAudioCapturePort(stopAudio = byteArrayOf()), + ) + fixture.coordinator.initialize() + fixture.coordinator.startRecording() + + val result = fixture.coordinator.stopRecording() + + assertEquals(VoiceSessionStopReason.NO_SPEECH_DETECTED, result.reason) + assertNull(fixture.clipboard.lastCopiedText) + assertTrue(fixture.eventSink.states.last().message?.contains("No speech detected") == true) + } + + @Test + fun blankTranscriptReturnsNoSpeechDetected() = runTest { + val backend = FakeBackend(transcript = "[BLANK AUDIO]") + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + backend = backend, + audioCapture = FakeAudioCapturePort(stopAudio = byteArrayOf(9, 9, 9)), + ) + fixture.coordinator.initialize() + fixture.coordinator.startRecording() + + val result = fixture.coordinator.stopRecording() + + assertEquals(VoiceSessionStopReason.NO_SPEECH_DETECTED, result.reason) + assertNull(fixture.clipboard.lastCopiedText) + } + + @Test + fun clipboardFailureReturnsFailedResult() = runTest { + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + clipboard = FakeClipboardPort(shouldCopySucceed = false), + audioCapture = FakeAudioCapturePort(stopAudio = byteArrayOf(1, 2, 3)), + ) + fixture.coordinator.initialize() + fixture.coordinator.startRecording() + + val result = fixture.coordinator.stopRecording() + + assertEquals(VoiceSessionStopReason.FAILED, result.reason) + assertEquals(VoiceSessionError.CLIPBOARD_WRITE_FAILED, fixture.eventSink.errors.last()) + } + + @Test + fun modelLoadFailureIsSurfacedAsError() = runTest { + val backend = FakeBackend(loadError = IllegalStateException("load failed")) + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + backend = backend, + audioCapture = FakeAudioCapturePort(stopAudio = byteArrayOf(1, 2, 3)), + ) + fixture.coordinator.initialize() + fixture.coordinator.startRecording() + + val result = fixture.coordinator.stopRecording() + + assertEquals(VoiceSessionStopReason.FAILED, result.reason) + assertEquals(VoiceSessionError.MODEL_LOAD_FAILED, fixture.eventSink.errors.last()) + } + + @Test + fun concurrentStartRequestIsRejected() = runTest { + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + ) + fixture.coordinator.initialize() + + val first = fixture.coordinator.startRecording() + val second = fixture.coordinator.startRecording() + + assertTrue(first) + assertFalse(second) + assertEquals(VoiceSessionError.AUDIO_START_FAILED, fixture.eventSink.errors.last()) + } + + @Test + fun cancelWhileRecordingResetsToIdle() = runTest { + val fixture = fixture( + settings = defaultSettings(selectedModelId = "openai_whisper-base"), + installed = listOf(installedRecord("openai_whisper-base")), + ) + fixture.coordinator.initialize() + fixture.coordinator.startRecording() + + fixture.coordinator.cancelRecording() + + assertFalse(fixture.audioCapture.isCapturing()) + assertEquals(VoiceSessionState.IDLE, fixture.eventSink.states.last().state) + } + + @Test + fun settingsAreRestoredOnInitialize() = runTest { + val settings = defaultSettings( + selectedModelId = "openai_whisper-base", + selectedLanguage = TranscriptionLanguage.GERMAN, + preferredInputDeviceId = "usb-mic", + outputMode = VoiceOutputMode.DIRECT_INSERT, + ) + val fixture = fixture( + settings = settings, + installed = listOf(installedRecord("openai_whisper-base")), + ) + + val result = fixture.coordinator.initialize() + + assertEquals(settings, result.settings) + fixture.coordinator.startRecording() + assertEquals("usb-mic", fixture.audioCapture.preferredInputDeviceId) + } + + private fun defaultSettings( + selectedModelId: String, + selectedLanguage: TranscriptionLanguage = TranscriptionLanguage.AUTOMATIC, + preferredInputDeviceId: String? = null, + outputMode: VoiceOutputMode = VoiceOutputMode.CLIPBOARD, + ): VoiceSettingsSnapshot { + return VoiceSettingsSnapshot( + selectedModelId = TranscriptionModelId(selectedModelId), + selectedLanguage = selectedLanguage, + preferredInputDeviceId = preferredInputDeviceId, + outputMode = outputMode, + ) + } + + private fun installedRecord(modelId: String): InstalledModelRecord { + return InstalledModelRecord( + modelId = TranscriptionModelId(modelId), + state = ModelInstallState.INSTALLED, + storage = ModelStorageLayout("/tmp", "/tmp/$modelId"), + installedProvider = LocalModelProvider.WCPP, + ) + } + + private fun fixture( + settings: VoiceSettingsSnapshot, + installed: List = emptyList(), + backend: FakeBackend = FakeBackend(), + audioCapture: FakeAudioCapturePort = FakeAudioCapturePort(stopAudio = byteArrayOf(1, 2, 3)), + clipboard: FakeClipboardPort = FakeClipboardPort(), + permissions: FakePermissionPort = FakePermissionPort(status = PermissionStatus.GRANTED), + history: FakeTranscriptHistoryPort? = null, + timestampProvider: FakeTimestampProvider = FakeTimestampProvider(now = 2_000L), + ): Fixture { + val backendRegistry = FakeBackendRegistry( + preferredByModelId = mapOf("openai_whisper-base" to LocalBackendId.WHISPER_CPP), + backends = mapOf(LocalBackendId.WHISPER_CPP to backend), + ) + val runtime = LocalTranscriptionRuntime( + platform = LocalPlatformId.LINUX, + installedModelIndex = FakeInstalledModelIndex(installed), + modelInstaller = FakeInstaller(installed.toMutableList()), + backendRegistry = backendRegistry, + ) + val eventSink = FakeVoiceSessionEventSink() + val settingsStore = FakeSettingsStorePort(settings) + val coordinator = VoiceSessionCoordinator( + runtime = runtime, + audioCapture = audioCapture, + clipboard = clipboard, + permissions = permissions, + settingsStore = settingsStore, + eventSink = eventSink, + history = history, + timestampProvider = timestampProvider, + ) + return Fixture( + coordinator = coordinator, + audioCapture = audioCapture, + clipboard = clipboard, + eventSink = eventSink, + timestampProvider = timestampProvider, + ) + } + + private data class Fixture( + val coordinator: VoiceSessionCoordinator, + val audioCapture: FakeAudioCapturePort, + val clipboard: FakeClipboardPort, + val eventSink: FakeVoiceSessionEventSink, + val timestampProvider: FakeTimestampProvider, + ) +} + +private class FakeSettingsStorePort( + private var snapshot: VoiceSettingsSnapshot, +) : SettingsStorePort { + override fun load(): VoiceSettingsSnapshot = snapshot + + override fun save(snapshot: VoiceSettingsSnapshot) { + this.snapshot = snapshot + } +} + +private class FakePermissionPort( + private val status: PermissionStatus, + private val requestedStatus: PermissionStatus = status, +) : PermissionPort { + var requestCalls: Int = 0 + + override suspend fun microphoneStatus(): PermissionStatus = status + + override suspend fun requestMicrophonePermission(): PermissionStatus { + requestCalls += 1 + return requestedStatus + } +} + +private class FakeAudioCapturePort( + private val stopAudio: ByteArray, +) : AudioCapturePort { + private var capturing = false + var preferredInputDeviceId: String? = null + + override suspend fun startCapture() { + if (capturing) { + error("already capturing") + } + capturing = true + } + + override suspend fun stopCapture(): ByteArray { + if (!capturing) { + error("not capturing") + } + capturing = false + return stopAudio + } + + override suspend fun cancelCapture() { + capturing = false + } + + override fun isCapturing(): Boolean = capturing + + override fun setPreferredInputDevice(deviceId: String?) { + preferredInputDeviceId = deviceId + } +} + +private class FakeClipboardPort( + private val shouldCopySucceed: Boolean = true, +) : ClipboardPort { + var lastCopiedText: String? = null + + override fun copyText(text: String): Boolean { + if (!shouldCopySucceed) { + return false + } + lastCopiedText = text + return true + } +} + +private class FakeTranscriptHistoryPort : TranscriptHistoryPort { + private val entries = mutableListOf() + + override suspend fun save(entry: TranscriptHistoryEntry) { + entries += entry + } + + override suspend fun latest(): TranscriptHistoryEntry? = entries.lastOrNull() +} + +private class FakeVoiceSessionEventSink : VoiceSessionEventSink { + val states = mutableListOf() + val errors = mutableListOf() + val transcripts = mutableListOf() + + override fun onStateChanged(state: VoiceSessionUiState) { + states += state + } + + override fun onError(error: VoiceSessionError) { + errors += error + } + + override fun onTranscriptReady(text: String) { + transcripts += text + } +} + +private class FakeTimestampProvider( + var now: Long, +) : TimestampProvider { + override fun nowEpochMillis(): Long = now +} + +private class FakeInstalledModelIndex( + private val installed: List, +) : InstalledModelIndexPort { + override suspend fun refreshInstalledModels(): List = installed +} + +private class FakeInstaller( + private val installed: MutableList, +) : ModelInstallerPort { + override suspend fun installModel( + model: LocalModelDescriptor, + onProgress: (ModelInstallProgress) -> Unit, + ): InstalledModelRecord { + onProgress( + ModelInstallProgress( + modelId = model.id, + progress = 1.0, + state = ModelInstallState.INSTALLED, + ), + ) + return InstalledModelRecord( + modelId = model.id, + state = ModelInstallState.INSTALLED, + storage = ModelStorageLayout("/tmp", "/tmp/${model.id.value}"), + installedProvider = model.provider, + ).also(installed::add) + } + + override suspend fun deleteModel(model: LocalModelDescriptor) { + installed.removeAll { it.modelId == model.id } + } +} + +private class FakeBackend( + private val transcript: String = "transcribed text", + private val loadError: Throwable? = null, +) : LocalInferenceBackendPort { + override val backendId: LocalBackendId = LocalBackendId.WHISPER_CPP + override val supportedFamilies: Set = setOf(LocalModelFamily.WHISPER) + override val supportsPathLoading: Boolean = true + + override suspend fun loadModel( + model: LocalModelDescriptor, + installedRecord: InstalledModelRecord?, + ) { + loadError?.let { throw it } + } + + override suspend fun loadModelFromPath(path: String) = Unit + + override suspend fun transcribe(request: TranscriptionRequest): TranscriptionResult { + return TranscriptionResult(text = transcript) + } + + override suspend fun unloadModel() = Unit +} + +private class FakeBackendRegistry( + private val preferredByModelId: Map, + private val backends: Map, +) : BackendRegistryPort { + override fun preferredBackend(model: LocalModelDescriptor): LocalBackendId? { + return preferredByModelId[model.id.value] + } + + override fun backend(id: LocalBackendId): LocalInferenceBackendPort? { + return backends[id] + } +} diff --git a/shared/runtime-transcription/build.gradle.kts b/shared/runtime-transcription/build.gradle.kts index c8b19c5..77682ae 100644 --- a/shared/runtime-transcription/build.gradle.kts +++ b/shared/runtime-transcription/build.gradle.kts @@ -34,10 +34,12 @@ kotlin { commonMain.dependencies { api(project(":core")) implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.2") + implementation("com.squareup.okio:okio:3.9.0") } commonTest.dependencies { implementation(kotlin("test")) implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.10.2") + implementation("com.squareup.okio:okio-fakefilesystem:3.9.0") } } } diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/FileSystemModelStorage.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/FileSystemModelStorage.kt new file mode 100644 index 0000000..95d56b2 --- /dev/null +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/FileSystemModelStorage.kt @@ -0,0 +1,229 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import okio.FileSystem +import okio.Path +import okio.Path.Companion.toPath +import okio.buffer +import okio.use +import tech.watzon.pindrop.shared.core.TranscriptionModelId + +data class RemoteModelArtifact( + val fileName: String, + val downloadUrl: String, + val sizeBytes: Long? = null, + val sha256: String? = null, +) + +interface RemoteModelRepositoryPort { + fun artifactsFor(model: LocalModelDescriptor): List +} + +interface DownloadClientPort { + suspend fun download( + artifact: RemoteModelArtifact, + destination: Path, + onProgress: (bytesDownloaded: Long, totalBytes: Long?) -> Unit, + ) +} + +class FileSystemInstalledModelIndex( + private val fileSystem: FileSystem, + private val installRoot: Path, +) : InstalledModelIndexPort { + override suspend fun refreshInstalledModels(): List { + if (!fileSystem.exists(installRoot)) { + return emptyList() + } + + return fileSystem.list(installRoot) + .filter { candidate -> + fileSystem.metadataOrNull(candidate)?.isDirectory == true && + !candidate.name.startsWith(".") + } + .map(::recordForModelDirectory) + } + + private fun recordForModelDirectory(modelDirectory: Path): InstalledModelRecord { + val modelId = TranscriptionModelId(modelDirectory.name) + val provider = readProvider(modelDirectory) + val modelFile = fileSystem.list(modelDirectory) + .firstOrNull { child -> + fileSystem.metadataOrNull(child)?.isRegularFile == true && + child.name !in RESERVED_FILE_NAMES + } + + val installState = when { + fileSystem.exists(modelDirectory / INSTALL_FAILED_FILE_NAME) -> ModelInstallState.FAILED + fileSystem.exists(modelDirectory / INSTALL_COMPLETE_FILE_NAME) -> ModelInstallState.INSTALLED + else -> ModelInstallState.NOT_INSTALLED + } + + return InstalledModelRecord( + modelId = modelId, + state = installState, + storage = ModelStorageLayout( + installRootPath = modelDirectory.toString(), + modelPath = modelFile?.toString(), + ), + installedProvider = provider, + lastError = readOptionalText(modelDirectory / INSTALL_FAILED_FILE_NAME), + ) + } + + private fun readProvider(modelDirectory: Path): LocalModelProvider? { + val providerText = readOptionalText(modelDirectory / PROVIDER_FILE_NAME) ?: return null + return LocalModelProvider.entries.firstOrNull { it.name == providerText } + } + + private fun readOptionalText(path: Path): String? { + if (!fileSystem.exists(path)) { + return null + } + + return fileSystem.source(path).buffer().use { source -> + source.readUtf8().trim().takeIf { it.isNotEmpty() } + } + } + + companion object { + internal const val INSTALL_COMPLETE_FILE_NAME = ".installed" + internal const val INSTALL_FAILED_FILE_NAME = ".failed" + internal const val PROVIDER_FILE_NAME = ".provider" + internal val RESERVED_FILE_NAMES = setOf( + INSTALL_COMPLETE_FILE_NAME, + INSTALL_FAILED_FILE_NAME, + PROVIDER_FILE_NAME, + ) + } +} + +class FileSystemModelInstaller( + private val fileSystem: FileSystem, + private val installRoot: Path, + private val repository: RemoteModelRepositoryPort, + private val downloadClient: DownloadClientPort, +) : ModelInstallerPort { + override suspend fun installModel( + model: LocalModelDescriptor, + onProgress: (ModelInstallProgress) -> Unit, + ): InstalledModelRecord { + val artifacts = repository.artifactsFor(model) + require(artifacts.isNotEmpty()) { "No artifacts configured for ${model.id.value}" } + + fileSystem.createDirectories(installRoot) + + val modelDirectory = modelInstallDirectory(model.id) + val tempDirectory = tempInstallDirectory(model.id) + cleanup(tempDirectory) + fileSystem.createDirectories(tempDirectory) + + emitProgress(model, onProgress, 0.0, ModelInstallState.INSTALLING, "Starting download") + + return runCatching { + artifacts.forEachIndexed { index, artifact -> + val tempPath = tempDirectory / artifact.fileName + downloadClient.download(artifact, tempPath) { downloadedBytes, totalBytes -> + val artifactProgress = when { + totalBytes == null || totalBytes <= 0L -> 0.0 + else -> downloadedBytes.toDouble() / totalBytes.toDouble() + }.coerceIn(0.0, 1.0) + val overallProgress = (index.toDouble() + artifactProgress) / artifacts.size.toDouble() + emitProgress( + model = model, + onProgress = onProgress, + progress = overallProgress, + state = ModelInstallState.INSTALLING, + message = "Downloading ${artifact.fileName}", + ) + } + + if (artifact.sizeBytes != null) { + val actualSize = fileSystem.metadata(tempPath).size ?: 0L + check(actualSize == artifact.sizeBytes) { + "Downloaded size mismatch for ${artifact.fileName}: expected ${artifact.sizeBytes}, got $actualSize" + } + } + } + + cleanup(modelDirectory) + fileSystem.createDirectories(modelDirectory) + artifacts.forEach { artifact -> + fileSystem.atomicMove( + source = tempDirectory / artifact.fileName, + target = modelDirectory / artifact.fileName, + ) + } + writeText(modelDirectory / FileSystemInstalledModelIndex.PROVIDER_FILE_NAME, model.provider.name) + writeText(modelDirectory / FileSystemInstalledModelIndex.INSTALL_COMPLETE_FILE_NAME, "ok") + cleanup(tempDirectory) + + emitProgress(model, onProgress, 1.0, ModelInstallState.INSTALLED, "Install complete") + InstalledModelRecord( + modelId = model.id, + state = ModelInstallState.INSTALLED, + storage = ModelStorageLayout( + installRootPath = modelDirectory.toString(), + modelPath = (modelDirectory / artifacts.first().fileName).toString(), + ), + installedProvider = model.provider, + ) + }.getOrElse { error -> + cleanup(modelDirectory) + fileSystem.createDirectories(modelDirectory) + writeText(modelDirectory / FileSystemInstalledModelIndex.INSTALL_FAILED_FILE_NAME, error.message ?: "install failed") + cleanup(tempDirectory) + emitProgress( + model = model, + onProgress = onProgress, + progress = 0.0, + state = ModelInstallState.FAILED, + message = error.message ?: "Install failed", + ) + throw error + } + } + + override suspend fun deleteModel(model: LocalModelDescriptor) { + cleanup(modelInstallDirectory(model.id)) + cleanup(tempInstallDirectory(model.id)) + } + + private fun modelInstallDirectory(modelId: TranscriptionModelId): Path { + return installRoot / modelId.value + } + + private fun tempInstallDirectory(modelId: TranscriptionModelId): Path { + return (installRoot / ".tmp").resolve(modelId.value) + } + + private fun cleanup(path: Path) { + if (fileSystem.exists(path)) { + fileSystem.deleteRecursively(path, mustExist = false) + } + } + + private fun writeText(path: Path, text: String) { + fileSystem.sink(path).buffer().use { sink -> + sink.writeUtf8(text) + } + } + + private fun emitProgress( + model: LocalModelDescriptor, + onProgress: (ModelInstallProgress) -> Unit, + progress: Double, + state: ModelInstallState, + message: String, + ) { + onProgress( + ModelInstallProgress( + modelId = model.id, + progress = progress.coerceIn(0.0, 1.0), + state = state, + message = message, + ), + ) + } +} + +private fun Path.resolve(child: String): Path = (toString().trimEnd('/') + "/" + child).toPath() diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt index e8fd776..1675685 100644 --- a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt @@ -57,22 +57,18 @@ object LocalTranscriptionCatalog { fun models(platform: LocalPlatformId): List { return localModels.map { descriptor -> - when (descriptor.family) { - LocalModelFamily.WHISPER -> descriptor.copy( - provider = if (platform == LocalPlatformId.MACOS) { - LocalModelProvider.WHISPER_KIT - } else { - LocalModelProvider.WCPP - }, - ) - LocalModelFamily.PARAKEET -> descriptor.copy( - provider = if (platform == LocalPlatformId.MACOS) { - LocalModelProvider.PARAKEET_COREML - } else { - LocalModelProvider.PARAKEET_NATIVE - }, - ) + val preferredBackend = preferredBackendFor(platform, descriptor.family) + val provider = providerFor(preferredBackend) + val availability = if (preferredBackend in descriptor.supportedBackends) { + descriptor.availability + } else { + ModelAvailability.COMING_SOON } + + descriptor.copy( + provider = provider, + availability = availability, + ) } } @@ -92,6 +88,7 @@ object LocalTranscriptionCatalog { val ranks = recommendedModelIds(language).withIndex().associate { it.value to it.index } return models .filter { it.id in ranks.keys } + .filter { it.availability == ModelAvailability.AVAILABLE } .filter { supportsLanguage(it.languageSupport, language) } .sortedBy { ranks[it.id] ?: Int.MAX_VALUE } } @@ -138,6 +135,7 @@ object LocalTranscriptionCatalog { id = TranscriptionModelId(id), family = LocalModelFamily.WHISPER, provider = LocalModelProvider.WHISPER_KIT, + supportedBackends = setOf(LocalBackendId.WHISPER_KIT, LocalBackendId.WHISPER_CPP), displayName = displayName, languageSupport = languageSupport, sizeInMb = sizeInMb, @@ -162,6 +160,7 @@ object LocalTranscriptionCatalog { id = TranscriptionModelId(id), family = LocalModelFamily.PARAKEET, provider = LocalModelProvider.PARAKEET_COREML, + supportedBackends = setOf(LocalBackendId.PARAKEET_APPLE), displayName = displayName, languageSupport = languageSupport, sizeInMb = sizeInMb, @@ -171,4 +170,36 @@ object LocalTranscriptionCatalog { availability = availability, ) } + + private fun preferredBackendFor( + platform: LocalPlatformId, + family: LocalModelFamily, + ): LocalBackendId { + return when (family) { + LocalModelFamily.WHISPER -> { + if (platform == LocalPlatformId.MACOS) { + LocalBackendId.WHISPER_KIT + } else { + LocalBackendId.WHISPER_CPP + } + } + + LocalModelFamily.PARAKEET -> { + if (platform == LocalPlatformId.MACOS) { + LocalBackendId.PARAKEET_APPLE + } else { + LocalBackendId.PARAKEET_NATIVE + } + } + } + } + + private fun providerFor(backendId: LocalBackendId): LocalModelProvider { + return when (backendId) { + LocalBackendId.WHISPER_KIT -> LocalModelProvider.WHISPER_KIT + LocalBackendId.WHISPER_CPP -> LocalModelProvider.WCPP + LocalBackendId.PARAKEET_APPLE -> LocalModelProvider.PARAKEET_COREML + LocalBackendId.PARAKEET_NATIVE -> LocalModelProvider.PARAKEET_NATIVE + } + } } diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionContracts.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionContracts.kt index ff755ce..5f4dd6e 100644 --- a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionContracts.kt +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionContracts.kt @@ -70,6 +70,7 @@ data class LocalModelDescriptor( val id: TranscriptionModelId, val family: LocalModelFamily, val provider: LocalModelProvider, + val supportedBackends: Set, val displayName: String, val languageSupport: ModelLanguageSupport, val sizeInMb: Int, diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntime.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntime.kt index cb0c5ed..5b7f9f2 100644 --- a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntime.kt +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntime.kt @@ -49,8 +49,10 @@ class LocalTranscriptionRuntime( defaultModelId: TranscriptionModelId, ): LocalModelSelectionResolution { val models = catalog() - val normalizedSelectedModel = models.firstOrNull { it.id == selectedModelId } - ?: models.firstOrNull { it.id == defaultModelId } + val selectableModels = models.filter { it.availability == ModelAvailability.AVAILABLE } + val normalizedSelectedModel = selectableModels.firstOrNull { it.id == selectedModelId } + ?: selectableModels.firstOrNull { it.id == defaultModelId } + ?: selectableModels.firstOrNull() ?: models.first() val installedSet = installedModels @@ -176,8 +178,16 @@ class LocalTranscriptionRuntime( val backend = backendRegistry.backend( when (family) { - LocalModelFamily.WHISPER -> LocalBackendId.WHISPER_KIT - LocalModelFamily.PARAKEET -> LocalBackendId.PARAKEET_APPLE + LocalModelFamily.WHISPER -> if (platform == LocalPlatformId.MACOS) { + LocalBackendId.WHISPER_KIT + } else { + LocalBackendId.WHISPER_CPP + } + LocalModelFamily.PARAKEET -> if (platform == LocalPlatformId.MACOS) { + LocalBackendId.PARAKEET_APPLE + } else { + LocalBackendId.PARAKEET_NATIVE + } }, ) diff --git a/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/FileSystemModelStorageTest.kt b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/FileSystemModelStorageTest.kt new file mode 100644 index 0000000..33f81a3 --- /dev/null +++ b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/FileSystemModelStorageTest.kt @@ -0,0 +1,175 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlinx.coroutines.test.runTest +import okio.ByteString.Companion.encodeUtf8 +import okio.Path.Companion.toPath +import okio.fakefilesystem.FakeFileSystem +import tech.watzon.pindrop.shared.core.ModelAvailability +import tech.watzon.pindrop.shared.core.ModelLanguageSupport +import tech.watzon.pindrop.shared.core.TranscriptionModelId + +class FileSystemModelStorageTest { + @Test + fun installedModelIndexReadsInstalledDirectories() = runTest { + val fileSystem = FakeFileSystem() + val installRoot = "/models".toPath() + fileSystem.createDirectories(installRoot / "openai_whisper-base") + fileSystem.write((installRoot / "openai_whisper-base" / ".installed")) { + write("ok".encodeUtf8()) + } + fileSystem.write((installRoot / "openai_whisper-base" / ".provider")) { + write("WCPP".encodeUtf8()) + } + fileSystem.write((installRoot / "openai_whisper-base" / "model.gguf")) { + write("binary".encodeUtf8()) + } + + val index = FileSystemInstalledModelIndex(fileSystem = fileSystem, installRoot = installRoot) + val records = index.refreshInstalledModels() + + assertEquals(1, records.size) + assertEquals(ModelInstallState.INSTALLED, records.single().state) + assertEquals(LocalModelProvider.WCPP, records.single().installedProvider) + assertTrue(records.single().storage.modelPath?.endsWith("model.gguf") == true) + } + + @Test + fun installerDownloadsArtifactsAtomicallyAndIndexSeesThem() = runTest { + val fileSystem = FakeFileSystem() + val installRoot = "/models".toPath() + val model = whisperModel() + val installer = FileSystemModelInstaller( + fileSystem = fileSystem, + installRoot = installRoot, + repository = FakeRepository( + artifacts = listOf( + RemoteModelArtifact( + fileName = "model.gguf", + downloadUrl = "https://example.invalid/model.gguf", + sizeBytes = 6, + ), + ), + ), + downloadClient = FakeDownloadClient( + fileSystem = fileSystem, + contentByUrl = mapOf( + "https://example.invalid/model.gguf" to "binary".encodeUtf8(), + ), + ), + ) + + val progress = mutableListOf() + val record = installer.installModel(model) { progress += it } + val indexed = FileSystemInstalledModelIndex(fileSystem, installRoot).refreshInstalledModels() + + assertEquals(ModelInstallState.INSTALLED, record.state) + assertTrue(progress.any { it.state == ModelInstallState.INSTALLING }) + assertEquals(ModelInstallState.INSTALLED, progress.last().state) + assertEquals(model.id, indexed.single().modelId) + assertTrue(fileSystem.exists(installRoot / model.id.value / "model.gguf")) + assertFalse(fileSystem.exists(installRoot / ".tmp" / model.id.value)) + } + + @Test + fun installerMarksFailuresAndCleansTempDirectory() = runTest { + val fileSystem = FakeFileSystem() + val installRoot = "/models".toPath() + val model = whisperModel() + val installer = FileSystemModelInstaller( + fileSystem = fileSystem, + installRoot = installRoot, + repository = FakeRepository( + artifacts = listOf( + RemoteModelArtifact( + fileName = "model.gguf", + downloadUrl = "https://example.invalid/model.gguf", + ), + ), + ), + downloadClient = object : DownloadClientPort { + override suspend fun download( + artifact: RemoteModelArtifact, + destination: okio.Path, + onProgress: (bytesDownloaded: Long, totalBytes: Long?) -> Unit, + ) { + onProgress(0, artifact.sizeBytes) + error("network failed") + } + }, + ) + + runCatching { + installer.installModel(model) { } + } + + val records = FileSystemInstalledModelIndex(fileSystem, installRoot).refreshInstalledModels() + val failed = records.single() + assertEquals(ModelInstallState.FAILED, failed.state) + assertNotNull(failed.lastError) + assertFalse(fileSystem.exists(installRoot / ".tmp" / model.id.value)) + } + + @Test + fun installerDeletesInstalledModelDirectory() = runTest { + val fileSystem = FakeFileSystem() + val installRoot = "/models".toPath() + val model = whisperModel() + val installer = FileSystemModelInstaller( + fileSystem = fileSystem, + installRoot = installRoot, + repository = FakeRepository(emptyList()), + downloadClient = FakeDownloadClient(fileSystem, emptyMap()), + ) + + fileSystem.createDirectories(installRoot / model.id.value) + fileSystem.write((installRoot / model.id.value / ".installed")) { write("ok".encodeUtf8()) } + + installer.deleteModel(model) + + assertFalse(fileSystem.exists(installRoot / model.id.value)) + } + + private fun whisperModel(): LocalModelDescriptor { + return LocalModelDescriptor( + id = TranscriptionModelId("openai_whisper-base"), + family = LocalModelFamily.WHISPER, + provider = LocalModelProvider.WCPP, + supportedBackends = setOf(LocalBackendId.WHISPER_CPP), + displayName = "Whisper Base", + languageSupport = ModelLanguageSupport.FULL_MULTILINGUAL, + sizeInMb = 145, + description = "Test model", + speedRating = 9.0, + accuracyRating = 7.0, + availability = ModelAvailability.AVAILABLE, + ) + } +} + +private class FakeRepository( + private val artifacts: List, +) : RemoteModelRepositoryPort { + override fun artifactsFor(model: LocalModelDescriptor): List = artifacts +} + +private class FakeDownloadClient( + private val fileSystem: FakeFileSystem, + private val contentByUrl: Map, +) : DownloadClientPort { + override suspend fun download( + artifact: RemoteModelArtifact, + destination: okio.Path, + onProgress: (bytesDownloaded: Long, totalBytes: Long?) -> Unit, + ) { + val content = contentByUrl.getValue(artifact.downloadUrl) + fileSystem.write(destination) { + write(content) + } + onProgress(content.size.toLong(), artifact.sizeBytes ?: content.size.toLong()) + } +} From 55f021a47d1840a3c87a668259551e72338539b7 Mon Sep 17 00:00:00 2001 From: Chris Watson Date: Sat, 28 Mar 2026 23:32:18 -0600 Subject: [PATCH 3/5] Add shared whisper.cpp transcription runtime - Wire Ktor download client into shared model installer - Add whisper.cpp backend, curated model repo, and factory - Cover install, load, transcribe, and cleanup flows in tests --- shared/runtime-transcription/build.gradle.kts | 2 + .../transcription/KtorDownloadClient.kt | 61 ++++++ .../LocalTranscriptionCatalog.kt | 11 +- .../WhisperCppModelRepository.kt | 62 ++++++ .../transcription/WhisperCppRuntime.kt | 78 +++++++ .../transcription/WhisperCppRuntimeFactory.kt | 39 ++++ .../transcription/KtorDownloadClientTest.kt | 77 +++++++ .../LocalTranscriptionRuntimeTest.kt | 3 + .../transcription/WhisperCppRuntimeTest.kt | 192 ++++++++++++++++++ 9 files changed, 521 insertions(+), 4 deletions(-) create mode 100644 shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt create mode 100644 shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppModelRepository.kt create mode 100644 shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntime.kt create mode 100644 shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeFactory.kt create mode 100644 shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClientTest.kt create mode 100644 shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeTest.kt diff --git a/shared/runtime-transcription/build.gradle.kts b/shared/runtime-transcription/build.gradle.kts index 77682ae..3dec804 100644 --- a/shared/runtime-transcription/build.gradle.kts +++ b/shared/runtime-transcription/build.gradle.kts @@ -34,11 +34,13 @@ kotlin { commonMain.dependencies { api(project(":core")) implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.2") + implementation("io.ktor:ktor-client-core:3.4.1") implementation("com.squareup.okio:okio:3.9.0") } commonTest.dependencies { implementation(kotlin("test")) implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.10.2") + implementation("io.ktor:ktor-client-mock:3.4.1") implementation("com.squareup.okio:okio-fakefilesystem:3.9.0") } } diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt new file mode 100644 index 0000000..e282394 --- /dev/null +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt @@ -0,0 +1,61 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import io.ktor.client.call.body +import io.ktor.client.HttpClient +import io.ktor.client.request.prepareGet +import io.ktor.client.statement.HttpResponse +import io.ktor.http.HttpHeaders +import io.ktor.http.isSuccess +import okio.FileSystem +import okio.Path +import okio.buffer +import okio.use + +class KtorDownloadClient( + private val httpClient: HttpClient, + private val fileSystem: FileSystem, +) : DownloadClientPort { + override suspend fun download( + artifact: RemoteModelArtifact, + destination: Path, + onProgress: (bytesDownloaded: Long, totalBytes: Long?) -> Unit, + ) { + destination.parent?.let(fileSystem::createDirectories) + + runCatching { + httpClient.prepareGet(artifact.downloadUrl).execute { response -> + check(response.status.isSuccess()) { + "Download failed for ${artifact.fileName}: HTTP ${response.status.value}" + } + writeResponseBody( + response = response, + destination = destination, + fallbackTotalBytes = artifact.sizeBytes, + onProgress = onProgress, + ) + } + }.getOrElse { error -> + if (fileSystem.exists(destination)) { + fileSystem.delete(destination) + } + throw error + } + } + + private suspend fun writeResponseBody( + response: HttpResponse, + destination: Path, + fallbackTotalBytes: Long?, + onProgress: (bytesDownloaded: Long, totalBytes: Long?) -> Unit, + ) { + val totalBytes = response.headers[HttpHeaders.ContentLength]?.toLongOrNull() ?: fallbackTotalBytes + val bytes = response.body() + + onProgress(0L, totalBytes) + + fileSystem.sink(destination).buffer().use { sink -> + sink.write(bytes) + } + onProgress(bytes.size.toLong(), totalBytes ?: bytes.size.toLong()) + } +} diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt index 1675685..f8485bc 100644 --- a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionCatalog.kt @@ -59,10 +59,13 @@ object LocalTranscriptionCatalog { return localModels.map { descriptor -> val preferredBackend = preferredBackendFor(platform, descriptor.family) val provider = providerFor(preferredBackend) - val availability = if (preferredBackend in descriptor.supportedBackends) { - descriptor.availability - } else { - ModelAvailability.COMING_SOON + val availability = when { + preferredBackend !in descriptor.supportedBackends -> ModelAvailability.COMING_SOON + platform != LocalPlatformId.MACOS && + preferredBackend == LocalBackendId.WHISPER_CPP && + descriptor.family == LocalModelFamily.WHISPER && + descriptor.id !in WhisperCppRemoteModelRepository.curatedModelIds -> ModelAvailability.REQUIRES_SETUP + else -> descriptor.availability } descriptor.copy( diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppModelRepository.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppModelRepository.kt new file mode 100644 index 0000000..8dd3180 --- /dev/null +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppModelRepository.kt @@ -0,0 +1,62 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import tech.watzon.pindrop.shared.core.TranscriptionModelId + +class WhisperCppRemoteModelRepository : RemoteModelRepositoryPort { + override fun artifactsFor(model: LocalModelDescriptor): List { + if (model.family != LocalModelFamily.WHISPER) { + return emptyList() + } + + val curatedModel = curatedModelsById[model.id] ?: return emptyList() + return listOf( + RemoteModelArtifact( + fileName = curatedModel.fileName, + downloadUrl = curatedModel.downloadUrl, + ), + ) + } + + companion object { + private val curatedModelsById = listOf( + curated("openai_whisper-tiny", "tiny"), + curated("openai_whisper-tiny.en", "tiny.en"), + curated("openai_whisper-base", "base"), + curated("openai_whisper-base.en", "base.en"), + curated("openai_whisper-small", "small"), + curated("openai_whisper-small.en", "small.en"), + curated("openai_whisper-medium", "medium"), + curated("openai_whisper-medium.en", "medium.en"), + curated("openai_whisper-large-v2", "large-v2"), + curated("openai_whisper-large-v3", "large-v3"), + curated("openai_whisper-large-v3_turbo", "large-v3-turbo"), + curated("openai_whisper-small_216MB", "small-q5_1"), + curated("openai_whisper-small.en_217MB", "small.en-q5_1"), + curated("openai_whisper-large-v3_turbo_954MB", "large-v3-turbo-q8_0"), + ).associateBy { it.modelId } + + val curatedModelIds: Set = curatedModelsById.keys + + private fun curated( + localModelId: String, + whisperCppModelName: String, + repositoryBaseUrl: String = DEFAULT_MODEL_REPOSITORY_BASE_URL, + ): CuratedWhisperCppModel { + val fileName = "ggml-$whisperCppModelName.bin" + return CuratedWhisperCppModel( + modelId = TranscriptionModelId(localModelId), + fileName = fileName, + downloadUrl = "$repositoryBaseUrl/$fileName", + ) + } + + private const val DEFAULT_MODEL_REPOSITORY_BASE_URL = + "https://huggingface.co/ggerganov/whisper.cpp/resolve/main" + } +} + +private data class CuratedWhisperCppModel( + val modelId: TranscriptionModelId, + val fileName: String, + val downloadUrl: String, +) diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntime.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntime.kt new file mode 100644 index 0000000..78f421f --- /dev/null +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntime.kt @@ -0,0 +1,78 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import tech.watzon.pindrop.shared.core.TranscriptionRequest +import tech.watzon.pindrop.shared.core.TranscriptionResult + +interface WhisperCppBridgePort { + suspend fun loadModel(modelPath: String) + suspend fun transcribe(request: TranscriptionRequest): TranscriptionResult + suspend fun unloadModel() +} + +class WhisperCppBackend( + private val bridge: WhisperCppBridgePort, +) : LocalInferenceBackendPort { + override val backendId: LocalBackendId = LocalBackendId.WHISPER_CPP + override val supportedFamilies: Set = setOf(LocalModelFamily.WHISPER) + override val supportsPathLoading: Boolean = true + + override suspend fun loadModel( + model: LocalModelDescriptor, + installedRecord: InstalledModelRecord?, + ) { + require(model.family in supportedFamilies) { + "Backend $backendId does not support model family ${model.family}" + } + + val modelPath = installedRecord?.storage?.modelPath + ?: error("Installed model path is missing for ${model.id.value}") + bridge.loadModel(modelPath) + } + + override suspend fun loadModelFromPath(path: String) { + bridge.loadModel(path) + } + + override suspend fun transcribe(request: TranscriptionRequest): TranscriptionResult { + return bridge.transcribe(request) + } + + override suspend fun unloadModel() { + bridge.unloadModel() + } +} + +class DefaultBackendRegistry( + private val platform: LocalPlatformId, + backends: Collection, +) : BackendRegistryPort { + private val backendsById = backends.associateBy { it.backendId } + + override fun preferredBackend(model: LocalModelDescriptor): LocalBackendId? { + return preferredBackendIds(model).firstOrNull { backendId -> + backendId in model.supportedBackends && backendId in backendsById + } + } + + override fun backend(id: LocalBackendId): LocalInferenceBackendPort? = backendsById[id] + + private fun preferredBackendIds(model: LocalModelDescriptor): List { + return when (model.family) { + LocalModelFamily.WHISPER -> { + if (platform == LocalPlatformId.MACOS) { + listOf(LocalBackendId.WHISPER_KIT, LocalBackendId.WHISPER_CPP) + } else { + listOf(LocalBackendId.WHISPER_CPP) + } + } + + LocalModelFamily.PARAKEET -> { + if (platform == LocalPlatformId.MACOS) { + listOf(LocalBackendId.PARAKEET_APPLE, LocalBackendId.PARAKEET_NATIVE) + } else { + listOf(LocalBackendId.PARAKEET_NATIVE) + } + } + } + } +} diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeFactory.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeFactory.kt new file mode 100644 index 0000000..5c66c64 --- /dev/null +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeFactory.kt @@ -0,0 +1,39 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import io.ktor.client.HttpClient +import okio.FileSystem +import okio.Path + +object WhisperCppRuntimeFactory { + fun create( + platform: LocalPlatformId, + fileSystem: FileSystem, + installRoot: Path, + httpClient: HttpClient, + bridge: WhisperCppBridgePort, + observer: RuntimeObserver? = null, + ): LocalTranscriptionRuntime { + val repository = WhisperCppRemoteModelRepository() + return LocalTranscriptionRuntime( + platform = platform, + installedModelIndex = FileSystemInstalledModelIndex( + fileSystem = fileSystem, + installRoot = installRoot, + ), + modelInstaller = FileSystemModelInstaller( + fileSystem = fileSystem, + installRoot = installRoot, + repository = repository, + downloadClient = KtorDownloadClient( + httpClient = httpClient, + fileSystem = fileSystem, + ), + ), + backendRegistry = DefaultBackendRegistry( + platform = platform, + backends = listOf(WhisperCppBackend(bridge)), + ), + observer = observer, + ) + } +} diff --git a/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClientTest.kt b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClientTest.kt new file mode 100644 index 0000000..daeb66b --- /dev/null +++ b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClientTest.kt @@ -0,0 +1,77 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.headersOf +import io.ktor.utils.io.ByteReadChannel +import kotlinx.coroutines.test.runTest +import okio.Path.Companion.toPath +import okio.fakefilesystem.FakeFileSystem +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class KtorDownloadClientTest { + @Test + fun downloadStreamsBodyToDiskAndReportsProgress() = runTest { + val fileSystem = FakeFileSystem() + val destination = "/models/base/ggml-base.bin".toPath() + val client = HttpClient( + MockEngine { + respond( + content = ByteReadChannel("binary"), + status = HttpStatusCode.OK, + headers = headersOf(HttpHeaders.ContentLength, "6"), + ) + }, + ) + val downloadClient = KtorDownloadClient(client, fileSystem) + val progress = mutableListOf>() + + downloadClient.download( + artifact = RemoteModelArtifact( + fileName = "ggml-base.bin", + downloadUrl = "https://example.invalid/ggml-base.bin", + ), + destination = destination, + ) { downloaded, total -> + progress += downloaded to total + } + + val content = fileSystem.read(destination) { readUtf8() } + assertEquals("binary", content) + assertTrue(progress.first() == (0L to 6L)) + assertTrue(progress.last() == (6L to 6L)) + } + + @Test + fun downloadDeletesPartialFileOnFailure() = runTest { + val fileSystem = FakeFileSystem() + val destination = "/models/base/ggml-base.bin".toPath() + val client = HttpClient( + MockEngine { + respond( + content = ByteReadChannel("boom"), + status = HttpStatusCode.BadGateway, + ) + }, + ) + val downloadClient = KtorDownloadClient(client, fileSystem) + + runCatching { + downloadClient.download( + artifact = RemoteModelArtifact( + fileName = "ggml-base.bin", + downloadUrl = "https://example.invalid/ggml-base.bin", + ), + destination = destination, + ) { _, _ -> } + } + + assertFalse(fileSystem.exists(destination)) + } +} diff --git a/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntimeTest.kt b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntimeTest.kt index 5356718..ac78242 100644 --- a/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntimeTest.kt +++ b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/LocalTranscriptionRuntimeTest.kt @@ -1,6 +1,7 @@ package tech.watzon.pindrop.shared.runtime.transcription import kotlinx.coroutines.test.runTest +import tech.watzon.pindrop.shared.core.ModelAvailability import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -37,10 +38,12 @@ class LocalTranscriptionRuntimeTest { val macWhisper = LocalTranscriptionCatalog.model(LocalPlatformId.MACOS, TranscriptionModelId("openai_whisper-base")) val linuxWhisper = LocalTranscriptionCatalog.model(LocalPlatformId.LINUX, TranscriptionModelId("openai_whisper-base")) val windowsParakeet = LocalTranscriptionCatalog.model(LocalPlatformId.WINDOWS, TranscriptionModelId("parakeet-tdt-0.6b-v3")) + val linuxManualSetup = LocalTranscriptionCatalog.model(LocalPlatformId.LINUX, TranscriptionModelId("openai_whisper-large-v3-v20240930")) assertEquals(LocalModelProvider.WHISPER_KIT, macWhisper?.provider) assertEquals(LocalModelProvider.WCPP, linuxWhisper?.provider) assertEquals(LocalModelProvider.PARAKEET_NATIVE, windowsParakeet?.provider) + assertEquals(ModelAvailability.REQUIRES_SETUP, linuxManualSetup?.availability) } @Test diff --git a/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeTest.kt b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeTest.kt new file mode 100644 index 0000000..f9db479 --- /dev/null +++ b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeTest.kt @@ -0,0 +1,192 @@ +package tech.watzon.pindrop.shared.runtime.transcription + +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.HttpStatusCode +import kotlinx.coroutines.test.runTest +import okio.ByteString.Companion.encodeUtf8 +import okio.Path.Companion.toPath +import okio.fakefilesystem.FakeFileSystem +import tech.watzon.pindrop.shared.core.TranscriptionLanguage +import tech.watzon.pindrop.shared.core.TranscriptionModelId +import tech.watzon.pindrop.shared.core.TranscriptionRequest +import tech.watzon.pindrop.shared.core.TranscriptionResult +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class WhisperCppRuntimeTest { + @Test + fun repositoryReturnsCuratedArtifactForSupportedModel() { + val repository = WhisperCppRemoteModelRepository() + val model = requireNotNull( + LocalTranscriptionCatalog.model( + platform = LocalPlatformId.WINDOWS, + modelId = TranscriptionModelId("openai_whisper-base.en"), + ), + ) + + val artifacts = repository.artifactsFor(model) + + assertEquals(1, artifacts.size) + assertEquals("ggml-base.en.bin", artifacts.single().fileName) + assertTrue(artifacts.single().downloadUrl.endsWith("/ggml-base.en.bin")) + } + + @Test + fun unsupportedWhisperModelsRequireManualSetupOffMacos() { + val curated = requireNotNull( + LocalTranscriptionCatalog.model( + platform = LocalPlatformId.LINUX, + modelId = TranscriptionModelId("openai_whisper-base.en"), + ), + ) + val manualSetup = requireNotNull( + LocalTranscriptionCatalog.model( + platform = LocalPlatformId.LINUX, + modelId = TranscriptionModelId("openai_whisper-large-v3-v20240930"), + ), + ) + val recommendedEnglish = LocalTranscriptionCatalog.recommendedModels( + platform = LocalPlatformId.LINUX, + language = TranscriptionLanguage.ENGLISH, + ) + + assertEquals(LocalModelProvider.WCPP, curated.provider) + assertEquals(tech.watzon.pindrop.shared.core.ModelAvailability.AVAILABLE, curated.availability) + assertEquals(tech.watzon.pindrop.shared.core.ModelAvailability.REQUIRES_SETUP, manualSetup.availability) + assertFalse(recommendedEnglish.any { it.id == manualSetup.id }) + } + + @Test + fun runtimeInstallsLoadsTranscribesAndDeletesWithWhisperCppBackend() = runTest { + val fileSystem = FakeFileSystem() + val installRoot = "/models".toPath() + val repository = WhisperCppRemoteModelRepository() + val model = requireNotNull( + LocalTranscriptionCatalog.model( + platform = LocalPlatformId.WINDOWS, + modelId = TranscriptionModelId("openai_whisper-base.en"), + ), + ) + val artifact = repository.artifactsFor(model).single() + val bridge = FakeWhisperCppBridge() + val runtime = LocalTranscriptionRuntime( + platform = LocalPlatformId.WINDOWS, + installedModelIndex = FileSystemInstalledModelIndex(fileSystem, installRoot), + modelInstaller = FileSystemModelInstaller( + fileSystem = fileSystem, + installRoot = installRoot, + repository = repository, + downloadClient = FakeArtifactDownloadClient( + fileSystem = fileSystem, + contentByUrl = mapOf(artifact.downloadUrl to "binary".encodeUtf8()), + ), + ), + backendRegistry = DefaultBackendRegistry( + platform = LocalPlatformId.WINDOWS, + backends = listOf(WhisperCppBackend(bridge)), + ), + ) + + runtime.refreshInstalledModels() + runtime.installModel(model.id) + runtime.loadModel(model.id) + val result = runtime.transcribe(TranscriptionRequest(audioData = byteArrayOf(1, 2, 3))) + runtime.deleteModel(model.id) + + assertEquals("transcribed", result.text) + assertEquals(listOf("/models/openai_whisper-base.en/ggml-base.en.bin"), bridge.loadedPaths) + assertEquals(1, bridge.transcribeCalls) + assertFalse(fileSystem.exists(installRoot / model.id.value)) + assertEquals(LocalRuntimeState.UNLOADED, runtime.state) + } + + @Test + fun factoryBuildsWorkingWhisperCppRuntimeFromSharedPieces() = runTest { + val fileSystem = FakeFileSystem() + val installRoot = "/models".toPath() + val bridge = FakeWhisperCppBridge() + val runtime = WhisperCppRuntimeFactory.create( + platform = LocalPlatformId.WINDOWS, + fileSystem = fileSystem, + installRoot = installRoot, + httpClient = HttpClient( + MockEngine { request -> + val fileName = request.url.encodedPath.substringAfterLast('/') + respond( + content = fileName.encodeUtf8().toByteArray(), + status = HttpStatusCode.OK, + ) + }, + ), + bridge = bridge, + ) + val modelId = TranscriptionModelId("openai_whisper-base.en") + + runtime.refreshInstalledModels() + runtime.installModel(modelId) + runtime.loadModel(modelId) + val result = runtime.transcribe(TranscriptionRequest(audioData = byteArrayOf(4, 5, 6))) + + assertEquals("transcribed", result.text) + assertEquals( + listOf("/models/openai_whisper-base.en/ggml-base.en.bin"), + bridge.loadedPaths, + ) + } + + @Test + fun defaultBackendRegistryPrefersRegisteredWhisperCppBackendOffMacos() { + val backend = WhisperCppBackend(FakeWhisperCppBridge()) + val registry = DefaultBackendRegistry( + platform = LocalPlatformId.LINUX, + backends = listOf(backend), + ) + val model = requireNotNull( + LocalTranscriptionCatalog.model( + platform = LocalPlatformId.LINUX, + modelId = TranscriptionModelId("openai_whisper-base.en"), + ), + ) + + assertEquals(LocalBackendId.WHISPER_CPP, registry.preferredBackend(model)) + assertNotNull(registry.backend(LocalBackendId.WHISPER_CPP)) + } +} + +private class FakeWhisperCppBridge : WhisperCppBridgePort { + val loadedPaths = mutableListOf() + var transcribeCalls: Int = 0 + + override suspend fun loadModel(modelPath: String) { + loadedPaths += modelPath + } + + override suspend fun transcribe(request: TranscriptionRequest): TranscriptionResult { + transcribeCalls += 1 + return TranscriptionResult(text = "transcribed") + } + + override suspend fun unloadModel() = Unit +} + +private class FakeArtifactDownloadClient( + private val fileSystem: FakeFileSystem, + private val contentByUrl: Map, +) : DownloadClientPort { + override suspend fun download( + artifact: RemoteModelArtifact, + destination: okio.Path, + onProgress: (bytesDownloaded: Long, totalBytes: Long?) -> Unit, + ) { + val content = contentByUrl.getValue(artifact.downloadUrl) + fileSystem.write(destination) { + write(content) + } + onProgress(content.size.toLong(), content.size.toLong()) + } +} From 55656add9852d0daf84455d7f3d3d8ba6496880a Mon Sep 17 00:00:00 2001 From: Chris Watson Date: Sat, 28 Mar 2026 23:48:21 -0600 Subject: [PATCH 4/5] Implement shared whisper.cpp runtime plumbing --- .gitignore | 1 + shared/gradle.properties | 1 + shared/runtime-transcription/build.gradle.kts | 12 ++++++++- .../transcription/WhisperCppRuntimeFactory.kt | 10 +++---- .../transcription/WhisperCppRuntimeTest.kt | 26 ++++++++++--------- .../transcription/KtorDownloadClient.kt | 4 +-- 6 files changed, 32 insertions(+), 22 deletions(-) rename shared/runtime-transcription/src/{commonMain => desktopMain}/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt (98%) diff --git a/.gitignore b/.gitignore index 9f99187..665043d 100644 --- a/.gitignore +++ b/.gitignore @@ -48,6 +48,7 @@ Carthage/Build/ # Temporary scripts and build artifacts *.py package-lock.json +*.hprof DerivedDataTests/ diff --git a/shared/gradle.properties b/shared/gradle.properties index c7e3b25..2f9934f 100644 --- a/shared/gradle.properties +++ b/shared/gradle.properties @@ -1,2 +1,3 @@ kotlin.code.style=official kotlin.mpp.enableCInteropCommonization=true +kotlin.mpp.applyDefaultHierarchyTemplate=false diff --git a/shared/runtime-transcription/build.gradle.kts b/shared/runtime-transcription/build.gradle.kts index 3dec804..ecfcad5 100644 --- a/shared/runtime-transcription/build.gradle.kts +++ b/shared/runtime-transcription/build.gradle.kts @@ -31,10 +31,16 @@ kotlin { } sourceSets { + val desktopMain by creating { + dependsOn(commonMain.get()) + dependencies { + implementation("io.ktor:ktor-client-core:3.4.1") + } + } + commonMain.dependencies { api(project(":core")) implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.2") - implementation("io.ktor:ktor-client-core:3.4.1") implementation("com.squareup.okio:okio:3.9.0") } commonTest.dependencies { @@ -43,5 +49,9 @@ kotlin { implementation("io.ktor:ktor-client-mock:3.4.1") implementation("com.squareup.okio:okio-fakefilesystem:3.9.0") } + + jvmMain.get().dependsOn(desktopMain) + linuxX64Main.get().dependsOn(desktopMain) + mingwX64Main.get().dependsOn(desktopMain) } } diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeFactory.kt b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeFactory.kt index 5c66c64..2872e7b 100644 --- a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeFactory.kt +++ b/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeFactory.kt @@ -1,6 +1,5 @@ package tech.watzon.pindrop.shared.runtime.transcription -import io.ktor.client.HttpClient import okio.FileSystem import okio.Path @@ -9,11 +8,11 @@ object WhisperCppRuntimeFactory { platform: LocalPlatformId, fileSystem: FileSystem, installRoot: Path, - httpClient: HttpClient, + downloadClient: DownloadClientPort, bridge: WhisperCppBridgePort, observer: RuntimeObserver? = null, + repository: RemoteModelRepositoryPort = WhisperCppRemoteModelRepository(), ): LocalTranscriptionRuntime { - val repository = WhisperCppRemoteModelRepository() return LocalTranscriptionRuntime( platform = platform, installedModelIndex = FileSystemInstalledModelIndex( @@ -24,10 +23,7 @@ object WhisperCppRuntimeFactory { fileSystem = fileSystem, installRoot = installRoot, repository = repository, - downloadClient = KtorDownloadClient( - httpClient = httpClient, - fileSystem = fileSystem, - ), + downloadClient = downloadClient, ), backendRegistry = DefaultBackendRegistry( platform = platform, diff --git a/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeTest.kt b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeTest.kt index f9db479..d8cb132 100644 --- a/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeTest.kt +++ b/shared/runtime-transcription/src/commonTest/kotlin/tech/watzon/pindrop/shared/runtime/transcription/WhisperCppRuntimeTest.kt @@ -1,9 +1,5 @@ package tech.watzon.pindrop.shared.runtime.transcription -import io.ktor.client.HttpClient -import io.ktor.client.engine.mock.MockEngine -import io.ktor.client.engine.mock.respond -import io.ktor.http.HttpStatusCode import kotlinx.coroutines.test.runTest import okio.ByteString.Companion.encodeUtf8 import okio.Path.Companion.toPath @@ -110,20 +106,26 @@ class WhisperCppRuntimeTest { val fileSystem = FakeFileSystem() val installRoot = "/models".toPath() val bridge = FakeWhisperCppBridge() + val repository = WhisperCppRemoteModelRepository() + val model = requireNotNull( + LocalTranscriptionCatalog.model( + platform = LocalPlatformId.WINDOWS, + modelId = TranscriptionModelId("openai_whisper-base.en"), + ), + ) + val artifact = repository.artifactsFor(model).single() val runtime = WhisperCppRuntimeFactory.create( platform = LocalPlatformId.WINDOWS, fileSystem = fileSystem, installRoot = installRoot, - httpClient = HttpClient( - MockEngine { request -> - val fileName = request.url.encodedPath.substringAfterLast('/') - respond( - content = fileName.encodeUtf8().toByteArray(), - status = HttpStatusCode.OK, - ) - }, + downloadClient = FakeArtifactDownloadClient( + fileSystem = fileSystem, + contentByUrl = mapOf( + artifact.downloadUrl to "ggml-base.en.bin".encodeUtf8(), + ), ), bridge = bridge, + repository = repository, ) val modelId = TranscriptionModelId("openai_whisper-base.en") diff --git a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt b/shared/runtime-transcription/src/desktopMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt similarity index 98% rename from shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt rename to shared/runtime-transcription/src/desktopMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt index e282394..6b40285 100644 --- a/shared/runtime-transcription/src/commonMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt +++ b/shared/runtime-transcription/src/desktopMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt @@ -1,7 +1,7 @@ package tech.watzon.pindrop.shared.runtime.transcription -import io.ktor.client.call.body import io.ktor.client.HttpClient +import io.ktor.client.call.body import io.ktor.client.request.prepareGet import io.ktor.client.statement.HttpResponse import io.ktor.http.HttpHeaders @@ -11,7 +11,7 @@ import okio.Path import okio.buffer import okio.use -class KtorDownloadClient( +internal class KtorDownloadClient( private val httpClient: HttpClient, private val fileSystem: FileSystem, ) : DownloadClientPort { From d06f290589edeec407919cfbea0b61e4d425e65a Mon Sep 17 00:00:00 2001 From: Chris Watson Date: Sat, 28 Mar 2026 23:50:48 -0600 Subject: [PATCH 5/5] Address PR review feedback --- .../NativeTranscriptionAdapters.swift | 24 +++++++++++++++++ .../transcription/KtorDownloadClient.kt | 26 ++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/Pindrop/Services/Transcription/NativeTranscriptionAdapters.swift b/Pindrop/Services/Transcription/NativeTranscriptionAdapters.swift index a597782..c9ed72d 100644 --- a/Pindrop/Services/Transcription/NativeTranscriptionAdapters.swift +++ b/Pindrop/Services/Transcription/NativeTranscriptionAdapters.swift @@ -251,6 +251,7 @@ final class KMPTranscriptionRuntimeBridge { } } } + try validateRuntimeLoadSucceeded(modelName: modelName) guard let engine = backendRegistry.engine(for: backendProvider) else { throw TranscriptionService.TranscriptionError.modelLoadFailed( @@ -270,6 +271,7 @@ final class KMPTranscriptionRuntimeBridge { } } } + try validateRuntimePathLoadSucceeded(path: path) guard let engine = backendRegistry.engine(for: .whisperKit) else { throw TranscriptionService.TranscriptionError.modelLoadFailed( @@ -338,6 +340,28 @@ final class KMPTranscriptionRuntimeBridge { } } } + + private func validateRuntimeLoadSucceeded(modelName: String) throws { + if runtime.state == .ready, runtime.activeModel?.descriptor.id.value == modelName { + return + } + + throw TranscriptionService.TranscriptionError.modelLoadFailed( + runtime.lastErrorMessage ?? + "Runtime failed to load model '\(modelName)' (\(runtime.lastErrorCode?.name ?? "unknown_error"))" + ) + } + + private func validateRuntimePathLoadSucceeded(path: String) throws { + if runtime.state == .ready { + return + } + + throw TranscriptionService.TranscriptionError.modelLoadFailed( + runtime.lastErrorMessage ?? + "Runtime failed to load model at path '\(path)' (\(runtime.lastErrorCode?.name ?? "unknown_error"))" + ) + } } private final class MacOSInstalledModelIndexAdapter: NSObject, InstalledModelIndexPort { diff --git a/shared/runtime-transcription/src/desktopMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt b/shared/runtime-transcription/src/desktopMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt index 6b40285..1c9d2ae 100644 --- a/shared/runtime-transcription/src/desktopMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt +++ b/shared/runtime-transcription/src/desktopMain/kotlin/tech/watzon/pindrop/shared/runtime/transcription/KtorDownloadClient.kt @@ -1,11 +1,12 @@ package tech.watzon.pindrop.shared.runtime.transcription import io.ktor.client.HttpClient -import io.ktor.client.call.body import io.ktor.client.request.prepareGet import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.bodyAsChannel import io.ktor.http.HttpHeaders import io.ktor.http.isSuccess +import io.ktor.utils.io.readAvailable import okio.FileSystem import okio.Path import okio.buffer @@ -49,13 +50,30 @@ internal class KtorDownloadClient( onProgress: (bytesDownloaded: Long, totalBytes: Long?) -> Unit, ) { val totalBytes = response.headers[HttpHeaders.ContentLength]?.toLongOrNull() ?: fallbackTotalBytes - val bytes = response.body() + val channel = response.bodyAsChannel() + val buffer = ByteArray(BUFFER_SIZE_BYTES) + var downloadedBytes = 0L onProgress(0L, totalBytes) fileSystem.sink(destination).buffer().use { sink -> - sink.write(bytes) + while (true) { + val readCount = channel.readAvailable(buffer, 0, buffer.size) + if (readCount == -1) { + break + } + if (readCount == 0) { + continue + } + + sink.write(buffer, 0, readCount) + downloadedBytes += readCount + onProgress(downloadedBytes, totalBytes) + } } - onProgress(bytes.size.toLong(), totalBytes ?: bytes.size.toLong()) + } + + private companion object { + const val BUFFER_SIZE_BYTES = 64 * 1024 } }