Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import com.embabel.common.ai.prompt.PromptContributor
import com.embabel.common.ai.prompt.PromptElement
import com.embabel.common.util.loggerFor
import org.jetbrains.annotations.ApiStatus
import java.util.function.Predicate
import kotlin.reflect.KProperty

/**
* Define a handoff to a subagent.
Expand Down Expand Up @@ -297,6 +299,33 @@ interface PromptRunner : LlmUse, PromptRunnerOperations {
*/
fun withGenerateExamples(generateExamples: Boolean): PromptRunner

/**
* Adds a filter that determines which properties are to be included when creating an object.
*
* Note that each predicate is applied *in addition to* previously registered predicates, including
* [withProperties] and [withoutProperties].
* @param filter the property predicate to be added
*/
fun withPropertyFilter(filter: Predicate<String>): PromptRunner

/**
* Includes the given properties when creating an object.
*
* Note that each predicate is applied *in addition to* previously registered predicates, including
* [withPropertyFilter] and [withoutProperties].
* @param properties the properties that are to be included
*/
fun withProperties(vararg properties: String): PromptRunner = withPropertyFilter { properties.contains(it) }

/**
* Excludes the given properties when creating an object.
*
* Note that each predicate is applied *in addition to* previously registered predicates, including
* [withPropertyFilter] and [withProperties].
* @param properties the properties that are to be included
*/
fun withoutProperties(vararg properties: String): PromptRunner = withPropertyFilter { !properties.contains(it) }

/**
* Create an object creator for the given output class.
* Allows setting strongly typed examples.
Expand Down Expand Up @@ -325,3 +354,13 @@ inline fun <reified T> TemplateOperations.createObject(
model: Map<String, Any>,
): T =
createObject(outputClass = T::class.java, model = model)

fun PromptRunner.withProperties(
vararg properties: KProperty<Any>
): PromptRunner =
withProperties(*properties.map { it.name }.toTypedArray())

fun PromptRunner.withoutProperties(
vararg properties: KProperty<Any>
): PromptRunner =
withoutProperties(*properties.map { it.name }.toTypedArray())
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.embabel.common.ai.prompt.PromptContributor
import com.embabel.common.core.types.ZeroToOne
import com.embabel.common.util.loggerFor
import org.springframework.ai.tool.ToolCallback
import java.util.function.Predicate

/**
* Uses the platform's LlmOperations to execute the prompt.
Expand All @@ -48,6 +49,7 @@ internal data class OperationContextPromptRunner(
override val promptContributors: List<PromptContributor>,
private val contextualPromptContributors: List<ContextualPromptElement>,
override val generateExamples: Boolean?,
override val propertyFilter: Predicate<String> = Predicate { true },
private val otherToolCallbacks: List<ToolCallback> = emptyList(),
) : PromptRunner {

Expand Down Expand Up @@ -83,6 +85,7 @@ internal data class OperationContextPromptRunner(
},
id = interactionId ?: idForPrompt(messages, outputClass),
generateExamples = generateExamples,
propertyFilter = propertyFilter,
),
outputClass = outputClass,
agentProcess = context.processContext.agentProcess,
Expand All @@ -107,6 +110,7 @@ internal data class OperationContextPromptRunner(
},
id = interactionId ?: idForPrompt(messages, outputClass),
generateExamples = generateExamples,
propertyFilter = propertyFilter,
),
outputClass = outputClass,
agentProcess = context.processContext.agentProcess,
Expand Down Expand Up @@ -225,6 +229,9 @@ internal data class OperationContextPromptRunner(
override fun withGenerateExamples(generateExamples: Boolean): PromptRunner =
copy(generateExamples = generateExamples)

override fun withPropertyFilter(filter: Predicate<String>): PromptRunner =
copy(propertyFilter = this.propertyFilter.and(filter))

override fun <T> creating(outputClass: Class<T>): ObjectCreator<T> {
return PromptRunnerObjectCreator(
promptRunner = this,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.embabel.common.core.types.HasInfoString
import com.embabel.common.util.indent
import jakarta.validation.ConstraintViolation
import org.springframework.ai.tool.ToolCallback
import java.util.function.Predicate

/**
* Spec for calling an LLM. Optional LlmOptions,
Expand All @@ -43,6 +44,11 @@ interface LlmUse : PromptContributorConsumer, ToolGroupConsumer {
*/
val generateExamples: Boolean?

