Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 65 additions & 13 deletions src/agent/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
* Router - Compares APYs and triggers rebalancing when conditions are met
*/

import { Decimal } from '@prisma/client/runtime/library';
import { PrismaClient } from '@prisma/client';
import { logger } from '../utils/logger';
import { ProtocolComparison, RebalanceDetails, RebalanceThresholds } from './types';
import { scanAllProtocols, getCurrentOnChainApy } from './scanner';
import { triggerRebalance as submitRebalance } from '../stellar/contract';

const prisma = new PrismaClient();

Expand All @@ -15,6 +15,14 @@ const DEFAULT_THRESHOLDS: RebalanceThresholds = {
maxGasPercent: 0.1,
};

function toApyBasisPoints(apyPercent: number): number {
if (!Number.isFinite(apyPercent) || apyPercent < 0) {
throw new Error('APY must be a non-negative number');
}

return Math.round(apyPercent * 100);
}

/**
* Estimate transaction costs for a rebalance
* Accounts for gas fees and potential DEX slippage
Expand Down Expand Up @@ -119,33 +127,76 @@ export async function compareProtocols(
export async function triggerRebalance(
fromProtocol: string,
toProtocol: string,
amount: string
amount: string,
positionIds: string[] = [],
): Promise<RebalanceDetails | null> {
const startTime = Date.now();

try {
const comparison = await compareProtocols(fromProtocol, amount);
if (!comparison) {
throw new Error(`Unable to compare protocols for ${fromProtocol}`);
}

const expectedApyBasisPoints = toApyBasisPoints(comparison.best.apy);

logger.info('Rebalance triggered', {
fromProtocol,
toProtocol,
amount,
expectedApyBasisPoints,
});

// TODO: Call actual smart contract to execute rebalance
// This would interact with the Stellar Soroban vault contract
// const txHash = await executeRebalanceOnChain(fromProtocol, toProtocol, amount);
const onChainTransaction = await submitRebalance(
toProtocol,
expectedApyBasisPoints,
);

const mockTxHash = `mock_tx_${Date.now()}`;
if (positionIds.length > 0) {
const representativePosition = await prisma.position.findFirst({
where: {
id: { in: positionIds },
},
include: {
user: {
select: {
network: true,
},
},
},
});

const comparison = await compareProtocols(fromProtocol);
const improvement = comparison ? comparison.improvement : 0;
if (representativePosition) {
await prisma.transaction.create({
data: {
userId: representativePosition.userId,
positionId: representativePosition.id,
txHash: onChainTransaction.hash,
type: 'REBALANCE',
status: 'PENDING',
assetSymbol: representativePosition.assetSymbol,
amount,
network: representativePosition.user.network,
protocolName: toProtocol,
memo: `Agent rebalance from ${fromProtocol} to ${toProtocol}`,
} as any,
});
} else {
logger.warn('No position found to persist rebalance transaction', {
fromProtocol,
toProtocol,
positionIds,
});
}
}

const rebalanceDetail: RebalanceDetails = {
fromProtocol,
toProtocol,
amount,
txHash: mockTxHash,
txHash: onChainTransaction.hash,
timestamp: new Date(),
improvedBy: improvement,
improvedBy: comparison.improvement,
};

const duration = Date.now() - startTime;
Expand All @@ -156,9 +207,9 @@ export async function triggerRebalance(
});

logger.info('Rebalance successful', {
txHash: mockTxHash,
txHash: onChainTransaction.hash,
duration,
improvedBy: improvement.toFixed(2),
improvedBy: comparison.improvement.toFixed(2),
});

return rebalanceDetail;
Expand Down Expand Up @@ -217,7 +268,8 @@ export async function executeRebalanceIfNeeded(
return await triggerRebalance(
currentProtocol,
comparison.best.name,
totalAmount
totalAmount,
userPositions.map(pos => pos.id),
);
} catch (error) {
logger.error('Rebalance execution check failed', {
Expand Down
3 changes: 3 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { logger } from './utils/logger'
import { startAgentLoop } from './agent/loop'
import { connectDb } from './db'
import { scheduleSessionCleanup } from './jobs/sessionCleanup'
import { startEventListener } from './stellar/events'
import healthRouter from './routes/health'
import agentRouter from './routes/agent'
import authRouter from './routes/auth'
Expand Down Expand Up @@ -57,6 +58,8 @@ async function main() {
logger.info(`Network: ${config.stellar.network}`)

try {
await startEventListener()
logger.info('Vault event listener started')
await startAgentLoop()
} catch (error) {
logger.error('Failed to start agent loop', {
Expand Down
77 changes: 75 additions & 2 deletions tests/unit/agent/router.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,38 @@ jest.mock('../../../src/agent/scanner', () => ({
getBestProtocol: jest.fn(),
}));

const mockFindFirst = jest.fn();
const mockTransactionCreate = jest.fn();
const mockAgentLogCreate = jest.fn().mockResolvedValue({});
const mockContractTriggerRebalance = jest.fn();

// Mock Prisma used by logAgentAction
jest.mock('@prisma/client', () => ({
PrismaClient: jest.fn().mockImplementation(() => ({
user: {
findMany: jest.fn().mockResolvedValue([{ id: 'test-user-id' }]),
},
position: {
findFirst: (...args: unknown[]) => mockFindFirst(...args),
},
transaction: {
create: (...args: unknown[]) => mockTransactionCreate(...args),
},
agentLog: {
create: jest.fn().mockResolvedValue({}),
create: (...args: unknown[]) => mockAgentLogCreate(...args),
},
})),
}));

jest.mock('../../../src/stellar/contract', () => ({
triggerRebalance: (...args: unknown[]) => mockContractTriggerRebalance(...args),
}));

import {
compareProtocols,
executeRebalanceIfNeeded,
getThresholds,
triggerRebalance,
} from '../../../src/agent/router';
import {
scanAllProtocols,
Expand Down Expand Up @@ -63,6 +79,19 @@ const marginalProtocol = {
describe('Agent Router', () => {
beforeEach(() => {
jest.clearAllMocks();
mockFindFirst.mockResolvedValue({
id: 'pos-1',
userId: 'test-user-id',
assetSymbol: 'USDC',
user: { network: 'TESTNET' },
});
mockTransactionCreate.mockResolvedValue({});
mockAgentLogCreate.mockResolvedValue({});
mockContractTriggerRebalance.mockResolvedValue({
hash: 'real-rebalance-hash-001',
status: 'success',
ledger: 77,
});
});

// ── compareProtocols ──────────────────────────────────────────────────────
Expand Down Expand Up @@ -187,7 +216,18 @@ describe('Agent Router', () => {
expect(result).not.toBeNull();
expect(result!.fromProtocol).toBe('Stellar DEX');
expect(result!.toProtocol).toBe('Blend');
expect(result!.txHash).toBeDefined();
expect(result!.txHash).toBe('real-rebalance-hash-001');
expect(mockContractTriggerRebalance).toHaveBeenCalledWith('Blend', 800);
expect(mockTransactionCreate).toHaveBeenCalledWith(
expect.objectContaining({
data: expect.objectContaining({
type: 'REBALANCE',
status: 'PENDING',
txHash: 'real-rebalance-hash-001',
protocolName: 'Blend',
}),
}),
);
});

it('sums amounts across multiple positions before cost calculation', async () => {
Expand All @@ -209,4 +249,37 @@ describe('Agent Router', () => {
expect(result).toBeNull();
});
});

describe('triggerRebalance()', () => {
it('converts APY percent to basis points before invoking the vault contract', async () => {
mockApy.mockResolvedValue(2.0);
mockScan.mockResolvedValue([blendProtocol]);

const result = await triggerRebalance(
'Stellar DEX',
'Blend',
'100000000000000000000000',
['pos-1'],
);

expect(result?.txHash).toBe('real-rebalance-hash-001');
expect(mockContractTriggerRebalance).toHaveBeenCalledWith('Blend', 800);
});

it('skips transaction persistence when no representative position is found', async () => {
mockApy.mockResolvedValue(2.0);
mockScan.mockResolvedValue([blendProtocol]);
mockFindFirst.mockResolvedValueOnce(null);

const result = await triggerRebalance(
'Stellar DEX',
'Blend',
'100000000000000000000000',
['missing-pos'],
);

expect(result?.txHash).toBe('real-rebalance-hash-001');
expect(mockTransactionCreate).not.toHaveBeenCalled();
});
});
});
Loading