Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/_utils/__tests__/pagination.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { describe, it, expect } from 'vitest'
import { paginateAll } from '../../_utils/pagination.js'

describe('paginateAll', () => {
it('should return all items from a single page', async () => {
const pages = [{ items: [1, 2, 3], nextToken: undefined as string | undefined }]
let callCount = 0

const result = await paginateAll(
() => Promise.resolve(pages[callCount++]!),
(page) => page.items,
(page) => page.nextToken,
)

expect(result).toEqual([1, 2, 3])
expect(callCount).toBe(1)
})

it('should accumulate items across multiple pages', async () => {
const pages = [
{ items: [1, 2], nextToken: 'tok1' as string | undefined },
{ items: [3, 4], nextToken: 'tok2' as string | undefined },
{ items: [5], nextToken: undefined as string | undefined },
]
let callCount = 0
const tokensReceived: (string | undefined)[] = []

const result = await paginateAll(
(nextToken) => {
tokensReceived.push(nextToken)
return Promise.resolve(pages[callCount++]!)
},
(page) => page.items,
(page) => page.nextToken,
)

expect(result).toEqual([1, 2, 3, 4, 5])
expect(tokensReceived).toEqual([undefined, 'tok1', 'tok2'])
})

it('should return empty array when page has no items', async () => {
const result = await paginateAll(
() => Promise.resolve({ items: undefined as number[] | undefined, nextToken: undefined }),
(page) => page.items,
(page) => page.nextToken,
)

expect(result).toEqual([])
})
})
71 changes: 71 additions & 0 deletions src/_utils/__tests__/polling.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { describe, it, expect } from 'vitest'
import { pollUntil } from '../../_utils/polling.js'

describe('pollUntil', () => {
it('should return true immediately when condition is met', async () => {
const result = await pollUntil(
() => Promise.resolve(true),
{ maxWaitSeconds: 5, pollIntervalMs: 10 },
)
expect(result).toBe(true)
})

it('should poll until condition becomes true', async () => {
let calls = 0
const result = await pollUntil(
() => Promise.resolve(++calls >= 3),
{ maxWaitSeconds: 5, pollIntervalMs: 10 },
)
expect(result).toBe(true)
expect(calls).toBe(3)
})

it('should return false on timeout when no timeoutErrorMessage', async () => {
const result = await pollUntil(
() => Promise.resolve(false),
{ maxWaitSeconds: 0.05, pollIntervalMs: 10 },
)
expect(result).toBe(false)
})

it('should throw on timeout when timeoutErrorMessage is set', async () => {
await expect(
pollUntil(
() => Promise.resolve(false),
{ maxWaitSeconds: 0.05, pollIntervalMs: 10, timeoutErrorMessage: 'timed out' },
),
).rejects.toThrow('timed out')
})

it('should swallow errors matching shouldSwallowError predicate', async () => {
let calls = 0
const result = await pollUntil(
() => {
calls++
if (calls < 2) throw new Error('transient')
return Promise.resolve(true)
},
{ maxWaitSeconds: 5, pollIntervalMs: 10, shouldSwallowError: () => true },
)
expect(result).toBe(true)
expect(calls).toBe(2)
})

it('should propagate errors not matched by shouldSwallowError', async () => {
await expect(
pollUntil(
() => { throw new Error('fatal') },
{ maxWaitSeconds: 5, pollIntervalMs: 10, shouldSwallowError: (err) => (err as Error).message !== 'fatal' },
),
).rejects.toThrow('fatal')
})

it('should propagate all errors when shouldSwallowError is not provided', async () => {
await expect(
pollUntil(
() => { throw new Error('fatal') },
{ maxWaitSeconds: 5, pollIntervalMs: 10 },
),
).rejects.toThrow('fatal')
})
})
34 changes: 34 additions & 0 deletions src/_utils/pagination.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
export const DEFAULT_PAGE_SIZE = 100

