Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"version": 1,
"currency": "USD",
"publishedAt": "2026-03-02",
"entries": [
{
"providerId": "openai",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ data class AmpereConfig(
val workspace: String? = null,
val databasePath: String? = null,
val onEscalation: ((Escalated) -> Unit)? = null,
val pricingOverrides: PricingOverrides = PricingOverrides(),
) {
class Builder {
private var providerConfig: ProviderConfig? = null
private var workspace: String? = null
private var databasePath: String? = null
private var escalationHandler: ((Escalated) -> Unit)? = null
private val pricingOverridesBuilder = PricingOverridesBuilder()

/**
* Set the AI provider configuration.
Expand Down Expand Up @@ -75,6 +77,31 @@ data class AmpereConfig(
escalationHandler = handler
}

/**
* Override bundled pricing data or add private model pricing.
*
* ```
* pricing {
* model("openai", "gpt-4.1") {
* tier(
* inputUsdPerMillionTokens = 1.5,
* outputUsdPerMillionTokens = 6.0,
* )
* }
*
* model("self-hosted", "mixtral-enterprise") {
* tier(
* inputUsdPerMillionTokens = 0.0,
* outputUsdPerMillionTokens = 0.0,
* )
* }
* }
* ```
*/
fun pricing(configure: PricingOverridesBuilder.() -> Unit) {
pricingOverridesBuilder.apply(configure)
}

fun build(): AmpereConfig {
val provider = requireNotNull(providerConfig) {
"Provider is required. Use provider(AnthropicConfig()) or similar."
Expand All @@ -84,6 +111,7 @@ data class AmpereConfig(
workspace = workspace,
databasePath = databasePath,
onEscalation = escalationHandler,
pricingOverrides = pricingOverridesBuilder.build(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import link.socket.ampere.api.service.AgentService
import link.socket.ampere.api.service.EventService
import link.socket.ampere.api.service.KnowledgeService
import link.socket.ampere.api.service.OutcomeService
import link.socket.ampere.api.service.PricingService
import link.socket.ampere.api.service.StatusService
import link.socket.ampere.api.service.ThreadService
import link.socket.ampere.api.service.TicketService
Expand Down Expand Up @@ -42,6 +43,9 @@ interface AmpereInstance : AutoCloseable {
/** Execution history and outcome tracking */
val outcomes: OutcomeService

/** Bundled model pricing, overrides, and cost estimation */
val pricing: PricingService

/** Persistent knowledge and memory */
val knowledge: KnowledgeService

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package link.socket.ampere.api

import kotlinx.serialization.Serializable
import link.socket.ampere.api.model.ModelPricing
import link.socket.ampere.api.model.PricingTier

/**
* Consumer-provided pricing entries that override bundled model rates.
*/
@AmpereStableApi
@Serializable
data class PricingOverrides(
val models: List<ModelPricing> = emptyList(),
)

/**
* Builder for [PricingOverrides].
*/
@AmpereStableApi
class PricingOverridesBuilder {
private val modelsByKey = linkedMapOf<PricingModelKey, ModelPricing>()

/**
* Add or replace pricing for a provider/model pair.
*/
fun model(pricing: ModelPricing) {
validateModelPricing(pricing)
modelsByKey[pricingModelKey(pricing.providerId, pricing.modelId)] = pricing
}

/**
* Add or replace pricing for a provider/model pair using the DSL.
*/
fun model(
providerId: String,
modelId: String,
configure: ModelPricingBuilder.() -> Unit,
) {
model(ModelPricingBuilder(providerId = providerId, modelId = modelId).apply(configure).build())
}

internal fun build(): PricingOverrides = PricingOverrides(models = modelsByKey.values.toList())
}

/**
* Builder for a single [ModelPricing] entry.
*/
@AmpereStableApi
class ModelPricingBuilder internal constructor(
private val providerId: String,
private val modelId: String,
) {
private val tiers = mutableListOf<PricingTier>()

fun tier(
maxInputTokens: Int? = null,
inputUsdPerMillionTokens: Double,
outputUsdPerMillionTokens: Double,
) {
val tier = PricingTier(
maxInputTokens = maxInputTokens,
inputUsdPerMillionTokens = inputUsdPerMillionTokens,
outputUsdPerMillionTokens = outputUsdPerMillionTokens,
)
validatePricingTier(tier)
tiers += tier
}

internal fun build(): ModelPricing {
val pricing = ModelPricing(
providerId = providerId,
modelId = modelId,
tiers = tiers.toList(),
)
validateModelPricing(pricing)
return pricing
}
}

internal data class PricingModelKey(
val providerId: String,
val modelId: String,
)

internal fun pricingModelKey(providerId: String, modelId: String): PricingModelKey = PricingModelKey(
providerId = providerId.trim().lowercase(),
modelId = modelId.trim().lowercase(),
)

internal fun validateModelPricing(pricing: ModelPricing) {
require(pricing.providerId.isNotBlank()) { "Pricing providerId cannot be blank." }
require(pricing.modelId.isNotBlank()) { "Pricing modelId cannot be blank." }
require(pricing.tiers.isNotEmpty()) {
"Pricing entry ${pricing.providerId}/${pricing.modelId} must include at least one tier."
}
pricing.tiers.forEach(::validatePricingTier)
}

internal fun validatePricingTier(tier: PricingTier) {
require(tier.maxInputTokens == null || tier.maxInputTokens > 0) {
"Pricing tier maxInputTokens must be positive when provided."
}
require(tier.inputUsdPerMillionTokens >= 0.0) {
"Pricing tier inputUsdPerMillionTokens cannot be negative."
}
require(tier.outputUsdPerMillionTokens >= 0.0) {
"Pricing tier outputUsdPerMillionTokens cannot be negative."
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package link.socket.ampere.api.internal

import link.socket.ampere.api.PricingModelKey
import link.socket.ampere.api.PricingOverrides
import link.socket.ampere.api.model.ModelPricing
import link.socket.ampere.api.model.PricingDataVersion
import link.socket.ampere.api.model.PricingEstimateRequest
import link.socket.ampere.api.model.PricingEstimateResult
import link.socket.ampere.api.model.PricingTier
import link.socket.ampere.api.pricingModelKey
import link.socket.ampere.api.service.PricingService
import link.socket.ampere.api.validateModelPricing
import link.socket.ampere.domain.ai.pricing.BundledProviderPricingCatalog
import link.socket.ampere.domain.ai.pricing.ProviderModelPricing
import link.socket.ampere.domain.ai.pricing.ProviderPricingCalculator
import link.socket.ampere.domain.ai.pricing.ProviderPricingCatalog
import link.socket.ampere.domain.ai.pricing.TokenPricingTier

internal class DefaultPricingService(
private val overrides: PricingOverrides = PricingOverrides(),
private val bundledCatalogLoader: suspend () -> ProviderPricingCatalog = { BundledProviderPricingCatalog.load() },
) : PricingService {
private var cachedCatalog: EffectivePricingCatalog? = null

override suspend fun get(providerId: String, modelId: String): Result<ModelPricing?> = runCatching {
effectiveCatalog().entriesByKey[pricingModelKey(providerId, modelId)]
}

override suspend fun list(): Result<List<ModelPricing>> = runCatching {
effectiveCatalog().entriesByKey.values.toList()
}

override suspend fun version(): Result<PricingDataVersion> = runCatching {
effectiveCatalog().version
}

override suspend fun estimate(request: PricingEstimateRequest): Result<PricingEstimateResult?> = runCatching {
val catalog = effectiveCatalog()
val pricing = catalog.entriesByKey[
pricingModelKey(request.providerId, request.modelId),
] ?: return@runCatching null
val inputTokens = request.usage.inputTokens ?: return@runCatching null
val outputTokens = request.usage.outputTokens ?: return@runCatching null
if (inputTokens < 0 || outputTokens < 0) return@runCatching null

val appliedTier = pricing.tiers.firstOrNull { tier ->
tier.maxInputTokens == null || inputTokens <= tier.maxInputTokens
} ?: return@runCatching null

val estimatedCost = ProviderPricingCalculator.estimateUsd(
pricing = pricing.toDomainPricing(),
inputTokens = inputTokens,
outputTokens = outputTokens,
) ?: return@runCatching null

PricingEstimateResult(
providerId = pricing.providerId,
modelId = pricing.modelId,
usage = request.usage.copy(estimatedCost = estimatedCost),
pricing = pricing,
appliedTier = appliedTier,
version = catalog.version,
)
}

private suspend fun effectiveCatalog(): EffectivePricingCatalog {
cachedCatalog?.let { return it }

overrides.models.forEach(::validateModelPricing)
val bundledCatalog = bundledCatalogLoader()
return bundledCatalog.toEffectiveCatalog(overrides).also { cachedCatalog = it }
}
}

private data class EffectivePricingCatalog(
val version: PricingDataVersion,
val entriesByKey: LinkedHashMap<PricingModelKey, ModelPricing>,
)

private fun ProviderPricingCatalog.toEffectiveCatalog(overrides: PricingOverrides): EffectivePricingCatalog {
val entriesByKey = linkedMapOf<PricingModelKey, ModelPricing>()

entries.forEach { pricing ->
val apiPricing = pricing.toApiPricing()
entriesByKey[pricingModelKey(apiPricing.providerId, apiPricing.modelId)] = apiPricing
}
overrides.models.forEach { pricing ->
entriesByKey[pricingModelKey(pricing.providerId, pricing.modelId)] = pricing
}

return EffectivePricingCatalog(
version = PricingDataVersion(
version = version,
currency = currency,
publishedAt = publishedAt,
overridesApplied = overrides.models.size,
),
entriesByKey = LinkedHashMap(entriesByKey),
)
}

private fun ProviderModelPricing.toApiPricing(): ModelPricing = ModelPricing(
providerId = providerId,
modelId = modelId,
tiers = tiers.map(TokenPricingTier::toApiTier),
)

private fun TokenPricingTier.toApiTier(): PricingTier = PricingTier(
maxInputTokens = maxInputTokens,
inputUsdPerMillionTokens = inputUsdPerMillionTokens,
outputUsdPerMillionTokens = outputUsdPerMillionTokens,
)

private fun ModelPricing.toDomainPricing(): ProviderModelPricing = ProviderModelPricing(
providerId = providerId,
modelId = modelId,
tiers = tiers.map(PricingTier::toDomainTier),
)

private fun PricingTier.toDomainTier(): TokenPricingTier = TokenPricingTier(
maxInputTokens = maxInputTokens,
inputUsdPerMillionTokens = inputUsdPerMillionTokens,
outputUsdPerMillionTokens = outputUsdPerMillionTokens,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package link.socket.ampere.api.model

import kotlinx.serialization.Serializable

/**
* Effective token pricing for a provider/model pair.
*/
@link.socket.ampere.api.AmpereStableApi
@Serializable
data class ModelPricing(
val providerId: String,
val modelId: String,
val tiers: List<PricingTier>,
)

/**
* Token price tier expressed in USD per million tokens.
*/
@link.socket.ampere.api.AmpereStableApi
@Serializable
data class PricingTier(
val maxInputTokens: Int? = null,
val inputUsdPerMillionTokens: Double,
val outputUsdPerMillionTokens: Double,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package link.socket.ampere.api.model

import kotlinx.serialization.Serializable

/**
* Version metadata for bundled pricing data plus any consumer overrides.
*/
@link.socket.ampere.api.AmpereStableApi
@Serializable
data class PricingDataVersion(
val version: Int,
val currency: String,
val publishedAt: String? = null,
val overridesApplied: Int = 0,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package link.socket.ampere.api.model

import kotlinx.serialization.Serializable

/**
* Inputs for pricing estimation.
*/
@link.socket.ampere.api.AmpereStableApi
@Serializable
data class PricingEstimateRequest(
val providerId: String,
val modelId: String,
val usage: TokenUsage,
)

/**
* Estimated cost plus the pricing data used to compute it.
*/
@link.socket.ampere.api.AmpereStableApi
@Serializable
data class PricingEstimateResult(
val providerId: String,
val modelId: String,
val usage: TokenUsage,
val pricing: ModelPricing,
val appliedTier: PricingTier,
val version: PricingDataVersion,
)
Loading