/**
* Filter that determines which properties to include when creating objects.
*/
val propertyFilter: Predicate<String>

}

/**
Expand Down Expand Up @@ -76,6 +82,7 @@ private data class LlmCallImpl(
override val promptContributors: List<PromptContributor> = emptyList(),
override val contextualPromptContributors: List<ContextualPromptElement> = emptyList(),
override val generateExamples: Boolean = false,
override val propertyFilter: Predicate<String> = Predicate { true },
) : LlmCall

/**
Expand All @@ -101,6 +108,7 @@ data class LlmInteraction(
override val promptContributors: List<PromptContributor> = emptyList(),
override val contextualPromptContributors: List<ContextualPromptElement> = emptyList(),
override val generateExamples: Boolean? = null,
override val propertyFilter: Predicate<String> = Predicate { true },
) : LlmCall {

override val name: String = id.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import org.springframework.ai.chat.messages.SystemMessage
import org.springframework.ai.chat.messages.UserMessage
import org.springframework.ai.chat.model.ChatResponse
import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.converter.BeanOutputConverter
import org.springframework.context.ApplicationContext
import org.springframework.core.ParameterizedTypeReference
import org.springframework.retry.support.RetrySynchronizationManager
Expand Down Expand Up @@ -208,7 +207,11 @@ internal class ChatClientLlmOperations(
expectedType = outputClass,
delegate = WithExampleConverter(
delegate = SuppressThinkingConverter(
BeanOutputConverter(outputClass, objectMapper)
FilteringJacksonOutputConverter(
clazz = outputClass,
objectMapper = objectMapper,
propertyFilter = interaction.propertyFilter,
)
),
outputClass = outputClass,
ifPossible = false,
Expand Down Expand Up @@ -346,7 +349,11 @@ internal class ChatClientLlmOperations(
expectedType = MaybeReturn::class.java,
delegate = WithExampleConverter(
delegate = SuppressThinkingConverter(
BeanOutputConverter(typeReference, objectMapper)
FilteringJacksonOutputConverter(
typeReference = typeReference,
objectMapper = objectMapper,
propertyFilter = interaction.propertyFilter,
)
),
outputClass = outputClass as Class<MaybeReturn<*>>,
ifPossible = true,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2024-2025 Embabel Software, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.embabel.agent.spi.support.springai

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.node.ObjectNode
import org.springframework.core.ParameterizedTypeReference
import java.lang.reflect.Type
import java.util.function.Predicate

/**
* Extension of [JacksonOutputConverter] that allows for filtering of properties of the generated object via a predicate.
*/
class FilteringJacksonOutputConverter<T> private constructor(
type: Type,
objectMapper: ObjectMapper,
private val propertyFilter: Predicate<String>,
) : JacksonOutputConverter<T>(type, objectMapper) {

constructor(
clazz: Class<T>,
objectMapper: ObjectMapper,
propertyFilter: Predicate<String>,
) : this(clazz as Type, objectMapper, propertyFilter)

constructor(
typeReference: ParameterizedTypeReference<T>,
objectMapper: ObjectMapper,
propertyFilter: Predicate<String>,
) : this(typeReference.type, objectMapper, propertyFilter)

override fun postProcessSchema(jsonNode: JsonNode) {
val propertiesNode = jsonNode.get("properties") as? ObjectNode ?: return

val fieldNames = propertiesNode.fieldNames() as MutableIterator<String>
while (fieldNames.hasNext()) {
val fieldName = fieldNames.next()
if (!this.propertyFilter.test(fieldName)) {
fieldNames.remove()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright 2024-2025 Embabel Software, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.embabel.agent.spi.support.springai

import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.core.util.DefaultIndenter
import com.fasterxml.jackson.core.util.DefaultPrettyPrinter
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.github.victools.jsonschema.generator.*
import com.github.victools.jsonschema.module.jackson.JacksonModule
import com.github.victools.jsonschema.module.jackson.JacksonOption
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.ai.converter.StructuredOutputConverter
import org.springframework.ai.util.LoggingMarkers
import org.springframework.core.ParameterizedTypeReference
import java.lang.reflect.Type

/**
* A Kotlin version of [org.springframework.ai.converter.BeanOutputConverter] that allows for customization
* of the used schema via [postProcessSchema]
*/
open class JacksonOutputConverter<T> protected constructor(
private val type: Type,
val objectMapper: ObjectMapper,
) : StructuredOutputConverter<T> {

constructor(
clazz: Class<T>,
objectMapper: ObjectMapper,
) : this(clazz as Type, objectMapper)

constructor(
typeReference: ParameterizedTypeReference<T>,
objectMapper: ObjectMapper,
) : this(typeReference.type, objectMapper)

protected val logger: Logger = LoggerFactory.getLogger(javaClass)

val jsonSchema: String by lazy {
val jacksonModule = JacksonModule(
JacksonOption.RESPECT_JSONPROPERTY_REQUIRED,
JacksonOption.RESPECT_JSONPROPERTY_ORDER
)
val configBuilder = SchemaGeneratorConfigBuilder(
SchemaVersion.DRAFT_2020_12,
OptionPreset.PLAIN_JSON
)
.with(jacksonModule)
.with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT)
val config = configBuilder.build()
val generator = SchemaGenerator(config)
val jsonNode: JsonNode = generator.generateSchema(this.type)
postProcessSchema(jsonNode)
val objectWriter = this.objectMapper.writer(
DefaultPrettyPrinter()
.withObjectIndenter(DefaultIndenter().withLinefeed(System.lineSeparator()))
)
try {
objectWriter.writeValueAsString(jsonNode)
} catch (e: JsonProcessingException) {
logger.error("Could not pretty print json schema for jsonNode: {}", jsonNode)
throw RuntimeException("Could not pretty print json schema for " + this.type, e)
}
}

/**
* Empty template method that allows for customization of the JSON schema in subclasses.
* @param jsonNode the JSON schema, in the form of a JSON node
*/
protected open fun postProcessSchema(jsonNode: JsonNode) {
}

override fun convert(text: String): T? {
val unwrapped = unwrapJson(text)
try {
return this.objectMapper.readValue<Any?>(unwrapped, this.objectMapper.constructType(this.type)) as T?
} catch (e: JsonProcessingException) {
logger.error(
LoggingMarkers.SENSITIVE_DATA_MARKER,
"Could not parse the given text to the desired target type: \"{}\" into {}", unwrapped, this.type
)
throw RuntimeException(e)
}
}

private fun unwrapJson(text: String): String {
var result = text.trim()

if (result.startsWith("```") && result.endsWith("```")) {
result = result.removePrefix("```json")
.removePrefix("```")
.removeSuffix("```")
.trim()
}

return result
}

override fun getFormat(): String =
"""|
|Your response should be in JSON format.
|Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
|Do not include markdown code blocks in your response.
|Remove the ```json markdown from the output.
|Here is the JSON Schema instance your output must adhere to:
|```${jsonSchema}```
|""".trimMargin()
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.embabel.common.core.types.ZeroToOne
import com.embabel.common.textio.template.JinjavaTemplateRenderer
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import org.slf4j.LoggerFactory
import java.util.function.Predicate

enum class Method {
CREATE_OBJECT,
Expand All @@ -51,6 +52,7 @@ data class FakePromptRunner(
override val promptContributors: List<PromptContributor>,
private val contextualPromptContributors: List<ContextualPromptElement>,
override val generateExamples: Boolean?,
override val propertyFilter: Predicate<String> = Predicate { true },
private val context: OperationContext,
private val _llmInvocations: MutableList<LlmInvocation> = mutableListOf(),
private val responses: MutableList<Any?> = mutableListOf(),
Expand Down Expand Up @@ -172,6 +174,10 @@ data class FakePromptRunner(
override fun withGenerateExamples(generateExamples: Boolean): PromptRunner =
copy(generateExamples = generateExamples)

override fun withPropertyFilter(filter: Predicate<String>): PromptRunner =
copy(propertyFilter = this.propertyFilter.and(filter))


private fun createLlmInteraction() =
LlmInteraction(
llm = llm ?: LlmOptions(),
Expand Down
Loading