diff --git a/amm/src/add.rs b/amm/src/add.rs index 1dff79e..7c18726 100644 --- a/amm/src/add.rs +++ b/amm/src/add.rs @@ -42,6 +42,16 @@ pub fn add_liquidity( "Vault B was not provided" ); + let token_program_id = vault_a.account.program_owner; + assert_eq!( + user_holding_a.account.program_owner, token_program_id, + "User Token A holding must be owned by the vault's Token Program" + ); + assert_eq!( + user_holding_b.account.program_owner, token_program_id, + "User Token B holding must be owned by the vault's Token Program" + ); + assert!( max_amount_to_add_token_a != 0 && max_amount_to_add_token_b != 0, "Both max-balances must be nonzero" @@ -138,7 +148,6 @@ pub fn add_liquidity( }; pool_post.data = Data::from(&pool_post_definition); - let token_program_id = user_holding_a.account.program_owner; // Chain call for Token A (UserHoldingA -> Vault_A) let call_token_a = ChainedCall::new( diff --git a/amm/src/remove.rs b/amm/src/remove.rs index 951c639..67e33b4 100644 --- a/amm/src/remove.rs +++ b/amm/src/remove.rs @@ -46,6 +46,16 @@ pub fn remove_liquidity( "Vault B was not provided" ); + let token_program_id = vault_a.account.program_owner; + assert_eq!( + user_holding_a.account.program_owner, token_program_id, + "User Token A holding must be owned by the vault's Token Program" + ); + assert_eq!( + user_holding_b.account.program_owner, token_program_id, + "User Token B holding must be owned by the vault's Token Program" + ); + // Vault addresses do not need to be checked with PDA // calculation for setting authorization since stored // in the Pool Definition. @@ -143,8 +153,6 @@ pub fn remove_liquidity( pool_post.data = Data::from(&pool_post_definition); - let token_program_id = user_holding_a.account.program_owner; - // Chaincall for Token A withdraw let call_token_a = ChainedCall::new( token_program_id, diff --git a/amm/src/swap.rs b/amm/src/swap.rs index 51c4d80..e8b38a0 100644 --- a/amm/src/swap.rs +++ b/amm/src/swap.rs @@ -105,6 +105,16 @@ pub fn swap_exact_input( ) -> (Vec, Vec) { let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b); + let token_program_id = vault_a.account.program_owner; + assert_eq!( + user_holding_a.account.program_owner, token_program_id, + "User Token A holding must be owned by the vault's Token Program" + ); + assert_eq!( + user_holding_b.account.program_owner, token_program_id, + "User Token B holding must be owned by the vault's Token Program" + ); + let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) = if token_in_id == pool_def_data.definition_token_a_id { let (chained_calls, deposit_a, withdraw_b) = swap_logic( @@ -244,6 +254,16 @@ pub fn swap_exact_output( ) -> (Vec, Vec) { let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b); + let token_program_id = vault_a.account.program_owner; + assert_eq!( + user_holding_a.account.program_owner, token_program_id, + "User Token A holding must be owned by the vault's Token Program" + ); + assert_eq!( + user_holding_b.account.program_owner, token_program_id, + "User Token B holding must be owned by the vault's Token Program" + ); + let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) = if token_in_id == pool_def_data.definition_token_a_id { let (chained_calls, deposit_a, withdraw_b) = exact_output_swap_logic( diff --git a/amm/src/tests.rs b/amm/src/tests.rs index 5f7f433..c6d12a0 100644 --- a/amm/src/tests.rs +++ b/amm/src/tests.rs @@ -24,6 +24,7 @@ use crate::{ const TOKEN_PROGRAM_ID: ProgramId = [15; 8]; const AMM_PROGRAM_ID: ProgramId = [42; 8]; +const MALICIOUS_TOKEN_PROGRAM_ID: ProgramId = [99; 8]; struct BalanceForTests; struct ChainedCallForTests; @@ -1346,6 +1347,38 @@ impl AccountWithMetadataForTests { } } + fn user_holding_a_wrong_program() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: MALICIOUS_TOKEN_PROGRAM_ID, + balance: 0u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_a_definition_id(), + balance: BalanceForTests::user_token_a_balance(), + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::user_token_a_id(), + } + } + + fn user_holding_b_wrong_program() -> AccountWithMetadata { + AccountWithMetadata { + account: Account { + program_owner: MALICIOUS_TOKEN_PROGRAM_ID, + balance: 0u128, + data: Data::from(&TokenHolding::Fungible { + definition_id: IdForTests::token_b_definition_id(), + balance: BalanceForTests::user_token_b_balance(), + }), + nonce: Nonce(0), + }, + is_authorized: true, + account_id: IdForTests::user_token_b_id(), + } + } + /// Legacy/corrupted pool state whose reported supply has already been drained down to the /// permanent lock (liquidity_pool_supply == MINIMUM_LIQUIDITY). fn pool_definition_at_minimum_liquidity() -> AccountWithMetadata { @@ -3303,3 +3336,137 @@ fn test_new_definition_rejects_unsupported_fee_tier() { AMM_PROGRAM_ID, ); } + +// --- Token program ownership validation tests --- + +#[should_panic(expected = "User Token A holding must be owned by the vault's Token Program")] +#[test] +fn test_add_liquidity_rejects_user_holding_a_wrong_program() { + let _ = add_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a_wrong_program(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::add_max_amount_b(), + ); +} + +#[should_panic(expected = "User Token B holding must be owned by the vault's Token Program")] +#[test] +fn test_add_liquidity_rejects_user_holding_b_wrong_program() { + let _ = add_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b_wrong_program(), + AccountWithMetadataForTests::user_holding_lp_init(), + NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::add_max_amount_b(), + ); +} + +#[should_panic(expected = "User Token A holding must be owned by the vault's Token Program")] +#[test] +fn test_remove_liquidity_rejects_user_holding_a_wrong_program() { + let _ = remove_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a_wrong_program(), + AccountWithMetadataForTests::user_holding_b(), + AccountWithMetadataForTests::user_holding_lp_with_balance( + BalanceForTests::remove_amount_lp(), + ), + NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(), + BalanceForTests::remove_min_amount_a(), + BalanceForTests::remove_min_amount_b_low(), + ); +} + +#[should_panic(expected = "User Token B holding must be owned by the vault's Token Program")] +#[test] +fn test_remove_liquidity_rejects_user_holding_b_wrong_program() { + let _ = remove_liquidity( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::pool_lp_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b_wrong_program(), + AccountWithMetadataForTests::user_holding_lp_with_balance( + BalanceForTests::remove_amount_lp(), + ), + NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(), + BalanceForTests::remove_min_amount_a(), + BalanceForTests::remove_min_amount_b_low(), + ); +} + +#[should_panic(expected = "User Token A holding must be owned by the vault's Token Program")] +#[test] +fn test_swap_exact_input_rejects_user_holding_a_wrong_program() { + let _ = swap_exact_input( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a_wrong_program(), + AccountWithMetadataForTests::user_holding_b(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::min_amount_out(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "User Token B holding must be owned by the vault's Token Program")] +#[test] +fn test_swap_exact_input_rejects_user_holding_b_wrong_program() { + let _ = swap_exact_input( + AccountWithMetadataForTests::pool_definition_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b_wrong_program(), + BalanceForTests::add_max_amount_a(), + BalanceForTests::min_amount_out(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "User Token A holding must be owned by the vault's Token Program")] +#[test] +fn test_swap_exact_output_rejects_user_holding_a_wrong_program() { + let _ = swap_exact_output( + AccountWithMetadataForTests::pool_definition_swap_exact_output_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a_wrong_program(), + AccountWithMetadataForTests::user_holding_b(), + 166, + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +} + +#[should_panic(expected = "User Token B holding must be owned by the vault's Token Program")] +#[test] +fn test_swap_exact_output_rejects_user_holding_b_wrong_program() { + let _ = swap_exact_output( + AccountWithMetadataForTests::pool_definition_swap_exact_output_init(), + AccountWithMetadataForTests::vault_a_init(), + AccountWithMetadataForTests::vault_b_init(), + AccountWithMetadataForTests::user_holding_a(), + AccountWithMetadataForTests::user_holding_b_wrong_program(), + 166, + BalanceForTests::max_amount_in(), + IdForTests::token_a_definition_id(), + ); +}