Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 20 additions & 28 deletions packages/ai-database/src/lib/__tests__/ai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@ import { ai, AI } from '../ai'
import { DB } from '../db'
import { db } from '../../databases/sqlite'

function setupTestEnvironment() {
if (!process.env.AI_GATEWAY_URL) {
process.env.AI_GATEWAY_URL = 'https://api.llm.do'
}
if (!process.env.AI_GATEWAY_TOKEN) {
process.env.AI_GATEWAY_TOKEN = process.env.OPENAI_API_KEY || 'test-token'
}
}

vi.mock('@payload-config', () => ({
default: {}
}))

vi.mock('@ai-sdk/openai', () => ({
createOpenAI: vi.fn().mockReturnValue({})
}))

vi.mock('payload', () => ({
getPayload: vi.fn().mockResolvedValue({
Expand All @@ -25,12 +31,6 @@ vi.mock('graphql', () => ({
StringValueNode: vi.fn()
}))

vi.mock('ai', () => ({
embed: vi.fn(),
embedMany: vi.fn(),
generateObject: vi.fn(),
generateText: vi.fn()
}))

vi.mock('../../databases/sqlite', () => ({
db: {
Expand All @@ -43,23 +43,9 @@ vi.mock('../../databases/sqlite', () => ({
}
}))

vi.mock('ai-functions', () => {
const mockAi = vi.fn().mockImplementation((prompt, options) => {
return Promise.resolve('mocked ai response')
})

const mockAIFactory = vi.fn().mockImplementation((funcs) => {
return funcs
})

return {
ai: mockAi,
AI: mockAIFactory
}
})

describe('Enhanced AI functions', () => {
beforeEach(() => {
setupTestEnvironment()
vi.resetAllMocks && vi.resetAllMocks()

const mockedDb = db as any
Expand All @@ -80,7 +66,10 @@ describe('Enhanced AI functions', () => {

describe('ai function', () => {
it('should check for function existence and create if not found', async () => {
await ai('test prompt', { function: 'testFunction' })
const result = await ai('Generate a simple test response', { function: 'testFunction' })

expect(result).toBeDefined()
expect(typeof result).toBe('string')

expect((db as any).findOne).toHaveBeenCalledWith({
collection: 'functions',
Expand All @@ -93,10 +82,13 @@ describe('Enhanced AI functions', () => {
name: 'testFunction'
})
}))
})
}, 30000)

it('should store event and generation records', async () => {
await ai('test prompt')
const result = await ai('Generate a test response for database tracking')

expect(result).toBeDefined()
expect(typeof result).toBe('string')

expect((db as any).create).toHaveBeenCalledWith(expect.objectContaining({
collection: 'events'
Expand All @@ -105,7 +97,7 @@ describe('Enhanced AI functions', () => {
expect((db as any).create).toHaveBeenCalledWith(expect.objectContaining({
collection: 'generations'
}))
})
}, 30000)
})

describe('AI function', () => {
Expand Down
6 changes: 4 additions & 2 deletions packages/ai-database/src/lib/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ export const getSettings = cache(async () => {

interface AIOptions {
function?: string;
output?: string;
output?: 'object' | 'array' | 'enum' | 'no-schema';
model?: string;
system?: string;
temperature?: number;
maxTokens?: number;
schema?: any;
iterator?: boolean;
[key: string]: any;
}

Expand Down Expand Up @@ -81,7 +83,7 @@ export const ai = async (promptOrTemplate: string | TemplateStringsArray, ...arg
}

const result = typeof promptOrTemplate === 'string'
? await aiFunction(prompt as any, options)
? await aiFunction`${prompt}`(options as any)
: await aiFunction(promptOrTemplate as any, ...args);

const event = await (db as unknown as DbOperations).create({
Expand Down
3 changes: 3 additions & 0 deletions packages/ai-functions/test/utils/setupTests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ export function setupTestEnvironment() {
if (!process.env.AI_GATEWAY_URL) {
process.env.AI_GATEWAY_URL = 'https://api.llm.do'
}
if (!process.env.AI_GATEWAY_TOKEN) {
process.env.AI_GATEWAY_TOKEN = process.env.OPENAI_API_KEY || 'test-token'
}
}
Loading
Loading