Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 27 additions & 79 deletions Sources/MechMuse/GeneratedCombatantTraits.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
//

import Foundation
import Parsing
import OpenAIClient

public struct GenerateCombatantTraitsRequest {
Expand All @@ -17,92 +16,41 @@ public struct GenerateCombatantTraitsRequest {
}
}

extension GenerateCombatantTraitsRequest: PromptConvertible {
public func prompt() -> [ChatMessage] {
let namesList = combatantNames.map { "\"\($0)\"" }.joined(separator: ", ")

// Prompt notes:
// - Spelling out D&D because it yields slightly longer responses
// - Without quotes around each combatant name, the discriminator could be omitted from the response
// - Added "Limit each value to a single sentence" to subdue the tendency to give a bulleted list when only
// a single combatant was in the request.
extension GenerateCombatantTraitsRequest {
func messages() -> [ChatMessage] {
let namesList = combatantNames.joined(separator: ", ")
return [
.init(role: .system, content: "You are helping a Dungeons & Dragons DM create awesome encounters."),
.init(role: .user, content: """
The encounter has \(combatantNames.count) monster(s): \(namesList). Come up with gritty physical and personality traits that don't interfere with its stats and unique nickname that fits its traits. One trait of each type for each monster, limit each trait to a single sentence.

Format your answer as a correct YAML sequence of maps, with an entry for each monster. Each entry has fields name, physical, personality, nickname.
"""
)
.init(role: .user, content: "Describe gritty physical and personality traits and a unique nickname for each of the following combatants: \(namesList). Respond in JSON with an object that has a `combatants` array of objects each containing `name`, `physical`, `personality`, and `nickname`.")
]
}
}

public enum GenerateCombatantTraitsResponse {
static let parser = Parse(input: Substring.self) {
Whitespace()

OneOf {
Parse(input: Substring.self) {
"```yaml"
Whitespace()
yamlParser
Whitespace()

Skip {
Optionally { "```" }
}
}
func prompt() -> [ChatMessage] { messages() }

yamlParser
}

Whitespace()
}

static let yamlParser = Many(into: [Traits]()) { acc, elem in
acc.append(elem)
} element: {
singleParser
} separator: {
Whitespace()
}

private static let singleParser = Parse(Traits.init(name:physical:personality:nickname:)) {
"- name: "
trimmedString
Whitespace()

StartsWith("Physical:", by: { $0.lowercased() == $1.lowercased() })
Whitespace()
trimmedString
Whitespace()

StartsWith("Personality:", by: { $0.lowercased() == $1.lowercased() })
Whitespace()
trimmedString
Whitespace()

StartsWith("Nickname:", by: { $0.lowercased() == $1.lowercased() })
Whitespace()
trimmedString
}

private static let trimmedString = Prefix<Substring> { !$0.isNewline }.map {
$0.trimmingCharacters(in: CharacterSet(["\"", "'", ".", " "]))
func chatRequest() -> ChatCompletionRequest {
ChatCompletionRequest(
messages: messages(),
responseFormat: .jsonObject,
maxTokens: 150 * max(combatantNames.count, 1),
temperature: 0.9
)
}
}

public struct Traits: Equatable, Hashable {
public let name: String
public let physical: String
public let personality: String
public let nickname: String
public struct GenerateCombatantTraitsResponse: Codable {
public var combatants: [Traits]
}

public init(name: String, physical: String, personality: String, nickname: String) {
self.name = name
self.physical = physical
self.personality = personality
self.nickname = nickname
}
public struct Traits: Codable, Equatable, Hashable {
public let name: String
public let physical: String
public let personality: String
public let nickname: String

public init(name: String, physical: String, personality: String, nickname: String) {
self.name = name
self.physical = physical
self.personality = personality
self.nickname = nickname
}
}
61 changes: 9 additions & 52 deletions Sources/MechMuse/MechMuse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,60 +86,17 @@ public extension MechMuse {
}
},
describeCombatants: { client, request in
assert(!request.combatantNames.isEmpty)
let prompt = request.prompt()

typealias TraitsArray = [GenerateCombatantTraitsResponse.Traits]
let endToken = "[Construct::END]"
do {
return try chain(client.stream(request: ChatCompletionRequest(
messages: prompt,
maxTokens: 150 * max(request.combatantNames.count, 1),
temperature: 0.9
)), [endToken].async)
// reduce the tokens into a growing (partial) response
.reductions(into: "", { acc, elem in
acc += elem
})
// parse every (partial) response
.map { acc -> TraitsArray in
if acc.hasSuffix(endToken) {

do {
let traits = try GenerateCombatantTraitsResponse.parser.parse(String(acc.dropLast(endToken.count)))

if traits.isEmpty {
throw MechMuseError.unspecified // is upgraded to .interpretationFailed below
}

// we're at the end of the response, add a dummy Traits that is removed in the next operator
return traits + [.init(name: "", physical: "", personality: "", nickname: "")]
} catch {
throw MechMuseError.interpretationFailed(
text: String(acc.dropLast(endToken.count)),
error: String(describing: error)
)
}
} else {
do {
return try GenerateCombatantTraitsResponse.parser.parse(acc)
} catch {
return []
}
}
let response = try await client.perform(request: request.chatRequest())
guard
let content = response.choices.first?.message.content,
let data = content.data(using: .utf8)
else {
throw MechMuseError.unspecified
}
// remember all seen traits
.reductions(into: (TraitsArray(), TraitsArray()), { traits, parsed in
// drop the last because it might be incomplete
let new = parsed.dropLast(1).filter { t in
!traits.0.contains(t)
}
traits = (traits.0 + new, new)
})
// emit just the new traits
.flatMap { (all, new) in
return new.async
}.stream

let wrapper = try JSONDecoder().decode(GenerateCombatantTraitsResponse.self, from: data)
return wrapper.combatants.async.stream
} catch let error as OpenAIError {
throw MechMuseError(from: error)
} catch {
Expand Down
5 changes: 5 additions & 0 deletions Sources/OpenAIClient/JSONSchema.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public struct JSONSchema: Codable, Equatable {
public var properties: [String: JSONSchema]?
public var required: [String]?
public var `enum`: [String]?
public var items: JSONSchema?
}

public extension JSONSchema {
Expand All @@ -31,4 +32,8 @@ public extension JSONSchema {
static func boolean(description: String? = nil) -> Self {
JSONSchema(type: "boolean", description: description)
}

static func array(of items: JSONSchema, description: String? = nil) -> Self {
JSONSchema(type: "array", description: description, properties: nil, required: nil, enum: nil, items: items)
}
}
6 changes: 3 additions & 3 deletions Sources/OpenAIClient/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import Foundation

public enum Model: String, Codable {
case gpt4 = "gpt-4"
/// Current standard GPT-4 model
case gpt4o = "gpt-4o"
/// Current standard GPT-3.5 model
case gpt35Turbo = "gpt-3.5-turbo"
case Davinci3 = "text-davinci-003"
case Curie1 = "text-curie-001"
}
16 changes: 15 additions & 1 deletion Sources/OpenAIClient/OpenAIClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,23 @@ public struct ChatCompletionRequest: Encodable, Equatable {
public var messages: [ChatMessage]
public var functions: [Function]?
public var functionCall: String?
public var responseFormat: ResponseFormat?
public let maxTokens: Int?
public let temperature: Float?
var stream: Bool = false

public init(
model: Model = .gpt35Turbo,
model: Model = .gpt4o,
messages: [ChatMessage],
functions: [Function]? = nil,
responseFormat: ResponseFormat? = nil,
maxTokens: Int? = nil,
temperature: Float? = nil
) {
self.model = model
self.messages = messages
self.functions = functions
self.responseFormat = responseFormat
self.maxTokens = maxTokens
self.temperature = temperature
}
Expand Down Expand Up @@ -278,6 +281,17 @@ public struct ChatCompletionRequest: Encodable, Equatable {
var name: String
}
}

public struct ResponseFormat: Encodable, Equatable {
public var type: String

public init(type: String) {
self.type = type
}

public static let jsonObject = Self(type: "json_object")
public static let text = Self(type: "text")
}
}

public struct ChatCompletionChunkResponse: Codable, Equatable {
Expand Down
12 changes: 3 additions & 9 deletions Tests/MechMuseTests/EncounterCombatantsTraitsTest.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// EncounterCombatantsDescriptionTest.swift
//
//
//
// Created by Thomas Visser on 30/12/2022.
//
Expand All @@ -15,15 +15,9 @@ final class EncounterCombatantsTraitsTest: XCTestCase {

func testPrompt() {
let request = GenerateCombatantTraitsRequest(combatantNames: ["Goblin 1", "Goblin 2", "Bugbear 1"])
XCTAssertNoDifference(request.prompt(), [
XCTAssertNoDifference(request.messages(), [
.init(role: .system, content: "You are helping a Dungeons & Dragons DM create awesome encounters."),
.init(role: .user, content: """
The encounter has 3 monster(s): "Goblin 1", "Goblin 2", "Bugbear 1". Come up with gritty physical and personality traits that don't interfere with its stats and unique nickname that fits its traits. One trait of each type for each monster, limit each trait to a single sentence.

Format your answer as a correct YAML sequence of maps, with an entry for each monster. Each entry has fields name, physical, personality, nickname.
"""
)
.init(role: .user, content: "Describe gritty physical and personality traits and a unique nickname for each of the following combatants: Goblin 1, Goblin 2, Bugbear 1. Respond in JSON with an object that has a `combatants` array of objects each containing `name`, `physical`, `personality`, and `nickname`.")
])
}

}
18 changes: 9 additions & 9 deletions Tests/OpenAIClientTests/OpenAIClientTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ final class OpenAIClientTest: XCTestCase {
func testCompletion() async throws {
// fake the response
let responseData = """
{"id":"cmpl-6JmBNcxha3k89zV2L6XzBXZdt5gb0","object":"text_completion","created":1670171465,"model":"text-davinci-003","choices":[{"text":"\\n\\nThis is indeed a test.","index":0,"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":8,"total_tokens":13}}
{"id":"cmpl-6JmBNcxha3k89zV2L6XzBXZdt5gb0","object":"text_completion","created":1670171465,"model":"gpt-4o","choices":[{"text":"\\n\\nThis is indeed a test.","index":0,"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":8,"total_tokens":13}}
""".data(using: .utf8)!
httpClient.dataResponse = (responseData, HTTPURLResponse())

let response = try await sut.perform(request: CompletionRequest(
model: .Davinci3,
model: .gpt4o,
prompt: "Say this is a test"
))

// Assert serialized request
let requestData = """
{"model":"text-davinci-003","stream":false,"prompt":"Say this is a test"}
{"model":"gpt-4o","stream":false,"prompt":"Say this is a test"}
""".data(using: .utf8)!
try XCTAssertNoDifferenceJSONData(requestData, httpClient.dataRequests.last?.httpBody)

Expand All @@ -47,7 +47,7 @@ final class OpenAIClientTest: XCTestCase {
id: "cmpl-6JmBNcxha3k89zV2L6XzBXZdt5gb0",
object: "text_completion",
created: 1670171465,
model: "text-davinci-003",
model: "gpt-4o",
choices: [
.init(text: "\n\nThis is indeed a test.", finishReason: "stop")
]
Expand All @@ -58,18 +58,18 @@ final class OpenAIClientTest: XCTestCase {
// fake the response
httpClient.streamResponse = [
"""
{"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
{"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-4o","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}
""",
"""
{"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"As"},"index":0,"finish_reason":null}]}
{"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-4o","choices":[{"delta":{"content":"As"},"index":0,"finish_reason":null}]}
""",
"""
{"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" the"},"index":0,"finish_reason":null}]}
{"id":"chatcmpl-6qJTvkR92tgHsu9nvdFSJSdhzk4yO","object":"chat.completion.chunk","created":1677925963,"model":"gpt-4o","choices":[{"delta":{"content":" the"},"index":0,"finish_reason":null}]}
"""
].async.stream

let response = try sut.stream(request: ChatCompletionRequest(
model: .gpt35Turbo,
model: .gpt4o,
messages: [
.init(role: .system, content: "You are a D&D DM"),
.init(role: .user, content: "Narrate the attack of a goblin")
Expand All @@ -78,7 +78,7 @@ final class OpenAIClientTest: XCTestCase {

// Assert serialized request
let requestData = """
{"model":"gpt-3.5-turbo","stream":true,"messages":[{"content":"You are a D&D DM","role":"system"},{"content":"Narrate the attack of a goblin","role":"user"}]}
{"model":"gpt-4o","stream":true,"messages":[{"content":"You are a D&D DM","role":"system"},{"content":"Narrate the attack of a goblin","role":"user"}]}
""".data(using: .utf8)!
try XCTAssertNoDifferenceJSONData(requestData, httpClient.streamRequests.last?.httpBody)

Expand Down
Loading