Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion amm/src/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ pub fn add_liquidity(
"Vault B was not provided"
);

let token_program_id = vault_a.account.program_owner;
assert_eq!(
user_holding_a.account.program_owner, token_program_id,
"User Token A holding must be owned by the vault's Token Program"
);
assert_eq!(
user_holding_b.account.program_owner, token_program_id,
"User Token B holding must be owned by the vault's Token Program"
);

assert!(
max_amount_to_add_token_a != 0 && max_amount_to_add_token_b != 0,
"Both max-balances must be nonzero"
Expand Down Expand Up @@ -138,7 +148,6 @@ pub fn add_liquidity(
};

pool_post.data = Data::from(&pool_post_definition);
let token_program_id = user_holding_a.account.program_owner;

// Chain call for Token A (UserHoldingA -> Vault_A)
let call_token_a = ChainedCall::new(
Expand Down
12 changes: 10 additions & 2 deletions amm/src/remove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ pub fn remove_liquidity(
"Vault B was not provided"
);

let token_program_id = vault_a.account.program_owner;
assert_eq!(
user_holding_a.account.program_owner, token_program_id,
"User Token A holding must be owned by the vault's Token Program"
);
assert_eq!(
user_holding_b.account.program_owner, token_program_id,
"User Token B holding must be owned by the vault's Token Program"
);

// Vault addresses do not need to be checked with PDA
// calculation for setting authorization since stored
// in the Pool Definition.
Expand Down Expand Up @@ -143,8 +153,6 @@ pub fn remove_liquidity(

pool_post.data = Data::from(&pool_post_definition);

let token_program_id = user_holding_a.account.program_owner;

// Chaincall for Token A withdraw
let call_token_a = ChainedCall::new(
token_program_id,
Expand Down
20 changes: 20 additions & 0 deletions amm/src/swap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ pub fn swap_exact_input(
) -> (Vec<AccountPostState>, Vec<ChainedCall>) {
let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b);

let token_program_id = vault_a.account.program_owner;
assert_eq!(
user_holding_a.account.program_owner, token_program_id,
"User Token A holding must be owned by the vault's Token Program"
);
assert_eq!(
user_holding_b.account.program_owner, token_program_id,
"User Token B holding must be owned by the vault's Token Program"
);

let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) =
if token_in_id == pool_def_data.definition_token_a_id {
let (chained_calls, deposit_a, withdraw_b) = swap_logic(
Expand Down Expand Up @@ -244,6 +254,16 @@ pub fn swap_exact_output(
) -> (Vec<AccountPostState>, Vec<ChainedCall>) {
let pool_def_data = validate_swap_setup(&pool, &vault_a, &vault_b);

let token_program_id = vault_a.account.program_owner;
assert_eq!(
user_holding_a.account.program_owner, token_program_id,
"User Token A holding must be owned by the vault's Token Program"
);
assert_eq!(
user_holding_b.account.program_owner, token_program_id,
"User Token B holding must be owned by the vault's Token Program"
);

let (chained_calls, [deposit_a, withdraw_a], [deposit_b, withdraw_b]) =
if token_in_id == pool_def_data.definition_token_a_id {
let (chained_calls, deposit_a, withdraw_b) = exact_output_swap_logic(
Expand Down
167 changes: 167 additions & 0 deletions amm/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::{

const TOKEN_PROGRAM_ID: ProgramId = [15; 8];
const AMM_PROGRAM_ID: ProgramId = [42; 8];
const MALICIOUS_TOKEN_PROGRAM_ID: ProgramId = [99; 8];

struct BalanceForTests;
struct ChainedCallForTests;
Expand Down Expand Up @@ -1346,6 +1347,38 @@ impl AccountWithMetadataForTests {
}
}

fn user_holding_a_wrong_program() -> AccountWithMetadata {
AccountWithMetadata {
account: Account {
program_owner: MALICIOUS_TOKEN_PROGRAM_ID,
balance: 0u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_a_definition_id(),
balance: BalanceForTests::user_token_a_balance(),
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::user_token_a_id(),
}
}

fn user_holding_b_wrong_program() -> AccountWithMetadata {
AccountWithMetadata {
account: Account {
program_owner: MALICIOUS_TOKEN_PROGRAM_ID,
balance: 0u128,
data: Data::from(&TokenHolding::Fungible {
definition_id: IdForTests::token_b_definition_id(),
balance: BalanceForTests::user_token_b_balance(),
}),
nonce: Nonce(0),
},
is_authorized: true,
account_id: IdForTests::user_token_b_id(),
}
}

/// Legacy/corrupted pool state whose reported supply has already been drained down to the
/// permanent lock (liquidity_pool_supply == MINIMUM_LIQUIDITY).
fn pool_definition_at_minimum_liquidity() -> AccountWithMetadata {
Expand Down Expand Up @@ -3303,3 +3336,137 @@ fn test_new_definition_rejects_unsupported_fee_tier() {
AMM_PROGRAM_ID,
);
}

// --- Token program ownership validation tests ---

#[should_panic(expected = "User Token A holding must be owned by the vault's Token Program")]
#[test]
fn test_add_liquidity_rejects_user_holding_a_wrong_program() {
let _ = add_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a_wrong_program(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(),
BalanceForTests::add_max_amount_a(),
BalanceForTests::add_max_amount_b(),
);
}

#[should_panic(expected = "User Token B holding must be owned by the vault's Token Program")]
#[test]
fn test_add_liquidity_rejects_user_holding_b_wrong_program() {
let _ = add_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b_wrong_program(),
AccountWithMetadataForTests::user_holding_lp_init(),
NonZero::new(BalanceForTests::add_min_amount_lp()).unwrap(),
BalanceForTests::add_max_amount_a(),
BalanceForTests::add_max_amount_b(),
);
}

#[should_panic(expected = "User Token A holding must be owned by the vault's Token Program")]
#[test]
fn test_remove_liquidity_rejects_user_holding_a_wrong_program() {
let _ = remove_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a_wrong_program(),
AccountWithMetadataForTests::user_holding_b(),
AccountWithMetadataForTests::user_holding_lp_with_balance(
BalanceForTests::remove_amount_lp(),
),
NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(),
BalanceForTests::remove_min_amount_a(),
BalanceForTests::remove_min_amount_b_low(),
);
}

#[should_panic(expected = "User Token B holding must be owned by the vault's Token Program")]
#[test]
fn test_remove_liquidity_rejects_user_holding_b_wrong_program() {
let _ = remove_liquidity(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::pool_lp_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b_wrong_program(),
AccountWithMetadataForTests::user_holding_lp_with_balance(
BalanceForTests::remove_amount_lp(),
),
NonZero::new(BalanceForTests::remove_amount_lp()).unwrap(),
BalanceForTests::remove_min_amount_a(),
BalanceForTests::remove_min_amount_b_low(),
);
}

#[should_panic(expected = "User Token A holding must be owned by the vault's Token Program")]
#[test]
fn test_swap_exact_input_rejects_user_holding_a_wrong_program() {
let _ = swap_exact_input(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a_wrong_program(),
AccountWithMetadataForTests::user_holding_b(),
BalanceForTests::add_max_amount_a(),
BalanceForTests::min_amount_out(),
IdForTests::token_a_definition_id(),
);
}

#[should_panic(expected = "User Token B holding must be owned by the vault's Token Program")]
#[test]
fn test_swap_exact_input_rejects_user_holding_b_wrong_program() {
let _ = swap_exact_input(
AccountWithMetadataForTests::pool_definition_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b_wrong_program(),
BalanceForTests::add_max_amount_a(),
BalanceForTests::min_amount_out(),
IdForTests::token_a_definition_id(),
);
}

#[should_panic(expected = "User Token A holding must be owned by the vault's Token Program")]
#[test]
fn test_swap_exact_output_rejects_user_holding_a_wrong_program() {
let _ = swap_exact_output(
AccountWithMetadataForTests::pool_definition_swap_exact_output_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a_wrong_program(),
AccountWithMetadataForTests::user_holding_b(),
166,
BalanceForTests::max_amount_in(),
IdForTests::token_a_definition_id(),
);
}

#[should_panic(expected = "User Token B holding must be owned by the vault's Token Program")]
#[test]
fn test_swap_exact_output_rejects_user_holding_b_wrong_program() {
let _ = swap_exact_output(
AccountWithMetadataForTests::pool_definition_swap_exact_output_init(),
AccountWithMetadataForTests::vault_a_init(),
AccountWithMetadataForTests::vault_b_init(),
AccountWithMetadataForTests::user_holding_a(),
AccountWithMetadataForTests::user_holding_b_wrong_program(),
166,
BalanceForTests::max_amount_in(),
IdForTests::token_a_definition_id(),
);
}
Loading