@@ -116,7 +116,7 @@ def decode_silent_payment_address(address: str, hrp: str = "tsp") -> Tuple[ECPub
116116 return B_scan , B_spend
117117
118118
119- def create_outputs (input_priv_keys : List [Tuple [ECKey , bool ]], outpoints : List [COutPoint ], recipients : List [str ], hrp = "tsp" ) -> List [str ]:
119+ def create_outputs (input_priv_keys : List [Tuple [ECKey , bool ]], outpoints : List [COutPoint ], recipients : List [str ], expected : Dict [ str , any ] = None , hrp = "tsp" ) -> List [str ]:
120120 G = ECKey ().set (1 ).get_pubkey ()
121121 negated_keys = []
122122 for key , is_xonly in input_priv_keys :
@@ -129,10 +129,16 @@ def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], outpoints: List[CO
129129 if not a_sum .valid :
130130 # Input privkeys sum is zero -> fail
131131 return []
132+ assert ECKey ().set (bytes .fromhex (expected .get ("input_private_key_sum" ))) == a_sum , "a_sum did not match expected input_private_key_sum"
132133 input_hash = get_input_hash (outpoints , a_sum * G )
133134 silent_payment_groups : Dict [ECPubKey , List [ECPubKey ]] = {}
134135 for recipient in recipients :
135- B_scan , B_m = decode_silent_payment_address (recipient , hrp = hrp )
136+ B_scan , B_m = decode_silent_payment_address (recipient ["address" ], hrp = hrp )
137+ # Verify decoded intermediate keys for recipient
138+ expected_B_scan = ECPubKey ().set (bytes .fromhex (recipient ["scan_pub_key" ]))
139+ expected_B_m = ECPubKey ().set (bytes .fromhex (recipient ["spend_pub_key" ]))
140+ assert expected_B_scan == B_scan , "B_scan did not match expected recipient.scan_pub_key"
141+ assert expected_B_m == B_m , "B_m did not match expected recipient.spend_pub_key"
136142 if B_scan in silent_payment_groups :
137143 silent_payment_groups [B_scan ].append (B_m )
138144 else :
@@ -141,6 +147,14 @@ def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], outpoints: List[CO
141147 outputs = []
142148 for B_scan , B_m_values in silent_payment_groups .items ():
143149 ecdh_shared_secret = input_hash * a_sum * B_scan
150+ expected_shared_secrets = expected .get ("shared_secrets" , {})
151+ # Find the recipient address that corresponds to this B_scan and get its index
152+ for recipient_idx , recipient in enumerate (recipients ):
153+ recipient_B_scan = ECPubKey ().set (bytes .fromhex (recipient ["scan_pub_key" ]))
154+ if recipient_B_scan == B_scan :
155+ expected_shared_secret_hex = expected_shared_secrets [recipient_idx ]
156+ assert ecdh_shared_secret .get_bytes (False ).hex () == expected_shared_secret_hex , f"ecdh_shared_secret did not match expected, recipient { recipient_idx } ({ recipient ['address' ]} ): expected={ expected_shared_secret_hex } "
157+ break
144158 k = 0
145159 for B_m in B_m_values :
146160 t_k = TaggedHash ("BIP0352/SharedSecret" , ecdh_shared_secret .get_bytes (False ) + ser_uint32 (k ))
@@ -151,9 +165,13 @@ def create_outputs(input_priv_keys: List[Tuple[ECKey, bool]], outpoints: List[CO
151165 return list (set (outputs ))
152166
153167
154- def scanning (b_scan : ECKey , B_spend : ECPubKey , A_sum : ECPubKey , input_hash : bytes , outputs_to_check : List [ECPubKey ], labels : Dict [str , str ] = {} ) -> List [Dict [str , str ]]:
168+ def scanning (b_scan : ECKey , B_spend : ECPubKey , A_sum : ECPubKey , input_hash : bytes , outputs_to_check : List [ECPubKey ], labels : Dict [str , str ] = None , expected : Dict [ str , any ] = None ) -> List [Dict [str , str ]]:
155169 G = ECKey ().set (1 ).get_pubkey ()
170+ input_hash_key = ECKey ().set (input_hash )
171+ computed_tweak_point = input_hash_key * A_sum
172+ assert computed_tweak_point .get_bytes (False ).hex () == expected .get ("tweak" ), "tweak did not match expected"
156173 ecdh_shared_secret = input_hash * b_scan * A_sum
174+ assert ecdh_shared_secret .get_bytes (False ).hex () == expected .get ("shared_secret" ), "ecdh_shared_secret did not match expected shared_secret"
157175 k = 0
158176 wallet = []
159177 while True :
@@ -236,11 +254,12 @@ def scanning(b_scan: ECKey, B_spend: ECPubKey, A_sum: ECPubKey, input_hash: byte
236254 is_p2tr (vin .prevout ),
237255 ))
238256 input_pub_keys .append (pubkey )
257+ assert [pk .get_bytes (False ).hex () for pk in input_pub_keys ] == expected .get ("input_pub_keys" ), "input_pub_keys did not match expected"
239258
240259 sending_outputs = []
241260 if (len (input_pub_keys ) > 0 ):
242261 outpoints = [vin .outpoint for vin in vins ]
243- sending_outputs = create_outputs (input_priv_keys , outpoints , given ["recipients" ], hrp = "sp" )
262+ sending_outputs = create_outputs (input_priv_keys , outpoints , given ["recipients" ], expected = expected , hrp = "sp" )
244263
245264 # Note: order doesn't matter for creating/finding the outputs. However, different orderings of the recipient addresses
246265 # will produce different generated outputs if sending to multiple silent payment addresses belonging to the
@@ -303,6 +322,7 @@ def scanning(b_scan: ECKey, B_spend: ECPubKey, A_sum: ECPubKey, input_hash: byte
303322 # Input pubkeys sum is point at infinity -> skip tx
304323 assert expected ["outputs" ] == []
305324 continue
325+ assert A_sum .get_bytes (False ).hex () == expected .get ("input_pub_key_sum" ), "A_sum did not match expected input_pub_key_sum"
306326 input_hash = get_input_hash ([vin .outpoint for vin in vins ], A_sum )
307327 pre_computed_labels = {
308328 (generate_label (b_scan , label ) * G ).get_bytes (False ).hex (): generate_label (b_scan , label ).hex ()
@@ -315,6 +335,7 @@ def scanning(b_scan: ECKey, B_spend: ECPubKey, A_sum: ECPubKey, input_hash: byte
315335 input_hash = input_hash ,
316336 outputs_to_check = outputs_to_check ,
317337 labels = pre_computed_labels ,
338+ expected = expected ,
318339 )
319340
320341 # Check that the private key is correct for the found output public key
0 commit comments