/**
* Paginate any SDK method that follows the nextToken request/response pattern.
*
* @param fetchPage - Fetches a single page, given an optional nextToken
* @param getItems - Extracts the item array from each page response
* @param getNextToken - Extracts the nextToken from each page response
* @returns All items accumulated across pages
*
* @example
* ```typescript
* const allEvents = await paginateAll(
* (nextToken) => client.listEvents({ memoryId, actorId, sessionId, nextToken }),
* (page) => page.events,
* (page) => page.nextToken,
* );
* ```
*/
export async function paginateAll<TOutput, TItem>(
fetchPage: (nextToken?: string) => Promise<TOutput>,
getItems: (output: TOutput) => TItem[] | undefined,
getNextToken: (output: TOutput) => string | undefined,
): Promise<TItem[]> {
const items: TItem[] = []
let nextToken: string | undefined
do {
const page = await fetchPage(nextToken)
const pageItems = getItems(page)
if (pageItems) items.push(...pageItems)
nextToken = getNextToken(page)
} while (nextToken)
return items
}
29 changes: 29 additions & 0 deletions src/_utils/polling.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* Poll a condition function until it returns `true` or the timeout expires.
*
* @param condition - Async function that returns `true` when done. Throwing is treated as "not done yet" when `swallowErrors` is true.
* @param opts - Polling options
* @returns `true` if the condition was met, `false` if timed out (only when `timeoutErrorMessage` is not set)
* @throws Error with `timeoutErrorMessage` if provided and timeout expires
*/
export async function pollUntil(
condition: () => Promise<boolean>,
opts: {
maxWaitSeconds: number
pollIntervalMs: number
timeoutErrorMessage?: string
shouldSwallowError?: (err: unknown) => boolean
},
): Promise<boolean> {
const deadline = Date.now() + opts.maxWaitSeconds * 1000
while (Date.now() < deadline) {
try {
if (await condition()) return true
} catch (err) {
if (!opts.shouldSwallowError?.(err)) throw err
}
await new Promise((resolve) => globalThis.setTimeout(resolve, opts.pollIntervalMs))
}
if (opts.timeoutErrorMessage) throw new Error(opts.timeoutErrorMessage)
return false
}
172 changes: 172 additions & 0 deletions src/memory/__tests__/client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import { describe, it, expect } from 'vitest'
import type { BedrockAgentCore } from '@aws-sdk/client-bedrock-agentcore'
import type { BedrockAgentCoreControl } from '@aws-sdk/client-bedrock-agentcore-control'
import { MemoryClient } from '../client.js'
import { DATA_PLANE_METHODS, CONTROL_PLANE_METHODS } from '../types.js'

function fakeControlPlane(overrides: Record<string, (input: unknown) => unknown>): Record<string, unknown> {
return new Proxy({} as Record<string, unknown>, {
get: (_, method) => overrides[method as string] ?? (() => Promise.resolve({})),
})
}

function fakeDataPlane(overrides: Record<string, (input: unknown) => unknown>): Record<string, unknown> {
return new Proxy({} as Record<string, unknown>, {
get: (_, method) => overrides[method as string] ?? (() => Promise.resolve({})),
})
}

