diff --git a/amm/src/add.rs b/amm/src/add.rs index 91f5962..1dff79e 100644 --- a/amm/src/add.rs +++ b/amm/src/add.rs @@ -1,6 +1,9 @@ use std::num::NonZeroU128; -use amm_core::{assert_supported_fee_tier, compute_liquidity_token_pda_seed, PoolDefinition}; +use amm_core::{ + assert_supported_fee_tier, compute_liquidity_token_pda_seed, read_vault_fungible_balances, + PoolDefinition, +}; use nssa_core::{ account::{AccountWithMetadata, Data}, program::{AccountPostState, ChainedCall}, @@ -44,33 +47,9 @@ pub fn add_liquidity( "Both max-balances must be nonzero" ); - // 2. Determine deposit amount - let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data) - .expect("Add liquidity: AMM Program expects valid Token Holding Account for Vault B"); - let token_core::TokenHolding::Fungible { - definition_id: _, - balance: vault_b_balance, - } = vault_b_token_holding - else { - panic!( - "Add liquidity: AMM Program expects valid Fungible Token Holding Account for Vault B" - ); - }; - - let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data) - .expect("Add liquidity: AMM Program expects valid Token Holding Account for Vault A"); - let token_core::TokenHolding::Fungible { - definition_id: _, - balance: vault_a_balance, - } = vault_a_token_holding - else { - panic!( - "Add liquidity: AMM Program expects valid Fungible Token Holding Account for Vault A" - ); - }; + let (vault_a_balance, vault_b_balance) = + read_vault_fungible_balances("Add liquidity", &vault_a, &vault_b); - assert!(pool_def_data.reserve_a != 0, "Reserves must be nonzero"); - assert!(pool_def_data.reserve_b != 0, "Reserves must be nonzero"); assert!( vault_a_balance >= pool_def_data.reserve_a, "Vaults' balances must be at least the reserve amounts" @@ -80,7 +59,10 @@ pub fn add_liquidity( "Vaults' balances must be at least the reserve amounts" ); - // Calculate actual_amounts + // 2. Determine deposit amount + assert!(pool_def_data.reserve_a != 0, "Reserves must be nonzero"); + assert!(pool_def_data.reserve_b != 0, "Reserves must be nonzero"); + let ideal_a: u128 = pool_def_data .reserve_a .checked_mul(max_amount_to_add_token_b) diff --git a/amm/src/swap.rs b/amm/src/swap.rs index fefcd2c..51c4d80 100644 --- a/amm/src/swap.rs +++ b/amm/src/swap.rs @@ -1,4 +1,6 @@ -use amm_core::{assert_supported_fee_tier, MINIMUM_LIQUIDITY}; +use amm_core::{ + assert_supported_fee_tier, read_vault_fungible_balances, FEE_BPS_DENOMINATOR, MINIMUM_LIQUIDITY, +}; pub use amm_core::{compute_liquidity_token_pda_seed, compute_vault_pda_seed, PoolDefinition}; use nssa_core::{ account::{AccountId, AccountWithMetadata, Data}, @@ -28,31 +30,13 @@ fn validate_swap_setup( "Vault B was not provided" ); - let vault_a_token_holding = token_core::TokenHolding::try_from(&vault_a.account.data) - .expect("AMM Program expects a valid Token Holding Account for Vault A"); - let token_core::TokenHolding::Fungible { - definition_id: _, - balance: vault_a_balance, - } = vault_a_token_holding - else { - panic!("AMM Program expects a valid Fungible Token Holding Account for Vault A"); - }; + let (vault_a_balance, vault_b_balance) = + read_vault_fungible_balances("Validate swap setup", vault_a, vault_b); assert!( vault_a_balance >= pool_def_data.reserve_a, "Reserve for Token A exceeds vault balance" ); - - let vault_b_token_holding = token_core::TokenHolding::try_from(&vault_b.account.data) - .expect("AMM Program expects a valid Token Holding Account for Vault B"); - let token_core::TokenHolding::Fungible { - definition_id: _, - balance: vault_b_balance, - } = vault_b_token_holding - else { - panic!("AMM Program expects a valid Fungible Token Holding Account for Vault B"); - }; - assert!( vault_b_balance >= pool_def_data.reserve_b, "Reserve for Token B exceeds vault balance" @@ -130,6 +114,7 @@ pub fn swap_exact_input( user_holding_b.clone(), swap_amount_in, min_amount_out, + pool_def_data.fees, pool_def_data.reserve_a, pool_def_data.reserve_b, pool.account_id, @@ -144,6 +129,7 @@ pub fn swap_exact_input( user_holding_a.clone(), swap_amount_in, min_amount_out, + pool_def_data.fees, pool_def_data.reserve_b, pool_def_data.reserve_a, pool.account_id, @@ -178,19 +164,29 @@ fn swap_logic( user_withdraw: AccountWithMetadata, swap_amount_in: u128, min_amount_out: u128, + fee_bps: u128, reserve_deposit_vault_amount: u128, reserve_withdraw_vault_amount: u128, pool_id: AccountId, ) -> (Vec, u128, u128) { - // Compute withdraw amount - // Maintains pool constant product - // k = pool_def_data.reserve_a * pool_def_data.reserve_b; + let effective_amount_in = swap_amount_in + .checked_mul(FEE_BPS_DENOMINATOR - fee_bps) + .expect("swap_amount_in * (FEE_BPS_DENOMINATOR - fee_bps) overflows u128") + / FEE_BPS_DENOMINATOR; + assert!( + effective_amount_in != 0, + "Effective swap amount should be nonzero" + ); + // Compute the withdraw amount using the fee-adjusted input for pricing. + // The recorded pool reserves are updated later with the full + // `swap_amount_in`, so LP fees accrue inside `reserve_*` via invariant + // growth rather than as a separate vault balance surplus over `reserve_*`. let withdraw_amount = reserve_withdraw_vault_amount - .checked_mul(swap_amount_in) - .expect("reserve * amount_in overflows u128") + .checked_mul(effective_amount_in) + .expect("reserve * effective_amount_in overflows u128") / reserve_deposit_vault_amount - .checked_add(swap_amount_in) - .expect("reserve + swap_amount_in overflows u128"); + .checked_add(effective_amount_in) + .expect("reserve + effective_amount_in overflows u128"); // Slippage check assert!( @@ -259,6 +255,7 @@ pub fn swap_exact_output( max_amount_in, pool_def_data.reserve_a, pool_def_data.reserve_b, + pool_def_data.fees, pool.account_id, ); @@ -273,6 +270,7 @@ pub fn swap_exact_output( max_amount_in, pool_def_data.reserve_b, pool_def_data.reserve_a, + pool_def_data.fees, pool.account_id, ); @@ -307,6 +305,7 @@ fn exact_output_swap_logic( max_amount_in: u128, reserve_deposit_vault_amount: u128, reserve_withdraw_vault_amount: u128, + fee_bps: u128, pool_id: AccountId, ) -> (Vec, u128, u128) { // Guard: exact_amount_out must be nonzero @@ -318,12 +317,28 @@ fn exact_output_swap_logic( "Exact amount out exceeds reserve" ); - // Compute deposit amount using ceiling division - // Formula: amount_in = ceil(reserve_in * exact_amount_out / (reserve_out - exact_amount_out)) - let deposit_amount = reserve_deposit_vault_amount + // Compute the minimum effective input required to achieve exact_amount_out + // using the same floor-rounded fee application as swap_exact_input. + // + // Solve constant product for effective_in (fee already removed): + // effective_in >= ceil(reserve_in * amount_out / (reserve_out - amount_out)) + let effective_in_numerator = reserve_deposit_vault_amount .checked_mul(exact_amount_out) - .expect("reserve * amount_out overflows u128") - .div_ceil(reserve_withdraw_vault_amount - exact_amount_out); + .expect("reserve * amount_out overflows u128"); + let effective_in_denominator = reserve_withdraw_vault_amount + .checked_sub(exact_amount_out) + .expect("reserve_out - amount_out underflows"); + let effective_in_min = effective_in_numerator.div_ceil(effective_in_denominator); + + // Lift back to gross input so that + // floor(gross_in * (FEE_DENOM - fee) / FEE_DENOM) >= effective_in_min + let fee_multiplier = FEE_BPS_DENOMINATOR + .checked_sub(fee_bps) + .expect("fee_bps exceeds fee denominator"); + let deposit_amount = effective_in_min + .checked_mul(FEE_BPS_DENOMINATOR) + .expect("effective_in * FEE_DENOM overflows u128") + .div_ceil(fee_multiplier); // Slippage check assert!( diff --git a/amm/src/tests.rs b/amm/src/tests.rs index 922d35a..96cb61b 100644 --- a/amm/src/tests.rs +++ b/amm/src/tests.rs @@ -4,8 +4,9 @@ use std::num::NonZero; use amm_core::{ compute_liquidity_token_pda, compute_liquidity_token_pda_seed, compute_lp_lock_holding_pda, - compute_pool_pda, compute_vault_pda, compute_vault_pda_seed, PoolDefinition, FEE_TIER_BPS_1, - FEE_TIER_BPS_100, FEE_TIER_BPS_30, FEE_TIER_BPS_5, MINIMUM_LIQUIDITY, + compute_pool_pda, compute_vault_pda, compute_vault_pda_seed, PoolDefinition, + FEE_BPS_DENOMINATOR, FEE_TIER_BPS_1, FEE_TIER_BPS_100, FEE_TIER_BPS_30, FEE_TIER_BPS_5, + MINIMUM_LIQUIDITY, }; use nssa_core::{ account::{Account, AccountId, AccountWithMetadata, Data, Nonce}, @@ -103,6 +104,16 @@ impl BalanceForTests { 200 } + fn effective_swap_in_a() -> u128 { + BalanceForTests::add_max_amount_a() * (FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier()) + / FEE_BPS_DENOMINATOR + } + + fn effective_swap_in_b() -> u128 { + BalanceForTests::add_max_amount_b() * (FEE_BPS_DENOMINATOR - BalanceForTests::fee_tier()) + / FEE_BPS_DENOMINATOR + } + fn add_max_amount_a_low() -> u128 { 10 } @@ -178,13 +189,13 @@ impl BalanceForTests { } fn swap_amount_out_b() -> u128 { - (BalanceForTests::vault_b_reserve_init() * BalanceForTests::add_max_amount_a()) - / (BalanceForTests::vault_a_reserve_init() + BalanceForTests::add_max_amount_a()) + (BalanceForTests::vault_b_reserve_init() * BalanceForTests::effective_swap_in_a()) + / (BalanceForTests::vault_a_reserve_init() + BalanceForTests::effective_swap_in_a()) } fn swap_amount_out_a() -> u128 { - (BalanceForTests::vault_a_reserve_init() * BalanceForTests::add_max_amount_b()) - / (BalanceForTests::vault_b_reserve_init() + BalanceForTests::add_max_amount_b()) + (BalanceForTests::vault_a_reserve_init() * BalanceForTests::effective_swap_in_b()) + / (BalanceForTests::vault_b_reserve_init() + BalanceForTests::effective_swap_in_b()) } fn add_delta_lp_successful() -> u128 { @@ -276,7 +287,10 @@ impl ChainedCallForTests { } fn cc_swap_exact_output_token_a_test_1() -> ChainedCall { - let swap_amount: u128 = 498; + // reserve_in=1000, amount_out=166, fee=30bps + // required_effective_in = ceil(1000 * 166 / 334) = 498 + // deposit = ceil(498 * 10000 / 9970) = 500 + let swap_amount: u128 = 500; ChainedCall::new( TOKEN_PROGRAM_ID, @@ -329,7 +343,10 @@ impl ChainedCallForTests { } fn cc_swap_exact_output_token_b_test_2() -> ChainedCall { - let swap_amount: u128 = 200; + // reserve_in=500, amount_out=285, fee=30bps + // required_effective_in = ceil(500 * 285 / 715) = 200 + // deposit = ceil(200 * 10000 / 9970) = 201 + let swap_amount: u128 = 201; ChainedCall::new( TOKEN_PROGRAM_ID, @@ -343,6 +360,36 @@ impl ChainedCallForTests { ) } + fn cc_swap_rounding_boundary_token_a_in() -> ChainedCall { + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![ + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::vault_a_init(), + ], + &token_core::Instruction::Transfer { + amount_to_transfer: 3, + }, + ) + } + + fn cc_swap_rounding_boundary_token_b_out() -> ChainedCall { + let mut vault_b_auth = AccountWithMetadataForTests::vault_b_init(); + vault_b_auth.is_authorized = true; + + ChainedCall::new( + TOKEN_PROGRAM_ID, + vec![vault_b_auth, AccountWithMetadataForTests::user_holding_b()], + &token_core::Instruction::Transfer { + amount_to_transfer: 1, + }, + ) + .with_pda_seeds(vec![compute_vault_pda_seed( + IdForTests::pool_definition_id(), + IdForTests::token_b_definition_id(), + )]) + } + fn cc_add_token_a() -> ChainedCall { ChainedCall::new( TOKEN_PROGRAM_ID, @@ -885,6 +932,29 @@ impl AccountWithMetadataForTests { } } + fn pool_definition_swap_rounding_boundary_init() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: MINIMUM_LIQUIDITY, + reserve_a: 1_000, + reserve_b: 1_000, + fees: FEE_TIER_BPS_30, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + fn pool_definition_init_reserve_a_zero() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -1024,6 +1094,9 @@ impl AccountWithMetadataForTests { } fn pool_definition_swap_exact_output_test_1() -> AccountWithMetadata { + // swap token_a in for 166 token_b out, fee=30bps + // reserve_a: 1000 + 500 = 1500 (gross deposit, see + // cc_swap_exact_output_token_a_test_1) reserve_b: 500 - 166 = 334 AccountWithMetadata { account: Account { program_owner: ProgramId::default(), @@ -1035,7 +1108,7 @@ impl AccountWithMetadataForTests { vault_b_id: IdForTests::vault_b_id(), liquidity_pool_id: IdForTests::token_lp_definition_id(), liquidity_pool_supply: BalanceForTests::lp_supply_init(), - reserve_a: 1498_u128, + reserve_a: 1500_u128, reserve_b: 334_u128, fees: BalanceForTests::fee_tier(), }), @@ -1059,7 +1132,7 @@ impl AccountWithMetadataForTests { liquidity_pool_id: IdForTests::token_lp_definition_id(), liquidity_pool_supply: BalanceForTests::lp_supply_init(), reserve_a: 715_u128, - reserve_b: 700_u128, + reserve_b: 701_u128, fees: BalanceForTests::fee_tier(), }), nonce: Nonce(0), @@ -1069,6 +1142,29 @@ impl AccountWithMetadataForTests { } } + fn pool_definition_swap_rounding_boundary_post() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0_u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: MINIMUM_LIQUIDITY, + reserve_a: 1003_u128, + reserve_b: 999_u128, + fees: FEE_TIER_BPS_30, + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + fn pool_definition_add_zero_lp() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -1115,6 +1211,29 @@ impl AccountWithMetadataForTests { } } + fn pool_definition_init_low_balances() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: ProgramId::default(), + balance: 0u128, + data: Data::from(&PoolDefinition { + definition_token_a_id: IdForTests::token_a_definition_id(), + definition_token_b_id: IdForTests::token_b_definition_id(), + vault_a_id: IdForTests::vault_a_id(), + vault_b_id: IdForTests::vault_b_id(), + liquidity_pool_id: IdForTests::token_lp_definition_id(), + liquidity_pool_supply: MINIMUM_LIQUIDITY, + reserve_a: BalanceForTests::vault_a_reserve_low(), + reserve_b: BalanceForTests::vault_b_reserve_low(), + fees: BalanceForTests::fee_tier(), + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::pool_definition_id(), + } + } + fn pool_definition_remove_successful() -> AccountWithMetadata { AccountWithMetadata { account: Account { @@ -1342,6 +1461,40 @@ fn test_call_add_liquidity_zero_balance_2() { ); } +#[should_panic(expected = "Vaults' balances must be at least the reserve amounts")] +#[test] +fn test_call_add_liquidity_vault_a_balance_below_reserve() { + let _post_states = add_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init_low(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::add_max_amount_b(), + ); +} + +#[should_panic(expected = "Vaults' balances must be at least the reserve amounts")] +#[test] +fn test_call_add_liquidity_vault_b_balance_below_reserve() { + let _post_states = add_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init_low(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::add_max_amount_b(), + ); +} + #[should_panic(expected = "Vaults' balances must be at least the reserve amounts")] #[test] fn test_call_add_liquidity_vault_insufficient_balance_1() { @@ -1353,9 +1506,9 @@ fn test_call_add_liquidity_vault_insufficient_balance_1() { AccountWithMetadataForTests::user_holding_a(), AccountWithMetadataForTests::user_holding_b(), AccountWithMetadataForTests::user_holding_lp_init(), - NonZero::new(BalanceForTests::add_max_amount_a()).unwrap(), + NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(), + BalanceForTests::add_max_amount_a(), BalanceForTests::add_max_amount_b(), - BalanceForTests::add_min_amount_lp(), ); } @@ -1370,9 +1523,9 @@ fn test_call_add_liquidity_vault_insufficient_balance_2() { AccountWithMetadataForTests::user_holding_a(), AccountWithMetadataForTests::user_holding_b(), AccountWithMetadataForTests::user_holding_lp_init(), - NonZero::new(BalanceForTests::add_max_amount_a()).unwrap(), + NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(), + BalanceForTests::add_max_amount_a(), BalanceForTests::add_max_amount_b(), - BalanceForTests::add_min_amount_lp(), ); } @@ -2052,6 +2205,84 @@ fn test_call_swap_below_min_out() { ); } +#[should_panic(expected = "Effective swap amount should be nonzero")] +#[test] +fn test_call_swap_effective_amount_zero() { + let _post_states = swap_exact_input( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 1, + 0, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Withdraw amount should be nonzero")] +#[test] +fn test_call_swap_output_rounds_to_zero() { + let _post_states = swap_exact_input( + AccountWithMetadataForTests::pool_definition_init_low_balances(), + AccountWithMetadataForTests::vault_a_init_low(), + AccountWithMetadataForTests::vault_b_init_low(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 2, + 0, + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "Withdraw amount is less than minimal amount out")] +#[test] +fn test_call_swap_exact_input_rejects_amount_that_rounds_down_below_target_output() { + let _post_states = swap_exact_input( + AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 2, + 1, + IdForTests::token_a_definition_id(), + ); +} + +#[test] +fn test_call_swap_exact_input_accepts_smallest_amount_for_rounded_boundary() { + let (post_states, chained_calls) = swap_exact_input( + AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 3, + 1, + IdForTests::token_a_definition_id(), + ); + + let pool_post = post_states[0].clone(); + + assert_eq!( + AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_post().account, + *pool_post.account() + ); + + let chained_call_a = chained_calls[0].clone(); + let chained_call_b = chained_calls[1].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_swap_rounding_boundary_token_a_in() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_swap_rounding_boundary_token_b_out() + ); +} + #[test] fn test_call_swap_chained_call_successful_1() { let (post_states, chained_calls) = swap_exact_input( @@ -2317,6 +2548,74 @@ fn call_swap_exact_output_chained_call_successful_2() { ); } +// The minimum effective input for exact_amount_out=166 on the 1000/500 pool is 498. +// After fee rounding, the true minimum gross input is 500, so 499 must be rejected. +#[should_panic(expected = "Required input exceeds maximum amount in")] +#[test] +fn call_swap_exact_output_fee_enforced() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_swap_exact_output_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 166_u128, // exact_amount_out: token_b + 499_u128, // max_amount_in: still one short after fee rounding + IdForTests::token_a_definition_id(), + ); +} + +// On a 1000/1000 pool at 0.3%, exact_amount_out = 1 requires gross input 3. +// max_amount_in = 2 must be rejected because the exact-input path would round +// 2 down to effective_in = 1 and still produce 0 output. +#[should_panic(expected = "Required input exceeds maximum amount in")] +#[test] +fn call_swap_exact_output_rejects_max_in_that_rounds_down_below_target_output() { + let _post_states = swap_exact_output( + AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 1, + 2, + IdForTests::token_a_definition_id(), + ); +} + +#[test] +fn call_swap_exact_output_accepts_smallest_max_in_for_rounded_boundary() { + let (post_states, chained_calls) = swap_exact_output( + AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b(), + 1, + 3, + IdForTests::token_a_definition_id(), + ); + + let pool_post = post_states[0].clone(); + + assert_eq!( + AccountWithMetadataForTests::pool_definition_swap_rounding_boundary_post().account, + *pool_post.account() + ); + + let chained_call_a = chained_calls[0].clone(); + let chained_call_b = chained_calls[1].clone(); + + assert_eq!( + chained_call_a, + ChainedCallForTests::cc_swap_rounding_boundary_token_a_in() + ); + assert_eq!( + chained_call_b, + ChainedCallForTests::cc_swap_rounding_boundary_token_b_out() + ); +} + // Without the fix, `reserve_a * exact_amount_out` silently wraps to 0 in release mode, // making `deposit_amount = 0`. The slippage check `0 <= max_amount_in` always passes, // so an attacker receives `exact_amount_out` tokens while paying nothing. @@ -2869,7 +3168,7 @@ fn remove_liquidity_overflow_protection() { ); } -#[should_panic(expected = "reserve * amount_in overflows u128")] +#[should_panic(expected = "reserve * effective_amount_in overflows u128")] #[test] fn swap_exact_input_overflow_protection() { let large_reserve: u128 = u128::MAX / 2 + 1; @@ -2924,7 +3223,8 @@ fn swap_exact_input_overflow_protection() { account_id: IdForTests::vault_b_id(), }; - // Swap token_a in: withdraw_amount = reserve_b * swap_amount_in / (reserve_a + swap_amount_in) + // Swap token_a in: withdraw_amount = reserve_b * effective_amount_in / (reserve_a + + // effective_amount_in) With fee_bps=30: effective_amount_in = 3 * 9970 / 10000 = 2 // reserve_b is large, so reserve_b * 2 overflows let _result = swap_exact_input( pool, @@ -2932,7 +3232,7 @@ fn swap_exact_input_overflow_protection() { vault_b, AccountWithMetadataForTests::user_holding_a(), AccountWithMetadataForTests::user_holding_b(), - 2, + 3, 1, IdForTests::token_a_definition_id(), ); diff --git a/integration_tests/tests/amm.rs b/integration_tests/tests/amm.rs index a28dfde..9378a8d 100644 --- a/integration_tests/tests/amm.rs +++ b/integration_tests/tests/amm.rs @@ -164,8 +164,16 @@ impl Balances { 200 } + fn reserve_a_swap_1() -> u128 { + 3_575 + } + + fn reserve_b_swap_1() -> u128 { + 3_500 + } + fn vault_a_swap_1() -> u128 { - 3_572 + 3_575 } fn vault_b_swap_1() -> u128 { @@ -173,19 +181,27 @@ impl Balances { } fn user_a_swap_1() -> u128 { - 11_428 + 11_425 } fn user_b_swap_1() -> u128 { 9_000 } + fn reserve_a_swap_2() -> u128 { + 6_000 + } + + fn reserve_b_swap_2() -> u128 { + 2_085 + } + fn vault_a_swap_2() -> u128 { 6_000 } fn vault_b_swap_2() -> u128 { - 2_084 + 2_085 } fn user_a_swap_2() -> u128 { @@ -193,7 +209,7 @@ impl Balances { } fn user_b_swap_2() -> u128 { - 10_416 + 10_415 } fn vault_a_add() -> u128 { @@ -405,8 +421,8 @@ impl Accounts { vault_b_id: Ids::vault_b(), liquidity_pool_id: Ids::token_lp_definition(), liquidity_pool_supply: Balances::pool_lp_supply_init(), - reserve_a: Balances::vault_a_swap_1(), - reserve_b: Balances::vault_b_swap_1(), + reserve_a: Balances::reserve_a_swap_1(), + reserve_b: Balances::reserve_b_swap_1(), fees: Balances::fee_tier(), }), nonce: Nonce(0), @@ -472,8 +488,8 @@ impl Accounts { vault_b_id: Ids::vault_b(), liquidity_pool_id: Ids::token_lp_definition(), liquidity_pool_supply: Balances::pool_lp_supply_init(), - reserve_a: Balances::vault_a_swap_2(), - reserve_b: Balances::vault_b_swap_2(), + reserve_a: Balances::reserve_a_swap_2(), + reserve_b: Balances::reserve_b_swap_2(), fees: Balances::fee_tier(), }), nonce: Nonce(0), @@ -918,6 +934,10 @@ fn state_for_amm_tests_with_new_def() -> V03State { state } +fn current_nonce(state: &V03State, account_id: AccountId) -> Nonce { + state.get_account_by_id(account_id).nonce +} + fn try_execute_new_definition(state: &mut V03State, fees: u128) -> Result<(), NssaError> { let instruction = amm_core::Instruction::NewDefinition { token_a_amount: Balances::vault_a_init(), @@ -938,7 +958,10 @@ fn try_execute_new_definition(state: &mut V03State, fees: u128) -> Result<(), Ns Ids::user_b(), Ids::user_lp(), ], - vec![Nonce(0), Nonce(0)], + vec![ + current_nonce(state, Ids::user_a()), + current_nonce(state, Ids::user_b()), + ], instruction, ) .unwrap(); @@ -954,6 +977,163 @@ fn execute_new_definition(state: &mut V03State, fees: u128) { try_execute_new_definition(state, fees).unwrap(); } +fn execute_swap_a_to_b(state: &mut V03State, swap_amount_in: u128, min_amount_out: u128) { + let instruction = amm_core::Instruction::SwapExactInput { + swap_amount_in, + min_amount_out, + token_definition_id_in: Ids::token_a_definition(), + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::user_a(), + Ids::user_b(), + ], + vec![current_nonce(state, Ids::user_a())], + instruction, + ) + .unwrap(); + + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_a()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); +} + +fn execute_swap_b_to_a(state: &mut V03State, swap_amount_in: u128, min_amount_out: u128) { + let instruction = amm_core::Instruction::SwapExactInput { + swap_amount_in, + min_amount_out, + token_definition_id_in: Ids::token_b_definition(), + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::user_a(), + Ids::user_b(), + ], + vec![current_nonce(state, Ids::user_b())], + instruction, + ) + .unwrap(); + + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_b()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); +} + +fn execute_add_liquidity( + state: &mut V03State, + min_amount_liquidity: u128, + max_amount_to_add_token_a: u128, + max_amount_to_add_token_b: u128, +) { + let instruction = amm_core::Instruction::AddLiquidity { + min_amount_liquidity, + max_amount_to_add_token_a, + max_amount_to_add_token_b, + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::token_lp_definition(), + Ids::user_a(), + Ids::user_b(), + Ids::user_lp(), + ], + vec![ + current_nonce(state, Ids::user_a()), + current_nonce(state, Ids::user_b()), + ], + instruction, + ) + .unwrap(); + + let witness_set = + public_transaction::WitnessSet::for_message(&message, &[&Keys::user_a(), &Keys::user_b()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); +} + +fn execute_remove_liquidity( + state: &mut V03State, + remove_liquidity_amount: u128, + min_amount_to_remove_token_a: u128, + min_amount_to_remove_token_b: u128, +) { + let instruction = amm_core::Instruction::RemoveLiquidity { + remove_liquidity_amount, + min_amount_to_remove_token_a, + min_amount_to_remove_token_b, + }; + + let message = public_transaction::Message::try_new( + Ids::amm_program(), + vec![ + Ids::pool_definition(), + Ids::vault_a(), + Ids::vault_b(), + Ids::token_lp_definition(), + Ids::user_a(), + Ids::user_b(), + Ids::user_lp(), + ], + vec![current_nonce(state, Ids::user_lp())], + instruction, + ) + .unwrap(); + + let witness_set = public_transaction::WitnessSet::for_message(&message, &[&Keys::user_lp()]); + + let tx = PublicTransaction::new(message, witness_set); + state.transition_from_public_transaction(&tx, 0).unwrap(); +} + +fn fungible_balance(account: &Account) -> u128 { + let holding = TokenHolding::try_from(&account.data).expect("expected token holding"); + let TokenHolding::Fungible { + definition_id: _, + balance, + } = holding + else { + panic!("expected fungible token holding") + }; + + balance +} + +fn pool_definition(account: &Account) -> PoolDefinition { + PoolDefinition::try_from(&account.data).expect("expected pool definition") +} + +fn fungible_total_supply(account: &Account) -> u128 { + let definition = TokenDefinition::try_from(&account.data).expect("expected token definition"); + let TokenDefinition::Fungible { + name: _, + total_supply, + metadata_id: _, + } = definition + else { + panic!("expected fungible token definition") + }; + + total_supply +} + #[test] fn amm_remove_liquidity() { let mut state = state_for_amm_tests(); @@ -1322,3 +1502,102 @@ fn amm_swap_a_to_b() { Accounts::user_b_holding_swap_2() ); } + +#[test] +fn amm_fee_accumulates_across_multiple_swaps_and_pays_out_on_remove() { + let mut state = state_for_amm_tests(); + + execute_swap_a_to_b(&mut state, 1_000, 200); + execute_swap_b_to_a(&mut state, 1_000, 200); + + let pool_before_remove = pool_definition(&state.get_account_by_id(Ids::pool_definition())); + assert_eq!(pool_before_remove.reserve_a, 4_060); + assert_eq!(pool_before_remove.reserve_b, 3_085); + assert_eq!(pool_before_remove.fees, Balances::fee_tier()); + + let vault_a_before_remove = fungible_balance(&state.get_account_by_id(Ids::vault_a())); + let vault_b_before_remove = fungible_balance(&state.get_account_by_id(Ids::vault_b())); + assert_eq!(vault_a_before_remove, 4_060); + assert_eq!(vault_b_before_remove, 3_085); + assert_eq!(vault_a_before_remove, pool_before_remove.reserve_a); + assert_eq!(vault_b_before_remove, pool_before_remove.reserve_b); + + execute_remove_liquidity(&mut state, 1_000, 812, 617); + + let pool_after_remove = pool_definition(&state.get_account_by_id(Ids::pool_definition())); + assert_eq!(pool_after_remove.reserve_a, 3_248); + assert_eq!(pool_after_remove.reserve_b, 2_468); + assert_eq!(pool_after_remove.liquidity_pool_supply, 4_000); + + let vault_a_after_remove = fungible_balance(&state.get_account_by_id(Ids::vault_a())); + let vault_b_after_remove = fungible_balance(&state.get_account_by_id(Ids::vault_b())); + assert_eq!(vault_a_after_remove, 3_248); + assert_eq!(vault_b_after_remove, 2_468); + assert_eq!(vault_a_after_remove, pool_after_remove.reserve_a); + assert_eq!(vault_b_after_remove, pool_after_remove.reserve_b); + + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_a())), + 11_752 + ); + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_b())), + 10_032 + ); + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_lp())), + 1_000 + ); + assert_eq!( + fungible_total_supply(&state.get_account_by_id(Ids::token_lp_definition())), + 4_000 + ); +} + +#[test] +fn amm_add_liquidity_after_fee_accrual() { + let mut state = state_for_amm_tests(); + + execute_swap_a_to_b(&mut state, 1_000, 200); + execute_swap_b_to_a(&mut state, 1_000, 200); + execute_swap_a_to_b(&mut state, 1_000, 200); + execute_swap_b_to_a(&mut state, 1_000, 200); + + let pool_before_add = pool_definition(&state.get_account_by_id(Ids::pool_definition())); + let vault_a_before_add = fungible_balance(&state.get_account_by_id(Ids::vault_a())); + let vault_b_before_add = fungible_balance(&state.get_account_by_id(Ids::vault_b())); + + assert_eq!(pool_before_add.reserve_a, 3_608); + assert_eq!(pool_before_add.reserve_b, 3_477); + assert_eq!(vault_a_before_add, pool_before_add.reserve_a); + assert_eq!(vault_b_before_add, pool_before_add.reserve_b); + + execute_add_liquidity(&mut state, 1_436, 2_000, 1_000); + + let pool_after_add = pool_definition(&state.get_account_by_id(Ids::pool_definition())); + let vault_a_after_add = fungible_balance(&state.get_account_by_id(Ids::vault_a())); + let vault_b_after_add = fungible_balance(&state.get_account_by_id(Ids::vault_b())); + + assert_eq!(pool_after_add.reserve_a, 4_645); + assert_eq!(pool_after_add.reserve_b, 4_477); + assert_eq!(pool_after_add.liquidity_pool_supply, 6_437); + assert_eq!(vault_a_after_add, pool_after_add.reserve_a); + assert_eq!(vault_b_after_add, pool_after_add.reserve_b); + + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_a())), + 10_355 + ); + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_b())), + 8_023 + ); + assert_eq!( + fungible_balance(&state.get_account_by_id(Ids::user_lp())), + 3_437 + ); + assert_eq!( + fungible_total_supply(&state.get_account_by_id(Ids::token_lp_definition())), + 6_437 + ); +}