@@ -67,13 +67,21 @@ impl State {
6767 maybe_tolerance : Option < FixedPoint < U256 > > ,
6868 maybe_max_iterations : Option < usize > ,
6969 ) -> Result < FixedPoint < U256 > > {
70- // SGA parameters.
71- let mut learning_rate = fixed ! ( 1e17 ) ;
70+ // SGD parameters.
71+ let mut learning_rate = fixed ! ( 1e18 ) ;
7272 let max_iterations = maybe_max_iterations. unwrap_or ( 1_000 ) ;
73- let tolerance = maybe_tolerance. unwrap_or ( fixed ! ( 1e9 ) ) ;
73+ let tolerance = maybe_tolerance. unwrap_or ( fixed ! ( 1e5 ) ) ;
7474 // Start with a conservative estimate of the bonds shorted & base paid.
7575 let conservative_price = self . calculate_conservative_short_price ( base_deposit_amount) ?;
7676 // let close_vault_share_price = open_vault_share_price.max(self.vault_share_price());
77+ // We ignore the short principal so that we have a closed-form inversion
78+ // of the short deposit equation.
79+ // let mut last_good_bond_amount = base_deposit_amount
80+ // / (((fixed!(1e18) / self.vault_share_price())
81+ // * (close_vault_share_price / open_vault_share_price)
82+ // + self.flat_fee())
83+ // + self.curve_fee() * (fixed!(1e18) - conservative_price));
84+ let mut last_good_bond_amount = self . minimum_transaction_amount ( ) ;
7785 let mut last_good_base_amount =
7886 self . calculate_open_short ( last_good_bond_amount, open_vault_share_price) ?;
7987 println ! ( "base_deposit_amount = {:#?}" , base_deposit_amount) ;
@@ -97,49 +105,69 @@ impl State {
97105 // Within tolerance, but bond amount must be >= 0.
98106 return Ok ( last_good_bond_amount. max ( fixed ! ( 0 ) ) . change_type :: < U256 > ( ) ?) ;
99107 }
100- // Run Stochastic Gradient Ascent to iteratively adjust the bond amount.
101- // We know the conservative bond amount is less than the target, so we can assert gradient ascent.
102- for _ in 0 ..max_iterations {
108+ // Run Stochastic Gradient Descent to adjust the bond amount.
109+ for iter in 0 ..max_iterations {
110+ println ! ( "\n ----\n iter {:#?}" , iter) ;
111+ println ! ( "learning_rate {:#?}" , learning_rate) ;
103112 // Calculate the current deposit.
104113 let base_amount =
105114 self . calculate_open_short ( last_good_bond_amount, open_vault_share_price) ?;
106115 // Calculate the current gradient.
107116 let base_amount_derivative = self . calculate_open_short_derivative (
108117 last_good_bond_amount,
109118 open_vault_share_price,
110- None ,
119+ Some ( self . calculate_spot_price ( ) ? ) ,
111120 ) ?;
112- // If we overshot here, we set the error to zero and then the new_bond_amount = last_good_bond_amount.
113- let error = if base_amount > base_deposit_amount {
114- base_amount - base_deposit_amount
121+ // If we overshot here, we set the error to zero and then the
122+ // new_bond_amount = last_good_bond_amount.
123+ println ! ( "last_good_bond_amount={:#?}" , last_good_bond_amount) ;
124+ println ! ( "base_amount={:#?}" , base_amount) ;
125+ println ! ( "base_deposit_amount={:#?}" , base_deposit_amount) ;
126+ let error = if base_amount < base_deposit_amount {
127+ base_deposit_amount - base_amount
115128 } else {
116129 fixed ! ( 0 )
117130 } ;
131+ println ! ( "error={:#?}" , error) ;
118132 // Calculate the new bond amount.
133+ // The update rule is: x_1 = x_0 - \eta * L(y,y_t) * dL(y,y_t)/dx,
134+ // where \eta is the learning rate, L is the error, y is
135+ // open_short(x), and y_t is the target deposit. The derivative of
136+ // L(y,y_t) wrt x is -base_amount_derivative. So we add here instead
137+ // of subtracting a negative.
119138 let new_bond_amount =
120139 last_good_bond_amount + learning_rate * error * base_amount_derivative;
121140 // If we overshot, lower the learning rate and try again.
122- // Otherwise, check convergence to either return or continue iterating .
141+ // Otherwise, check convergence to either return or continue.
123142 match self . calculate_open_short ( new_bond_amount, open_vault_share_price) {
124143 Ok ( new_base_amount) => {
125- println ! ( "new_base_amount {:#?}" , new_base_amount) ;
144+ println ! (
145+ "new_base_amount={:#?}; base_deposit_amount={:#?}" ,
146+ new_base_amount, base_deposit_amount
147+ ) ;
126148 if new_base_amount > base_deposit_amount {
127- learning_rate = ( learning_rate / fixed ! ( 100e18 ) ) . max ( fixed ! ( 1 ) ) ;
149+ let error_magnitude = new_base_amount / base_deposit_amount;
150+ println ! ( "error_magnitude={:#?}" , error_magnitude) ;
151+ // If the values are too close then the error magnitude
152+ // will round to 1.0 and the rate will not change.
153+ learning_rate = if error_magnitude <= fixed ! ( 1e18 ) {
154+ ( learning_rate / fixed ! ( 2e18 ) ) . max ( fixed ! ( 1 ) )
155+ } else {
156+ ( learning_rate / error_magnitude) . max ( fixed ! ( 1 ) )
157+ } ;
128158 } else {
129159 last_good_bond_amount = new_bond_amount;
130160 last_good_base_amount = new_base_amount;
131161 // Check for convergence.
132162 if ( base_deposit_amount - last_good_base_amount) <= tolerance {
133163 // Within tolerance, but bond amount must be >= 0.
134- return Ok ( last_good_bond_amount
135- . max ( fixed ! ( 0 ) )
136- . change_type :: < U256 > ( ) ?) ;
164+ return Ok ( last_good_bond_amount. max ( fixed ! ( 0 ) ) ) ;
137165 }
138166 // Amount was good but we did not converge; keep going.
139167 }
140168 }
141169 Err ( _) => {
142- learning_rate = ( learning_rate / fixed ! ( 10 ) ) . max ( fixed ! ( 1 ) ) ;
170+ learning_rate = ( learning_rate / fixed ! ( 10e18 ) ) . max ( fixed ! ( 1 ) ) ;
143171 }
144172 }
145173 }
@@ -692,18 +720,28 @@ mod tests {
692720
693721 #[ tokio:: test]
694722 async fn fuzz_calculate_short_bonds_given_deposit ( ) -> Result < ( ) > {
695- let test_tolerance = fixed ! ( 1e9 ) ;
723+ let test_tolerance = fixed ! ( 1e6 ) ;
724+ let max_iterations = 10_000 ;
696725 let mut rng = thread_rng ( ) ;
697- for _ in 0 ..* FUZZ_RUNS {
726+ for _ in 0 ..* SLOW_FUZZ_RUNS {
698727 let state = rng. gen :: < State > ( ) ;
699728 let open_vault_share_price = rng. gen_range ( fixed ! ( 0 ) ..=state. vault_share_price ( ) ) ;
729+ let checkpoint_exposure = {
730+ let value = rng. gen_range ( fixed ! ( 0 ) ..=FixedPoint :: from ( U256 :: from ( U128 :: MAX ) ) ) ;
731+ if rng. gen ( ) {
732+ -I256 :: try_from ( value) ?
733+ } else {
734+ I256 :: try_from ( value) ?
735+ }
736+ } ;
737+ let max_short_trade = get_max_short ( state. clone ( ) , checkpoint_exposure, None ) ?;
700738 let target_base_amount =
701- rng. gen_range ( state. minimum_transaction_amount ( ) ..=fixed ! ( 10_000e18 ) ) ;
739+ rng. gen_range ( state. minimum_transaction_amount ( ) ..=max_short_trade ) ;
702740 let bond_amount = state. calculate_short_bonds_given_deposit (
703741 target_base_amount,
704742 open_vault_share_price,
705743 Some ( test_tolerance) ,
706- None ,
744+ Some ( max_iterations ) ,
707745 ) ?;
708746 let computed_base_amount =
709747 state. calculate_open_short ( bond_amount, open_vault_share_price) ?;
0 commit comments