describe('MemoryClient', () => {
describe('passthrough', () => {
const client = new MemoryClient({ region: 'us-west-2' })

it('exposes every data plane method as a function', () => {
for (const method of DATA_PLANE_METHODS) {
expect(typeof client[method]).toBe('function')
}
})

it('exposes every control plane method as a function', () => {
for (const method of CONTROL_PLANE_METHODS) {
expect(typeof client[method]).toBe('function')
}
})

it('does not expose arbitrary properties', () => {
expect((client as unknown as Record<string, unknown>)['nonExistentMethod']).toBeUndefined()
})

it('constructs without config using defaults', () => {
expect(new MemoryClient()).toBeDefined()
})
})

describe('memory() scoping', () => {
const client = new MemoryClient({ region: 'us-west-2' })
const mem = client.memory('mem-1')

it('exposes scoped methods as functions', () => {
expect(typeof mem.createEvent).toBe('function')
expect(typeof mem.retrieveMemoryRecords).toBe('function')
expect(typeof mem.listEvents).toBe('function')
expect(typeof mem.listActors).toBe('function')
})
})

describe('createOrGetMemory()', () => {
it('returns existing memory when create throws already-exists', async () => {
const client = new MemoryClient({
controlPlaneClient: fakeControlPlane({
createMemory: () => {
throw Object.assign(new Error('already exists'), { name: 'ValidationException' })
},
getMemory: () => Promise.resolve({ memory: { id: 'mem-1' }, $metadata: {} }),
}) as unknown as BedrockAgentCoreControl,
})

const result = await client.createOrGetMemory({ name: 'test', eventExpiryDuration: 30 })
expect(result.memory).toMatchObject({ id: 'mem-1' })
})

it('propagates non-conflict errors', async () => {
const client = new MemoryClient({
controlPlaneClient: fakeControlPlane({
createMemory: () => {
throw Object.assign(new Error('throttled'), { name: 'ThrottlingException' })
},
}) as unknown as BedrockAgentCoreControl,
})

await expect(client.createOrGetMemory({ name: 'test', eventExpiryDuration: 30 })).rejects.toThrow('throttled')
})
})

describe('deleteMemoryAndWait()', () => {
it('resolves when resource is not found after delete', async () => {
const client = new MemoryClient({
controlPlaneClient: fakeControlPlane({
deleteMemory: () => Promise.resolve({}),
getMemory: () => {
throw Object.assign(new Error(), { name: 'ResourceNotFoundException' })
},
}) as unknown as BedrockAgentCoreControl,
})

await expect(
client.deleteMemoryAndWait('mem-1', { maxWaitSeconds: 1, pollIntervalMs: 10 })
).resolves.toBeUndefined()
})
})

describe('getLastKTurns()', () => {
it('groups messages into turns at USER boundaries', async () => {
const client = new MemoryClient({
dataPlaneClient: fakeDataPlane({
listEvents: () =>
Promise.resolve({
events: [
{ eventId: 'e1', payload: [{ conversational: { role: 'USER', content: { text: 'hi' } } }] },
{ eventId: 'e2', payload: [{ conversational: { role: 'ASSISTANT', content: { text: 'hello' } } }] },
{ eventId: 'e3', payload: [{ conversational: { role: 'USER', content: { text: 'bye' } } }] },
{ eventId: 'e4', payload: [{ conversational: { role: 'ASSISTANT', content: { text: 'goodbye' } } }] },
],
}),
}) as unknown as BedrockAgentCore,
})

const turns = await client.memory('mem-1').getLastKTurns({ actorId: 'a1', sessionId: 's1', k: 1 })
expect(turns).toMatchObject([[{ role: 'USER' }, { role: 'ASSISTANT' }]])
})

it('returns all turns when k exceeds total', async () => {
const client = new MemoryClient({
dataPlaneClient: fakeDataPlane({
listEvents: () =>
Promise.resolve({
events: [
{ eventId: 'e1', payload: [{ conversational: { role: 'USER', content: { text: 'hi' } } }] },
{ eventId: 'e2', payload: [{ conversational: { role: 'ASSISTANT', content: { text: 'hello' } } }] },
],
}),
}) as unknown as BedrockAgentCore,
})

const turns = await client.memory('mem-1').getLastKTurns({ actorId: 'a1', sessionId: 's1', k: 5 })
expect(turns).toMatchObject([[{ role: 'USER' }, { role: 'ASSISTANT' }]])
})

it('returns empty array for empty session', async () => {
const client = new MemoryClient({
dataPlaneClient: fakeDataPlane({
listEvents: () => Promise.resolve({ events: [] }),
}) as unknown as BedrockAgentCore,
})

const turns = await client.memory('mem-1').getLastKTurns({ actorId: 'a1', sessionId: 's1', k: 5 })
expect(turns).toEqual([])
})
})

describe('listBranches()', () => {
it('aggregates branch info from events', async () => {
const client = new MemoryClient({
dataPlaneClient: fakeDataPlane({
listEvents: () =>
Promise.resolve({
events: [
{ eventId: 'e1' },
{ eventId: 'e2' },
{ eventId: 'e3', branch: { name: 'alt', rootEventId: 'e1' } },
],
}),
}) as unknown as BedrockAgentCore,
})

const branches = await client.memory('mem-1').listBranches({ actorId: 'a1', sessionId: 's1' })
expect(branches).toMatchObject([
{ name: 'main', eventCount: 2 },
{ name: 'alt', eventCount: 1, rootEventId: 'e1' },
])
})
})
})
Loading
Loading