Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deploy/src/db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export class ConfigDB extends KeysDbD1 {
providers: providersWithKeys,
routingGroups,
otelSettings: user?.otel ?? project.otel,
cacheEnabled: keyInfo.cacheEnabled,
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion deploy/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

import { env } from 'cloudflare:workers'
import { type GatewayOptions, gatewayFetch, LimitDbD1 } from '@pydantic/ai-gateway'
import { type GatewayOptions, gatewayFetch, KVCacheStorage, LimitDbD1 } from '@pydantic/ai-gateway'
import { instrument } from '@pydantic/logfire-cf-workers'
import logfire from 'logfire'
import { config } from './config'
Expand All @@ -40,6 +40,7 @@ const handler = {
kv: env.KV,
kvVersion: await hash(JSON.stringify(config)),
subFetch: fetch,
cache: { storage: new KVCacheStorage(env.KV) },
}
try {
return await gatewayFetch(request, url, ctx, gatewayEnv)
Expand Down
1 change: 1 addition & 0 deletions deploy/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ export interface ApiKey<ProviderKey extends string> {
spendingLimitMonthly?: number
spendingLimitTotal?: number
providers: ProviderKey[] | '__all__'
cacheEnabled?: boolean
}
8 changes: 7 additions & 1 deletion gateway/src/gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { type GatewayOptions, noopLimiter } from '.'
import { apiKeyAuth, setApiKeyCache } from './auth'
import { currentScopeIntervals, type ExceededScope, endOfMonth, endOfWeek, type SpendScope } from './db'
import { type HandlerResponse, RequestHandler } from './handler'
import { CacheMiddleware } from './middleware/cache'
import { OtelTrace } from './otel'
import { genAiOtelAttributes } from './otel/attributes'
import type { ApiKeyInfo, ProviderProxy } from './types'
Expand Down Expand Up @@ -174,6 +175,11 @@ export async function gatewayWithLimiter(

const otel = new OtelTrace(request, apiKeyInfo.otelSettings, options)

const middlewares = options.proxyMiddlewares ?? []
if (options.cache) {
middlewares.push(new CacheMiddleware({ storage: options.cache.storage }))
}

let result: HandlerResponse | null = null

for (const providerProxy of providerProxies) {
Expand All @@ -187,7 +193,7 @@ export async function gatewayWithLimiter(
apiKeyInfo,
restOfPath,
otelSpan,
middlewares: options.proxyMiddlewares,
middlewares,
})

try {
Expand Down
5 changes: 5 additions & 0 deletions gateway/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import logfire from 'logfire'
import type { KeysDb, LimitDb } from './db'
import { gateway } from './gateway'
import type { Middleware, Next } from './handler'
import type { CacheStorage as GatewayCacheStorage } from './middleware/storage'
import type { RateLimiter } from './rateLimiter'
import { refreshGenaiPrices } from './refreshGenaiPrices'
import type { SubFetch } from './types'
Expand All @@ -27,6 +28,8 @@ export { changeProjectState as setProjectState, deleteApiKeyCache, setApiKeyCach
export type { Middleware, Next }
export * from './db'
export type { RequestHandler } from './handler'
export { CacheMiddleware, type CacheOptions } from './middleware/cache'
export { type CachedResponse, type CacheStorage, KVCacheStorage } from './middleware/storage'
export * from './rateLimiter'
export * from './types'

Expand All @@ -42,6 +45,8 @@ export interface GatewayOptions {
proxyPrefixLength?: number
/** proxyMiddlewares: perform actions before and after the request is made to the providers */
proxyMiddlewares?: Middleware[]
/** Cache configuration */
cache?: { storage: GatewayCacheStorage }
}

export async function gatewayFetch(
Expand Down
158 changes: 158 additions & 0 deletions gateway/src/middleware/cache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import logfire from 'logfire'
import type { HandlerResponse, Middleware, Next, RequestHandler } from '../handler'
import type { CachedResponse, CacheStorage as GatewayCacheStorage } from './storage'

export interface CacheOptions {
storage: GatewayCacheStorage
}

export class CacheMiddleware implements Middleware {
private options: CacheOptions

constructor(options: CacheOptions) {
this.options = options
}

dispatch(next: Next): Next {
return async (handler: RequestHandler) => {
if (!handler.apiKeyInfo.cacheEnabled) {
return await next(handler)
}

const { method, url, headers } = handler.request
// Clone the request to read the body without consuming the original
const requestBody = await handler.request.clone().text()
const requestUrl = new URL(url)
requestUrl.pathname = handler.restOfPath
const path = requestUrl.toString()

const apiKeyId = handler.apiKeyInfo.id
const hash = await this.calculateHash(method, path, requestBody, apiKeyId)

const shouldBypassCache = this.shouldBypassCache(headers)

if (!shouldBypassCache) {
const cached = await this.getCachedResponse(hash)

if (cached) {
logfire.info('Cache hit', { hash, apiKeyId: handler.apiKeyInfo.id })
return this.toCachedHandlerResponse(requestBody, cached)
}
}

const result = await next(handler)

const shouldStoreCache = this.shouldStoreCache(handler.request, result)
if (shouldStoreCache) {
handler.runAfter('cache-store', this.storeCachedResponse(hash, result))
}

return this.addCacheHeaders(result, shouldBypassCache ? 'BYPASS' : 'MISS')
}
}

private shouldBypassCache(requestHeaders: Headers): boolean {
const cacheControl = requestHeaders.get('cache-control')
return cacheControl?.includes('no-cache') || cacheControl?.includes('no-store') || false
}

private shouldStoreCache(request: Request, result: HandlerResponse): boolean {
const cacheControl = request.headers.get('cache-control')

if (cacheControl?.includes('no-store')) {
return false
}

if ('responseStream' in result) {
return false
}

if ('error' in result || 'unexpectedStatus' in result || 'response' in result || 'modelNotFound' in result) {
return false
}

return true
}

private async calculateHash(method: string, url: string, body: string, apiKeyId: number): Promise<string> {
const data = `${apiKeyId}:${method}:${url}:${body}`
const encoder = new TextEncoder()
const dataBuffer = encoder.encode(data)
const hashBuffer = await crypto.subtle.digest('SHA-256', dataBuffer)
const hashArray = Array.from(new Uint8Array(hashBuffer))
const hashHex = hashArray.map((b) => b.toString(16).padStart(2, '0')).join('')

return hashHex
}

private async getCachedResponse(hash: string): Promise<CachedResponse | null> {
try {
return await this.options.storage.get(hash)
} catch (error) {
logfire.reportError('Error getting cached response', error as Error, { hash })
return null
}
}

private async storeCachedResponse(hash: string, result: HandlerResponse): Promise<void> {
if (!('successStatus' in result) || 'responseStream' in result) {
return
}

try {
const { successStatus, responseHeaders, responseBody, requestModel, responseModel } = result

const headers: Record<string, string> = {}
responseHeaders.forEach((value, key) => {
headers[key] = value
})

const cached: CachedResponse = {
status: successStatus,
headers,
body: responseBody,
timestamp: Date.now(),
requestModel,
responseModel,
}

await this.options.storage.set(hash, cached)

const sizeBytes = new TextEncoder().encode(responseBody).length

logfire.info('Response cached', { hash, sizeBytes })
} catch (error) {
logfire.reportError('Error storing cached response', error as Error, { hash })
}
}

private toCachedHandlerResponse(
requestBody: string,
cached: CachedResponse,
): Extract<HandlerResponse, { successStatus: number }> {
const responseHeaders = new Headers(cached.headers)
const age = Math.floor((Date.now() - cached.timestamp) / 1000)

responseHeaders.set('Age', age.toString())
responseHeaders.set('X-Cache-Status', 'HIT')

return {
successStatus: cached.status,
responseHeaders,
responseBody: cached.body,
requestBody,
requestModel: cached.requestModel,
responseModel: cached.responseModel ?? 'unknown',
usage: { input_tokens: 0, output_tokens: 0 },
cost: 0,
}
}

private addCacheHeaders(result: HandlerResponse, status: 'HIT' | 'MISS' | 'BYPASS'): HandlerResponse {
if ('responseHeaders' in result) {
result.responseHeaders.set('X-Cache-Status', status)
}

return result
}
}
37 changes: 37 additions & 0 deletions gateway/src/middleware/storage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
export interface CachedResponse {
status: number
headers: Record<string, string>
body: string
timestamp: number
requestModel?: string
responseModel?: string
}

export interface CacheStorage {
get(hash: string): Promise<CachedResponse | null>
set(hash: string, response: CachedResponse): Promise<void>
}

export class KVCacheStorage implements CacheStorage {
private kv: KVNamespace
private namespace: string
private ttl: number

constructor(kv: KVNamespace, namespace: string = 'response', ttl: number = 86400) {
this.kv = kv
this.namespace = namespace
this.ttl = ttl
}

async get(hash: string): Promise<CachedResponse | null> {
const kvKey = cacheKey(this.namespace, hash)
return await this.kv.get<CachedResponse>(kvKey, 'json')
}

async set(hash: string, response: CachedResponse): Promise<void> {
const kvKey = cacheKey(this.namespace, hash)
await this.kv.put(kvKey, JSON.stringify(response), { expirationTtl: this.ttl })
}
}

const cacheKey = (namespace: string, hash: string): string => `${namespace}:${hash}`
1 change: 1 addition & 0 deletions gateway/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export interface ApiKeyInfo<ProviderKey extends string = string> {
// among values with same priority, use weight for randomized load balancing; if missing, treat as 1
routingGroups: Record<string, { key: ProviderKey; priority?: number; weight?: number }[]>
otelSettings?: OtelSettings
cacheEnabled?: boolean
}

export type ProviderID =
Expand Down
Loading