diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/PromptRunner.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/PromptRunner.kt index ea30c2dba..f7a2d45b6 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/PromptRunner.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/PromptRunner.kt @@ -28,6 +28,8 @@ import com.embabel.common.ai.prompt.PromptContributor import com.embabel.common.ai.prompt.PromptElement import com.embabel.common.util.loggerFor import org.jetbrains.annotations.ApiStatus +import java.util.function.Predicate +import kotlin.reflect.KProperty /** * Define a handoff to a subagent. @@ -297,6 +299,33 @@ interface PromptRunner : LlmUse, PromptRunnerOperations { */ fun withGenerateExamples(generateExamples: Boolean): PromptRunner + /** + * Adds a filter that determines which properties are to be included when creating an object. + * + * Note that each predicate is applied *in addition to* previously registered predicates, including + * [withProperties] and [withoutProperties]. + * @param filter the property predicate to be added + */ + fun withPropertyFilter(filter: Predicate): PromptRunner + + /** + * Includes the given properties when creating an object. + * + * Note that each predicate is applied *in addition to* previously registered predicates, including + * [withPropertyFilter] and [withoutProperties]. + * @param properties the properties that are to be included + */ + fun withProperties(vararg properties: String): PromptRunner = withPropertyFilter { properties.contains(it) } + + /** + * Excludes the given properties when creating an object. + * + * Note that each predicate is applied *in addition to* previously registered predicates, including + * [withPropertyFilter] and [withProperties]. + * @param properties the properties that are to be included + */ + fun withoutProperties(vararg properties: String): PromptRunner = withPropertyFilter { !properties.contains(it) } + /** * Create an object creator for the given output class. * Allows setting strongly typed examples. @@ -325,3 +354,13 @@ inline fun TemplateOperations.createObject( model: Map, ): T = createObject(outputClass = T::class.java, model = model) + +fun PromptRunner.withProperties( + vararg properties: KProperty +): PromptRunner = + withProperties(*properties.map { it.name }.toTypedArray()) + +fun PromptRunner.withoutProperties( + vararg properties: KProperty +): PromptRunner = + withoutProperties(*properties.map { it.name }.toTypedArray()) diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt index daf5d89fa..f2a90e2c5 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/api/common/support/OperationContextPromptRunner.kt @@ -33,6 +33,7 @@ import com.embabel.common.ai.prompt.PromptContributor import com.embabel.common.core.types.ZeroToOne import com.embabel.common.util.loggerFor import org.springframework.ai.tool.ToolCallback +import java.util.function.Predicate /** * Uses the platform's LlmOperations to execute the prompt. @@ -48,6 +49,7 @@ internal data class OperationContextPromptRunner( override val promptContributors: List, private val contextualPromptContributors: List, override val generateExamples: Boolean?, + override val propertyFilter: Predicate = Predicate { true }, private val otherToolCallbacks: List = emptyList(), ) : PromptRunner { @@ -83,6 +85,7 @@ internal data class OperationContextPromptRunner( }, id = interactionId ?: idForPrompt(messages, outputClass), generateExamples = generateExamples, + propertyFilter = propertyFilter, ), outputClass = outputClass, agentProcess = context.processContext.agentProcess, @@ -107,6 +110,7 @@ internal data class OperationContextPromptRunner( }, id = interactionId ?: idForPrompt(messages, outputClass), generateExamples = generateExamples, + propertyFilter = propertyFilter, ), outputClass = outputClass, agentProcess = context.processContext.agentProcess, @@ -225,6 +229,9 @@ internal data class OperationContextPromptRunner( override fun withGenerateExamples(generateExamples: Boolean): PromptRunner = copy(generateExamples = generateExamples) + override fun withPropertyFilter(filter: Predicate): PromptRunner = + copy(propertyFilter = this.propertyFilter.and(filter)) + override fun creating(outputClass: Class): ObjectCreator { return PromptRunnerObjectCreator( promptRunner = this, diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/LlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/LlmOperations.kt index 31666b65f..32fb77251 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/LlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/LlmOperations.kt @@ -29,6 +29,7 @@ import com.embabel.common.core.types.HasInfoString import com.embabel.common.util.indent import jakarta.validation.ConstraintViolation import org.springframework.ai.tool.ToolCallback +import java.util.function.Predicate /** * Spec for calling an LLM. Optional LlmOptions, @@ -43,6 +44,11 @@ interface LlmUse : PromptContributorConsumer, ToolGroupConsumer { */ val generateExamples: Boolean? + /** + * Filter that determines which properties to include when creating objects. + */ + val propertyFilter: Predicate + } /** @@ -76,6 +82,7 @@ private data class LlmCallImpl( override val promptContributors: List = emptyList(), override val contextualPromptContributors: List = emptyList(), override val generateExamples: Boolean = false, + override val propertyFilter: Predicate = Predicate { true }, ) : LlmCall /** @@ -101,6 +108,7 @@ data class LlmInteraction( override val promptContributors: List = emptyList(), override val contextualPromptContributors: List = emptyList(), override val generateExamples: Boolean? = null, + override val propertyFilter: Predicate = Predicate { true }, ) : LlmCall { override val name: String = id.value diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt index d445c315d..2a6cd13ee 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/ChatClientLlmOperations.kt @@ -45,7 +45,6 @@ import org.springframework.ai.chat.messages.SystemMessage import org.springframework.ai.chat.messages.UserMessage import org.springframework.ai.chat.model.ChatResponse import org.springframework.ai.chat.prompt.Prompt -import org.springframework.ai.converter.BeanOutputConverter import org.springframework.context.ApplicationContext import org.springframework.core.ParameterizedTypeReference import org.springframework.retry.support.RetrySynchronizationManager @@ -208,7 +207,11 @@ internal class ChatClientLlmOperations( expectedType = outputClass, delegate = WithExampleConverter( delegate = SuppressThinkingConverter( - BeanOutputConverter(outputClass, objectMapper) + FilteringJacksonOutputConverter( + clazz = outputClass, + objectMapper = objectMapper, + propertyFilter = interaction.propertyFilter, + ) ), outputClass = outputClass, ifPossible = false, @@ -346,7 +349,11 @@ internal class ChatClientLlmOperations( expectedType = MaybeReturn::class.java, delegate = WithExampleConverter( delegate = SuppressThinkingConverter( - BeanOutputConverter(typeReference, objectMapper) + FilteringJacksonOutputConverter( + typeReference = typeReference, + objectMapper = objectMapper, + propertyFilter = interaction.propertyFilter, + ) ), outputClass = outputClass as Class>, ifPossible = true, diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/FilteringJacksonOutputConverter.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/FilteringJacksonOutputConverter.kt new file mode 100644 index 000000000..428700cb2 --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/FilteringJacksonOutputConverter.kt @@ -0,0 +1,57 @@ +/* + * Copyright 2024-2025 Embabel Software, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support.springai + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode +import org.springframework.core.ParameterizedTypeReference +import java.lang.reflect.Type +import java.util.function.Predicate + +/** + * Extension of [JacksonOutputConverter] that allows for filtering of properties of the generated object via a predicate. + */ +class FilteringJacksonOutputConverter private constructor( + type: Type, + objectMapper: ObjectMapper, + private val propertyFilter: Predicate, +) : JacksonOutputConverter(type, objectMapper) { + + constructor( + clazz: Class, + objectMapper: ObjectMapper, + propertyFilter: Predicate, + ) : this(clazz as Type, objectMapper, propertyFilter) + + constructor( + typeReference: ParameterizedTypeReference, + objectMapper: ObjectMapper, + propertyFilter: Predicate, + ) : this(typeReference.type, objectMapper, propertyFilter) + + override fun postProcessSchema(jsonNode: JsonNode) { + val propertiesNode = jsonNode.get("properties") as? ObjectNode ?: return + + val fieldNames = propertiesNode.fieldNames() as MutableIterator + while (fieldNames.hasNext()) { + val fieldName = fieldNames.next() + if (!this.propertyFilter.test(fieldName)) { + fieldNames.remove() + } + } + } +} diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/JacksonOutputConverter.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/JacksonOutputConverter.kt new file mode 100644 index 000000000..16824f94b --- /dev/null +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/spi/support/springai/JacksonOutputConverter.kt @@ -0,0 +1,123 @@ +/* + * Copyright 2024-2025 Embabel Software, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support.springai + +import com.fasterxml.jackson.core.JsonProcessingException +import com.fasterxml.jackson.core.util.DefaultIndenter +import com.fasterxml.jackson.core.util.DefaultPrettyPrinter +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.ObjectMapper +import com.github.victools.jsonschema.generator.* +import com.github.victools.jsonschema.module.jackson.JacksonModule +import com.github.victools.jsonschema.module.jackson.JacksonOption +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import org.springframework.ai.converter.StructuredOutputConverter +import org.springframework.ai.util.LoggingMarkers +import org.springframework.core.ParameterizedTypeReference +import java.lang.reflect.Type + +/** + * A Kotlin version of [org.springframework.ai.converter.BeanOutputConverter] that allows for customization + * of the used schema via [postProcessSchema] + */ +open class JacksonOutputConverter protected constructor( + private val type: Type, + val objectMapper: ObjectMapper, +) : StructuredOutputConverter { + + constructor( + clazz: Class, + objectMapper: ObjectMapper, + ) : this(clazz as Type, objectMapper) + + constructor( + typeReference: ParameterizedTypeReference, + objectMapper: ObjectMapper, + ) : this(typeReference.type, objectMapper) + + protected val logger: Logger = LoggerFactory.getLogger(javaClass) + + val jsonSchema: String by lazy { + val jacksonModule = JacksonModule( + JacksonOption.RESPECT_JSONPROPERTY_REQUIRED, + JacksonOption.RESPECT_JSONPROPERTY_ORDER + ) + val configBuilder = SchemaGeneratorConfigBuilder( + SchemaVersion.DRAFT_2020_12, + OptionPreset.PLAIN_JSON + ) + .with(jacksonModule) + .with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT) + val config = configBuilder.build() + val generator = SchemaGenerator(config) + val jsonNode: JsonNode = generator.generateSchema(this.type) + postProcessSchema(jsonNode) + val objectWriter = this.objectMapper.writer( + DefaultPrettyPrinter() + .withObjectIndenter(DefaultIndenter().withLinefeed(System.lineSeparator())) + ) + try { + objectWriter.writeValueAsString(jsonNode) + } catch (e: JsonProcessingException) { + logger.error("Could not pretty print json schema for jsonNode: {}", jsonNode) + throw RuntimeException("Could not pretty print json schema for " + this.type, e) + } + } + + /** + * Empty template method that allows for customization of the JSON schema in subclasses. + * @param jsonNode the JSON schema, in the form of a JSON node + */ + protected open fun postProcessSchema(jsonNode: JsonNode) { + } + + override fun convert(text: String): T? { + val unwrapped = unwrapJson(text) + try { + return this.objectMapper.readValue(unwrapped, this.objectMapper.constructType(this.type)) as T? + } catch (e: JsonProcessingException) { + logger.error( + LoggingMarkers.SENSITIVE_DATA_MARKER, + "Could not parse the given text to the desired target type: \"{}\" into {}", unwrapped, this.type + ) + throw RuntimeException(e) + } + } + + private fun unwrapJson(text: String): String { + var result = text.trim() + + if (result.startsWith("```") && result.endsWith("```")) { + result = result.removePrefix("```json") + .removePrefix("```") + .removeSuffix("```") + .trim() + } + + return result + } + + override fun getFormat(): String = + """| + |Your response should be in JSON format. + |Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation. + |Do not include markdown code blocks in your response. + |Remove the ```json markdown from the output. + |Here is the JSON Schema instance your output must adhere to: + |```${jsonSchema}``` + |""".trimMargin() +} diff --git a/embabel-agent-api/src/main/kotlin/com/embabel/agent/testing/unit/FakePromptRunner.kt b/embabel-agent-api/src/main/kotlin/com/embabel/agent/testing/unit/FakePromptRunner.kt index 346e1c844..20e1ca95d 100644 --- a/embabel-agent-api/src/main/kotlin/com/embabel/agent/testing/unit/FakePromptRunner.kt +++ b/embabel-agent-api/src/main/kotlin/com/embabel/agent/testing/unit/FakePromptRunner.kt @@ -30,6 +30,7 @@ import com.embabel.common.core.types.ZeroToOne import com.embabel.common.textio.template.JinjavaTemplateRenderer import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import org.slf4j.LoggerFactory +import java.util.function.Predicate enum class Method { CREATE_OBJECT, @@ -51,6 +52,7 @@ data class FakePromptRunner( override val promptContributors: List, private val contextualPromptContributors: List, override val generateExamples: Boolean?, + override val propertyFilter: Predicate = Predicate { true }, private val context: OperationContext, private val _llmInvocations: MutableList = mutableListOf(), private val responses: MutableList = mutableListOf(), @@ -172,6 +174,10 @@ data class FakePromptRunner( override fun withGenerateExamples(generateExamples: Boolean): PromptRunner = copy(generateExamples = generateExamples) + override fun withPropertyFilter(filter: Predicate): PromptRunner = + copy(propertyFilter = this.propertyFilter.and(filter)) + + private fun createLlmInteraction() = LlmInteraction( llm = llm ?: LlmOptions(), diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerTest.kt index e25bc9601..296e72f84 100644 --- a/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerTest.kt +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/api/common/OperationContextPromptRunnerTest.kt @@ -474,6 +474,119 @@ class OperationContextPromptRunnerTest { } } + + @Nested + inner class WithPropertyFilter { + + @Test + fun `test property filter`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + .withPropertyFilter { it == "name" || it == "age" } + + assertTrue(pr.propertyFilter.test("name")) + assertTrue(pr.propertyFilter.test("age")) + assertFalse(pr.propertyFilter.test("email")) + } + + @Test + fun `test chain multiple property filters`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + .withPropertyFilter { it == "name" || it == "age" || it == "email" } + .withPropertyFilter { it != "email" } + + assertTrue(pr.propertyFilter.test("name")) + assertTrue(pr.propertyFilter.test("age")) + assertFalse(pr.propertyFilter.test("email")) + assertFalse(pr.propertyFilter.test("address")) + } + + @Test + fun `test default filter`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + + assertTrue(pr.propertyFilter.test("name")) + assertTrue(pr.propertyFilter.test("age")) + assertTrue(pr.propertyFilter.test("email")) + assertTrue(pr.propertyFilter.test("anyProperty")) + } + } + + + @Nested + inner class WithProperties { + + @Test + fun `test varargs syntax`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + .withProperties("name", "age") + + assertTrue(pr.propertyFilter.test("name")) + assertTrue(pr.propertyFilter.test("age")) + assertFalse(pr.propertyFilter.test("email")) + assertFalse(pr.propertyFilter.test("address")) + } + + @Test + fun `test KProperty syntax`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + .withProperties(Dog::name) + + assertTrue(pr.propertyFilter.test("name")) + assertFalse(pr.propertyFilter.test("age")) + } + + @Test + fun `test chain with withoutProperties`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + .withProperties("name", "age", "email") + .withoutProperties("email") + + assertTrue(pr.propertyFilter.test("name")) + assertTrue(pr.propertyFilter.test("age")) + assertFalse(pr.propertyFilter.test("email")) + assertFalse(pr.propertyFilter.test("address")) + } + + } + + + @Nested + inner class WithoutProperties { + + @Test + fun `test varargs syntax`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + .withoutProperties("email", "address") + + assertTrue(pr.propertyFilter.test("name")) + assertTrue(pr.propertyFilter.test("age")) + assertFalse(pr.propertyFilter.test("email")) + assertFalse(pr.propertyFilter.test("address")) + } + + @Test + fun `test KProperty syntax`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + .withoutProperties(Dog::name) + + assertFalse(pr.propertyFilter.test("name")) + assertTrue(pr.propertyFilter.test("age")) + } + + @Test + fun `test chain multiple`() { + val pr = createOperationContextPromptRunnerWithDefaults(FakeOperationContext()) + .withoutProperties("email") + .withoutProperties("address") + + assertTrue(pr.propertyFilter.test("name")) + assertTrue(pr.propertyFilter.test("age")) + assertFalse(pr.propertyFilter.test("email")) + assertFalse(pr.propertyFilter.test("address")) + } + } + + } diff --git a/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/FilteringJacksonOutputConverterTest.kt b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/FilteringJacksonOutputConverterTest.kt new file mode 100644 index 000000000..55f135a69 --- /dev/null +++ b/embabel-agent-api/src/test/kotlin/com/embabel/agent/spi/support/springai/FilteringJacksonOutputConverterTest.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2024-2025 Embabel Software, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.embabel.agent.spi.support.springai + +import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +class FilteringJacksonOutputConverterTest { + + private val objectMapper = jacksonObjectMapper() + + data class Person( + val name: String, + val age: Int, + val email: String, + val address: String + ) + + @Test + fun `test schema should include only specified properties`() { + val converter = FilteringJacksonOutputConverter( + clazz = Person::class.java, + objectMapper = objectMapper, + propertyFilter = { it == "name" || it == "age" } + ) + + val schema = converter.jsonSchema + + assertTrue(schema.contains("name")) + assertTrue(schema.contains("age")) + assertFalse(schema.contains("email")) + assertFalse(schema.contains("address")) + } + + @Test + fun `test schema should exclude specified properties`() { + val converter = FilteringJacksonOutputConverter( + clazz = Person::class.java, + objectMapper = objectMapper, + propertyFilter = { it != "email" && it != "address" } + ) + + val schema = converter.jsonSchema + + assertTrue(schema.contains("name")) + assertTrue(schema.contains("age")) + assertFalse(schema.contains("email")) + assertFalse(schema.contains("address")) + } + +}