11use crate :: hex_utils;
2+ use crate :: io_utils:: KVStoreUnpersister ;
23use crate :: Error ;
34
45use lightning:: ln:: { PaymentHash , PaymentPreimage , PaymentSecret } ;
56use lightning:: util:: persist:: KVStorePersister ;
67use lightning:: { impl_writeable_tlv_based, impl_writeable_tlv_based_enum} ;
78
8- use std:: collections:: HashMap ;
9+ use std:: collections:: hash_map;
10+ use std:: collections:: { HashMap , HashSet } ;
911use std:: iter:: FromIterator ;
1012use std:: ops:: Deref ;
11- use std:: sync:: Mutex ;
13+ use std:: sync:: { Mutex , MutexGuard } ;
1214
1315/// Represents a payment.
1416#[ derive( Clone , Debug , PartialEq , Eq ) ]
@@ -70,21 +72,26 @@ impl_writeable_tlv_based_enum!(PaymentStatus,
7072/// The payment information will be persisted under this prefix.
7173pub ( crate ) const PAYMENT_INFO_PERSISTENCE_PREFIX : & str = "payments" ;
7274
73- pub ( crate ) struct PaymentInfoStorage < K : Deref >
75+ pub ( crate ) struct PaymentInfoStorage < K : Deref + Clone >
7476where
75- K :: Target : KVStorePersister ,
77+ K :: Target : KVStorePersister + KVStoreUnpersister ,
7678{
7779 payments : Mutex < HashMap < PaymentHash , PaymentInfo > > ,
7880 persister : K ,
7981}
8082
81- impl < K : Deref > PaymentInfoStorage < K >
83+ impl < K : Deref + Clone > PaymentInfoStorage < K >
8284where
83- K :: Target : KVStorePersister ,
85+ K :: Target : KVStorePersister + KVStoreUnpersister ,
8486{
85- pub ( crate ) fn from_payments ( mut payments : Vec < PaymentInfo > , persister : K ) -> Self {
87+ pub ( crate ) fn new ( persister : K ) -> Self {
88+ let payments = Mutex :: new ( HashMap :: new ( ) ) ;
89+ Self { payments, persister }
90+ }
91+
92+ pub ( crate ) fn from_payments ( payments : Vec < PaymentInfo > , persister : K ) -> Self {
8693 let payments = Mutex :: new ( HashMap :: from_iter (
87- payments. drain ( .. ) . map ( |payment_info| ( payment_info. payment_hash , payment_info) ) ,
94+ payments. into_iter ( ) . map ( |payment_info| ( payment_info. payment_hash , payment_info) ) ,
8895 ) ) ;
8996 Self { payments, persister }
9097 }
@@ -106,9 +113,20 @@ where
106113 return Ok ( ( ) ) ;
107114 }
108115
109- // TODO: Need an `unpersist` method for this?
110- //pub(crate) fn remove_payment(&self, payment_hash: &PaymentHash) -> Result<(), Error> {
111- //}
116+ pub ( crate ) fn lock ( & self ) -> Result < PaymentInfoGuard < K > , ( ) > {
117+ let locked_store = self . payments . lock ( ) . map_err ( |_| ( ) ) ?;
118+ Ok ( PaymentInfoGuard :: new ( locked_store, self . persister . clone ( ) ) )
119+ }
120+
121+ pub ( crate ) fn remove ( & self , payment_hash : & PaymentHash ) -> Result < ( ) , Error > {
122+ let key = format ! (
123+ "{}/{}" ,
124+ PAYMENT_INFO_PERSISTENCE_PREFIX ,
125+ hex_utils:: to_string( & payment_hash. 0 )
126+ ) ;
127+ self . persister . unpersist ( & key) . map_err ( |_| Error :: PersistenceFailed ) ?;
128+ Ok ( ( ) )
129+ }
112130
113131 pub ( crate ) fn get ( & self , payment_hash : & PaymentHash ) -> Option < PaymentInfo > {
114132 self . payments . lock ( ) . unwrap ( ) . get ( payment_hash) . cloned ( )
@@ -136,3 +154,79 @@ where
136154 Ok ( ( ) )
137155 }
138156}
157+
158+ pub ( crate ) struct PaymentInfoGuard < ' a , K : Deref >
159+ where
160+ K :: Target : KVStorePersister + KVStoreUnpersister ,
161+ {
162+ inner : MutexGuard < ' a , HashMap < PaymentHash , PaymentInfo > > ,
163+ touched_keys : HashSet < PaymentHash > ,
164+ persister : K ,
165+ }
166+
167+ impl < ' a , K : Deref > PaymentInfoGuard < ' a , K >
168+ where
169+ K :: Target : KVStorePersister + KVStoreUnpersister ,
170+ {
171+ pub fn new ( inner : MutexGuard < ' a , HashMap < PaymentHash , PaymentInfo > > , persister : K ) -> Self {
172+ let touched_keys = HashSet :: new ( ) ;
173+ Self { inner, touched_keys, persister }
174+ }
175+
176+ pub fn entry (
177+ & mut self , payment_hash : PaymentHash ,
178+ ) -> hash_map:: Entry < PaymentHash , PaymentInfo > {
179+ self . touched_keys . insert ( payment_hash) ;
180+ self . inner . entry ( payment_hash)
181+ }
182+ }
183+
184+ impl < ' a , K : Deref > Drop for PaymentInfoGuard < ' a , K >
185+ where
186+ K :: Target : KVStorePersister + KVStoreUnpersister ,
187+ {
188+ fn drop ( & mut self ) {
189+ for key in self . touched_keys . iter ( ) {
190+ let store_key =
191+ format ! ( "{}/{}" , PAYMENT_INFO_PERSISTENCE_PREFIX , hex_utils:: to_string( & key. 0 ) ) ;
192+
193+ match self . inner . entry ( * key) {
194+ hash_map:: Entry :: Vacant ( _) => {
195+ self . persister . unpersist ( & store_key) . expect ( "Persistence failed" ) ;
196+ }
197+ hash_map:: Entry :: Occupied ( e) => {
198+ self . persister . persist ( & store_key, e. get ( ) ) . expect ( "Persistence failed" ) ;
199+ }
200+ } ;
201+ }
202+ }
203+ }
204+
205+ #[ cfg( test) ]
206+ mod tests {
207+ use super :: * ;
208+ use crate :: tests:: test_utils:: TestPersister ;
209+ use std:: sync:: Arc ;
210+
211+ #[ test]
212+ fn persistence_guard_persists_on_drop ( ) {
213+ let persister = Arc :: new ( TestPersister :: new ( ) ) ;
214+ let payment_info_store = PaymentInfoStorage :: new ( Arc :: clone ( & persister) ) ;
215+
216+ let payment_hash = PaymentHash ( [ 42u8 ; 32 ] ) ;
217+ assert ! ( !payment_info_store. contains( & payment_hash) ) ;
218+
219+ let payment_info = PaymentInfo {
220+ payment_hash,
221+ preimage : None ,
222+ secret : None ,
223+ amount_msat : None ,
224+ direction : PaymentDirection :: Inbound ,
225+ status : PaymentStatus :: Pending ,
226+ } ;
227+
228+ assert ! ( !persister. get_and_clear_did_persist( ) ) ;
229+ payment_info_store. lock ( ) . unwrap ( ) . entry ( payment_hash) . or_insert ( payment_info) ;
230+ assert ! ( persister. get_and_clear_did_persist( ) ) ;
231+ }
232+ }
0 commit comments