From 159ed715d159a57ff57015fd36098c62f29054b7 Mon Sep 17 00:00:00 2001 From: Thomas Visser Date: Wed, 4 Jun 2025 07:19:06 +0200 Subject: [PATCH] Use structured output for trait generation --- .../MechMuse/GeneratedCombatantTraits.swift | 106 +++++------------- Sources/MechMuse/MechMuse.swift | 61 ++-------- Sources/OpenAIClient/JSONSchema.swift | 5 + Sources/OpenAIClient/Models.swift | 6 +- Sources/OpenAIClient/OpenAIClient.swift | 16 ++- .../EncounterCombatantsTraitsTest.swift | 12 +- .../OpenAIClientTests/OpenAIClientTest.swift | 18 +-- 7 files changed, 71 insertions(+), 153 deletions(-) diff --git a/Sources/MechMuse/GeneratedCombatantTraits.swift b/Sources/MechMuse/GeneratedCombatantTraits.swift index fcafc7c..1d94bee 100644 --- a/Sources/MechMuse/GeneratedCombatantTraits.swift +++ b/Sources/MechMuse/GeneratedCombatantTraits.swift @@ -6,7 +6,6 @@ // import Foundation -import Parsing import OpenAIClient public struct GenerateCombatantTraitsRequest { @@ -17,92 +16,41 @@ public struct GenerateCombatantTraitsRequest { } } -extension GenerateCombatantTraitsRequest: PromptConvertible { - public func prompt() -> [ChatMessage] { - let namesList = combatantNames.map { "\"\($0)\"" }.joined(separator: ", ") - - // Prompt notes: - // - Spelling out D&D because it yields slightly longer responses - // - Without quotes around each combatant name, the discriminator could be omitted from the response - // - Added "Limit each value to a single sentence" to subdue the tendency to give a bulleted list when only - // a single combatant was in the request. +extension GenerateCombatantTraitsRequest { + func messages() -> [ChatMessage] { + let namesList = combatantNames.joined(separator: ", ") return [ .init(role: .system, content: "You are helping a Dungeons & Dragons DM create awesome encounters."), - .init(role: .user, content: """ - The encounter has \(combatantNames.count) monster(s): \(namesList). Come up with gritty physical and personality traits that don't interfere with its stats and unique nickname that fits its traits. One trait of each type for each monster, limit each trait to a single sentence. - - Format your answer as a correct YAML sequence of maps, with an entry for each monster. Each entry has fields name, physical, personality, nickname. - """ - ) + .init(role: .user, content: "Describe gritty physical and personality traits and a unique nickname for each of the following combatants: \(namesList). Respond in JSON with an object that has a `combatants` array of objects each containing `name`, `physical`, `personality`, and `nickname`.") ] } -} - -public enum GenerateCombatantTraitsResponse { - static let parser = Parse(input: Substring.self) { - Whitespace() - - OneOf { - Parse(input: Substring.self) { - "```yaml" - Whitespace() - yamlParser - Whitespace() - Skip { - Optionally { "```" } - } - } + func prompt() -> [ChatMessage] { messages() } - yamlParser - } - - Whitespace() - } - - static let yamlParser = Many(into: [Traits]()) { acc, elem in - acc.append(elem) - } element: { - singleParser - } separator: { - Whitespace() - } - - private static let singleParser = Parse(Traits.init(name:physical:personality:nickname:)) { - "- name: " - trimmedString - Whitespace() - - StartsWith("Physical:", by: { $0.lowercased() == $1.lowercased() }) - Whitespace() - trimmedString - Whitespace() - - StartsWith("Personality:", by: { $0.lowercased() == $1.lowercased() }) - Whitespace() - trimmedString - Whitespace() - - StartsWith("Nickname:", by: { $0.lowercased() == $1.lowercased() }) - Whitespace() - trimmedString - } - - private static let trimmedString = Prefix { !$0.isNewline }.map { - $0.trimmingCharacters(in: CharacterSet(["\"", "'", ".", " "])) + func chatRequest() -> ChatCompletionRequest { + ChatCompletionRequest( + messages: messages(), + responseFormat: .jsonObject, + maxTokens: 150 * max(combatantNames.count, 1), + temperature: 0.9 + ) } +} - public struct Traits: Equatable, Hashable { - public let name: String - public let physical: String - public let personality: String - public let nickname: String +public struct GenerateCombatantTraitsResponse: Codable { + public var combatants: [Traits] +} - public init(name: String, physical: String, personality: String, nickname: String) { - self.name = name - self.physical = physical - self.personality = personality - self.nickname = nickname - } +public struct Traits: Codable, Equatable, Hashable { + public let name: String + public let physical: String + public let personality: String + public let nickname: String + + public init(name: String, physical: String, personality: String, nickname: String) { + self.name = name + self.physical = physical + self.personality = personality + self.nickname = nickname } } diff --git a/Sources/MechMuse/MechMuse.swift b/Sources/MechMuse/MechMuse.swift index 8803444..8fc4e38 100644 --- a/Sources/MechMuse/MechMuse.swift +++ b/Sources/MechMuse/MechMuse.swift @@ -86,60 +86,17 @@ public extension MechMuse { } }, describeCombatants: { client, request in - assert(!request.combatantNames.isEmpty) - let prompt = request.prompt() - - typealias TraitsArray = [GenerateCombatantTraitsResponse.Traits] - let endToken = "[Construct::END]" do { - return try chain(client.stream(request: ChatCompletionRequest( - messages: prompt, - maxTokens: 150 * max(request.combatantNames.count, 1), - temperature: 0.9 - )), [endToken].async) - // reduce the tokens into a growing (partial) response - .reductions(into: "", { acc, elem in - acc += elem - }) - // parse every (partial) response - .map { acc -> TraitsArray in - if acc.hasSuffix(endToken) { - - do { - let traits = try GenerateCombatantTraitsResponse.parser.parse(String(acc.dropLast(endToken.count))) - - if traits.isEmpty { - throw MechMuseError.unspecified // is upgraded to .interpretationFailed below - } - - // we're at the end of the response, add a dummy Traits that is removed in the next operator - return traits + [.init(name: "", physical: "", personality: "", nickname: "")] - } catch { - throw MechMuseError.interpretationFailed( - text: String(acc.dropLast(endToken.count)), - error: String(describing: error) - ) - } - } else { - do { - return try GenerateCombatantTraitsResponse.parser.parse(acc) - } catch { - return [] - } - } + let response = try await client.perform(request: request.chatRequest()) + guard + let content = response.choices.first?.message.content, + let data = content.data(using: .utf8) + else { + throw MechMuseError.unspecified } - // remember all seen traits - .reductions(into: (TraitsArray(), TraitsArray()), { traits, parsed in - // drop the last because it might be incomplete - let new = parsed.dropLast(1).filter { t in - !traits.0.contains(t) - } - traits = (traits.0 + new, new) - }) - // emit just the new traits - .flatMap { (all, new) in - return new.async - }.stream + + let wrapper = try JSONDecoder().decode(GenerateCombatantTraitsResponse.self, from: data) + return wrapper.combatants.async.stream } catch let error as OpenAIError { throw MechMuseError(from: error) } catch { diff --git a/Sources/OpenAIClient/JSONSchema.swift b/Sources/OpenAIClient/JSONSchema.swift index 02bb19f..8ddb04d 100644 --- a/Sources/OpenAIClient/JSONSchema.swift +++ b/Sources/OpenAIClient/JSONSchema.swift @@ -13,6 +13,7 @@ public struct JSONSchema: Codable, Equatable { public var properties: [String: JSONSchema]? public var required: [String]? public var `enum`: [String]? + public var items: JSONSchema? } public extension JSONSchema { @@ -31,4 +32,8 @@ public extension JSONSchema { static func boolean(description: String? = nil) -> Self { JSONSchema(type: "boolean", description: description) } + + static func array(of items: JSONSchema, description: String? = nil) -> Self { + JSONSchema(type: "array", description: description, properties: nil, required: nil, enum: nil, items: items) + } } diff --git a/Sources/OpenAIClient/Models.swift b/Sources/OpenAIClient/Models.swift index fd5c6db..3403190 100644 --- a/Sources/OpenAIClient/Models.swift +++ b/Sources/OpenAIClient/Models.swift @@ -8,8 +8,8 @@ import Foundation public enum Model: String, Codable { - case gpt4 = "gpt-4" + /// Current standard GPT-4 model + case gpt4o = "gpt-4o" + /// Current standard GPT-3.5 model case gpt35Turbo = "gpt-3.5-turbo" - case Davinci3 = "text-davinci-003" - case Curie1 = "text-curie-001" } diff --git a/Sources/OpenAIClient/OpenAIClient.swift b/Sources/OpenAIClient/OpenAIClient.swift index 4e45055..7240e52 100644 --- a/Sources/OpenAIClient/OpenAIClient.swift +++ b/Sources/OpenAIClient/OpenAIClient.swift @@ -220,20 +220,23 @@ public struct ChatCompletionRequest: Encodable, Equatable { public var messages: [ChatMessage] public var functions: [Function]? public var functionCall: String? + public var responseFormat: ResponseFormat? public let maxTokens: Int? public let temperature: Float? var stream: Bool = false public init( - model: Model = .gpt35Turbo, + model: Model = .gpt4o, messages: [ChatMessage], functions: [Function]? = nil, + responseFormat: ResponseFormat? = nil, maxTokens: Int? = nil, temperature: Float? = nil ) { self.model = model self.messages = messages self.functions = functions + self.responseFormat = responseFormat self.maxTokens = maxTokens self.temperature = temperature } @@ -278,6 +281,17 @@ public struct ChatCompletionRequest: Encodable, Equatable { var name: String } } + + public struct ResponseFormat: Encodable, Equatable { + public var type: String + + public init(type: String) { + self.type = type + } + + public static let jsonObject = Self(type: "json_object") + public static let text = Self(type: "text") + } } public struct ChatCompletionChunkResponse: Codable, Equatable { diff --git a/Tests/MechMuseTests/EncounterCombatantsTraitsTest.swift b/Tests/MechMuseTests/EncounterCombatantsTraitsTest.swift index 5565dfe..ba0d163 100644 --- a/Tests/MechMuseTests/EncounterCombatantsTraitsTest.swift +++ b/Tests/MechMuseTests/EncounterCombatantsTraitsTest.swift @@ -1,6 +1,6 @@ // // EncounterCombatantsDescriptionTest.swift -// +// // // Created by Thomas Visser on 30/12/2022. // @@ -15,15 +15,9 @@ final class EncounterCombatantsTraitsTest: XCTestCase { func testPrompt() { let request = GenerateCombatantTraitsRequest(combatantNames: ["Goblin 1", "Goblin 2", "Bugbear 1"]) - XCTAssertNoDifference(request.prompt(), [ + XCTAssertNoDifference(request.messages(), [ .init(role: .system, content: "You are helping a Dungeons & Dragons DM create awesome encounters."), - .init(role: .user, content: """ - The encounter has 3 monster(s): "Goblin 1", "Goblin 2", "Bugbear 1". Come up with gritty physical and personality traits that don't interfere with its stats and unique nickname that fits its traits. One trait of each type for each monster, limit each trait to a single sentence. - - Format your answer as a correct YAML sequence of maps, with an entry for each monster. Each entry has fields name, physical, personality, nickname. - """ - ) + .init(role: .user, content: "Describe gritty physical and personality traits and a unique nickname for each of the following combatants: Goblin 1, Goblin 2, Bugbear 1. Respond in JSON with an object that has a `combatants` array of objects each containing `name`, `physical`, `personality`, and `nickname`.") ]) } - } diff --git a/Tests/OpenAIClientTests/OpenAIClientTest.swift b/Tests/OpenAIClientTests/OpenAIClientTest.swift index d5592b5..1af1c45 100644 --- a/Tests/OpenAIClientTests/OpenAIClientTest.swift +++ b/Tests/OpenAIClientTests/OpenAIClientTest.swift @@ -27,18 +27,18 @@ final class OpenAIClientTest: XCTestCase { func testCompletion() async throws { // fake the response let responseData = """ - {"id":"cmpl-6JmBNcxha3k89zV2L6XzBXZdt5gb0","object":"text_completion","created":1670171465,"model":"text-davinci-003","choices":[{"text":"\\n\\nThis is indeed a test.","index":0,"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":8,"total_tokens":13}} + {"id":"cmpl-6JmBNcxha3k89zV2L6XzBXZdt5gb0","object":"text_completion","created":1670171465,"model":"gpt-4o","choices":[{"text":"\\n\\nThis is indeed a test.","index":0,"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":8,"total_tokens":13}} """.data(using: .utf8)! httpClient.dataResponse = (responseData, HTTPURLResponse()) let response = try await sut.perform(request: CompletionRequest( - model: .Davinci3, + model: .gpt4o, prompt: "Say this is a test" )) // Assert serialized request let requestData = """ - {"model":"text-davinci-003","stream":false,"prompt":"Say this is a test"} + {"model":"gpt-4o","stream":false,"prompt":"Say this is a test"} """.data(using: .utf8)! try XCTAssertNoDifferenceJSONData(requestData, httpClient.dataRequests.last?.httpBody) @@ -47,7 +47,7 @@ final class OpenAIClientTest: XCTestCase { id: "cmpl-6JmBNcxha3k89zV2L6XzBXZdt5gb0", object: "text_completion", created: 1670171465, - model: "text-davinci-003", + model: "gpt-4o", choices: [ .init(text: "\n\nThis is indeed a test.", finishReason: "stop") ] @@ -58,18 +58,18 @@ final class OpenAIClientTest: XCTestCase { // fake the response httpClient.streamResponse = [ """ - {"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]} + {"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-4o","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]} """, """ - {"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"As"},"index":0,"finish_reason":null}]} + {"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-4o","choices":[{"delta":{"content":"As"},"index":0,"finish_reason":null}]} """, """ - {"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" the"},"index":0,"finish_reason":null}]} + {"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-4o","choices":[{"delta":{"content":" the"},"index":0,"finish_reason":null}]} """ ].async.stream let response = try sut.stream(request: ChatCompletionRequest( - model: .gpt35Turbo, + model: .gpt4o, messages: [ .init(role: .system, content: "You are a D&D DM"), .init(role: .user, content: "Narrate the attack of a goblin") @@ -78,7 +78,7 @@ final class OpenAIClientTest: XCTestCase { // Assert serialized request let requestData = """ - {"model":"gpt-3.5-turbo","stream":true,"messages":[{"content":"You are a D&D DM","role":"system"},{"content":"Narrate the attack of a goblin","role":"user"}]} + {"model":"gpt-4o","stream":true,"messages":[{"content":"You are a D&D DM","role":"system"},{"content":"Narrate the attack of a goblin","role":"user"}]} """.data(using: .utf8)! try XCTAssertNoDifferenceJSONData(requestData, httpClient.streamRequests.last?.httpBody)