diff --git a/src/agent/router.ts b/src/agent/router.ts index c162e0d..f6c7975 100644 --- a/src/agent/router.ts +++ b/src/agent/router.ts @@ -2,11 +2,11 @@ * Router - Compares APYs and triggers rebalancing when conditions are met */ -import { Decimal } from '@prisma/client/runtime/library'; import { PrismaClient } from '@prisma/client'; import { logger } from '../utils/logger'; import { ProtocolComparison, RebalanceDetails, RebalanceThresholds } from './types'; import { scanAllProtocols, getCurrentOnChainApy } from './scanner'; +import { triggerRebalance as submitRebalance } from '../stellar/contract'; const prisma = new PrismaClient(); @@ -15,6 +15,14 @@ const DEFAULT_THRESHOLDS: RebalanceThresholds = { maxGasPercent: 0.1, }; +function toApyBasisPoints(apyPercent: number): number { + if (!Number.isFinite(apyPercent) || apyPercent < 0) { + throw new Error('APY must be a non-negative number'); + } + + return Math.round(apyPercent * 100); +} + /** * Estimate transaction costs for a rebalance * Accounts for gas fees and potential DEX slippage @@ -119,33 +127,76 @@ export async function compareProtocols( export async function triggerRebalance( fromProtocol: string, toProtocol: string, - amount: string + amount: string, + positionIds: string[] = [], ): Promise { const startTime = Date.now(); try { + const comparison = await compareProtocols(fromProtocol, amount); + if (!comparison) { + throw new Error(`Unable to compare protocols for ${fromProtocol}`); + } + + const expectedApyBasisPoints = toApyBasisPoints(comparison.best.apy); + logger.info('Rebalance triggered', { fromProtocol, toProtocol, amount, + expectedApyBasisPoints, }); - // TODO: Call actual smart contract to execute rebalance - // This would interact with the Stellar Soroban vault contract - // const txHash = await executeRebalanceOnChain(fromProtocol, toProtocol, amount); + const onChainTransaction = await submitRebalance( + toProtocol, + expectedApyBasisPoints, + ); - const mockTxHash = `mock_tx_${Date.now()}`; + if (positionIds.length > 0) { + const representativePosition = await prisma.position.findFirst({ + where: { + id: { in: positionIds }, + }, + include: { + user: { + select: { + network: true, + }, + }, + }, + }); - const comparison = await compareProtocols(fromProtocol); - const improvement = comparison ? comparison.improvement : 0; + if (representativePosition) { + await prisma.transaction.create({ + data: { + userId: representativePosition.userId, + positionId: representativePosition.id, + txHash: onChainTransaction.hash, + type: 'REBALANCE', + status: 'PENDING', + assetSymbol: representativePosition.assetSymbol, + amount, + network: representativePosition.user.network, + protocolName: toProtocol, + memo: `Agent rebalance from ${fromProtocol} to ${toProtocol}`, + } as any, + }); + } else { + logger.warn('No position found to persist rebalance transaction', { + fromProtocol, + toProtocol, + positionIds, + }); + } + } const rebalanceDetail: RebalanceDetails = { fromProtocol, toProtocol, amount, - txHash: mockTxHash, + txHash: onChainTransaction.hash, timestamp: new Date(), - improvedBy: improvement, + improvedBy: comparison.improvement, }; const duration = Date.now() - startTime; @@ -156,9 +207,9 @@ export async function triggerRebalance( }); logger.info('Rebalance successful', { - txHash: mockTxHash, + txHash: onChainTransaction.hash, duration, - improvedBy: improvement.toFixed(2), + improvedBy: comparison.improvement.toFixed(2), }); return rebalanceDetail; @@ -217,7 +268,8 @@ export async function executeRebalanceIfNeeded( return await triggerRebalance( currentProtocol, comparison.best.name, - totalAmount + totalAmount, + userPositions.map(pos => pos.id), ); } catch (error) { logger.error('Rebalance execution check failed', { diff --git a/src/index.ts b/src/index.ts index fd94e87..ba0c573 100644 --- a/src/index.ts +++ b/src/index.ts @@ -9,6 +9,7 @@ import { logger } from './utils/logger' import { startAgentLoop } from './agent/loop' import { connectDb } from './db' import { scheduleSessionCleanup } from './jobs/sessionCleanup' +import { startEventListener } from './stellar/events' import healthRouter from './routes/health' import agentRouter from './routes/agent' import authRouter from './routes/auth' @@ -57,6 +58,8 @@ async function main() { logger.info(`Network: ${config.stellar.network}`) try { + await startEventListener() + logger.info('Vault event listener started') await startAgentLoop() } catch (error) { logger.error('Failed to start agent loop', { diff --git a/tests/unit/agent/router.test.ts b/tests/unit/agent/router.test.ts index 9ab8e30..c45181f 100644 --- a/tests/unit/agent/router.test.ts +++ b/tests/unit/agent/router.test.ts @@ -19,22 +19,38 @@ jest.mock('../../../src/agent/scanner', () => ({ getBestProtocol: jest.fn(), })); +const mockFindFirst = jest.fn(); +const mockTransactionCreate = jest.fn(); +const mockAgentLogCreate = jest.fn().mockResolvedValue({}); +const mockContractTriggerRebalance = jest.fn(); + // Mock Prisma used by logAgentAction jest.mock('@prisma/client', () => ({ PrismaClient: jest.fn().mockImplementation(() => ({ user: { findMany: jest.fn().mockResolvedValue([{ id: 'test-user-id' }]), }, + position: { + findFirst: (...args: unknown[]) => mockFindFirst(...args), + }, + transaction: { + create: (...args: unknown[]) => mockTransactionCreate(...args), + }, agentLog: { - create: jest.fn().mockResolvedValue({}), + create: (...args: unknown[]) => mockAgentLogCreate(...args), }, })), })); +jest.mock('../../../src/stellar/contract', () => ({ + triggerRebalance: (...args: unknown[]) => mockContractTriggerRebalance(...args), +})); + import { compareProtocols, executeRebalanceIfNeeded, getThresholds, + triggerRebalance, } from '../../../src/agent/router'; import { scanAllProtocols, @@ -63,6 +79,19 @@ const marginalProtocol = { describe('Agent Router', () => { beforeEach(() => { jest.clearAllMocks(); + mockFindFirst.mockResolvedValue({ + id: 'pos-1', + userId: 'test-user-id', + assetSymbol: 'USDC', + user: { network: 'TESTNET' }, + }); + mockTransactionCreate.mockResolvedValue({}); + mockAgentLogCreate.mockResolvedValue({}); + mockContractTriggerRebalance.mockResolvedValue({ + hash: 'real-rebalance-hash-001', + status: 'success', + ledger: 77, + }); }); // ── compareProtocols ────────────────────────────────────────────────────── @@ -187,7 +216,18 @@ describe('Agent Router', () => { expect(result).not.toBeNull(); expect(result!.fromProtocol).toBe('Stellar DEX'); expect(result!.toProtocol).toBe('Blend'); - expect(result!.txHash).toBeDefined(); + expect(result!.txHash).toBe('real-rebalance-hash-001'); + expect(mockContractTriggerRebalance).toHaveBeenCalledWith('Blend', 800); + expect(mockTransactionCreate).toHaveBeenCalledWith( + expect.objectContaining({ + data: expect.objectContaining({ + type: 'REBALANCE', + status: 'PENDING', + txHash: 'real-rebalance-hash-001', + protocolName: 'Blend', + }), + }), + ); }); it('sums amounts across multiple positions before cost calculation', async () => { @@ -209,4 +249,37 @@ describe('Agent Router', () => { expect(result).toBeNull(); }); }); + + describe('triggerRebalance()', () => { + it('converts APY percent to basis points before invoking the vault contract', async () => { + mockApy.mockResolvedValue(2.0); + mockScan.mockResolvedValue([blendProtocol]); + + const result = await triggerRebalance( + 'Stellar DEX', + 'Blend', + '100000000000000000000000', + ['pos-1'], + ); + + expect(result?.txHash).toBe('real-rebalance-hash-001'); + expect(mockContractTriggerRebalance).toHaveBeenCalledWith('Blend', 800); + }); + + it('skips transaction persistence when no representative position is found', async () => { + mockApy.mockResolvedValue(2.0); + mockScan.mockResolvedValue([blendProtocol]); + mockFindFirst.mockResolvedValueOnce(null); + + const result = await triggerRebalance( + 'Stellar DEX', + 'Blend', + '100000000000000000000000', + ['missing-pos'], + ); + + expect(result?.txHash).toBe('real-rebalance-hash-001'); + expect(mockTransactionCreate).not.toHaveBeenCalled(); + }); + }); });