diff --git a/Cargo.lock b/Cargo.lock index 1f87d148a473..bf4678ef1127 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -553,6 +553,7 @@ dependencies = [ "cln-rpc", "hex", "log", + "paste", "rand 0.9.2", "serde", "serde_json", @@ -1763,6 +1764,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pem" version = "3.0.5" diff --git a/plugins/lsps-plugin/Cargo.toml b/plugins/lsps-plugin/Cargo.toml index 5ef5fa421e9b..7fd074f1c48b 100644 --- a/plugins/lsps-plugin/Cargo.toml +++ b/plugins/lsps-plugin/Cargo.toml @@ -20,8 +20,9 @@ cln-plugin = { workspace = true } cln-rpc = { workspace = true } hex = "0.4" log = "0.4" +paste = "1.0.15" rand = "0.9" serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["raw_value"] } thiserror = "2.0" tokio = { version = "1.44", features = ["full"] } diff --git a/plugins/lsps-plugin/src/client.rs b/plugins/lsps-plugin/src/client.rs index a466560f7377..df648ee52b8a 100644 --- a/plugins/lsps-plugin/src/client.rs +++ b/plugins/lsps-plugin/src/client.rs @@ -1,35 +1,43 @@ use anyhow::{anyhow, bail, Context}; use bitcoin::hashes::{hex::FromHex, sha256, Hash}; use chrono::{Duration, Utc}; -use cln_lsps::jsonrpc::client::JsonRpcClient; -use cln_lsps::lsps0::primitives::Msat; -use cln_lsps::lsps0::{ - self, - transport::{Bolt8Transport, CustomMessageHookManager, WithCustomMessageHookManager}, +use cln_lsps::{ + cln_adapters::{ + hooks, + sender::ClnSender, + state::ClientState, + types::{ + HtlcAcceptedRequest, HtlcAcceptedResponse, InvoicePaymentRequest, OpenChannelRequest, + }, + }, + core::{ + client::LspsClient, + features::is_feature_bit_set_reversed, + tlv::{encode_tu64, TLV_FORWARD_AMT, TLV_PAYMENT_SECRET}, + transport::{MultiplexedTransport, PendingRequests}, + }, + proto::{ + lsps0::{Msat, LSP_FEATURE_BIT}, + lsps2::{compute_opening_fee, Lsps2BuyResponse, Lsps2GetInfoResponse, OpeningFeeParams}, + }, }; -use cln_lsps::lsps2::cln::tlv::encode_tu64; -use cln_lsps::lsps2::cln::{ - HtlcAcceptedRequest, HtlcAcceptedResponse, InvoicePaymentRequest, OpenChannelRequest, - TLV_FORWARD_AMT, TLV_PAYMENT_SECRET, -}; -use cln_lsps::lsps2::model::{ - compute_opening_fee, Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, - Lsps2GetInfoResponse, OpeningFeeParams, -}; -use cln_lsps::util; -use cln_lsps::LSP_FEATURE_BIT; use cln_plugin::options; -use cln_rpc::model::requests::{ - DatastoreMode, DatastoreRequest, DeldatastoreRequest, DelinvoiceRequest, DelinvoiceStatus, - ListdatastoreRequest, ListinvoicesRequest, ListpeersRequest, +use cln_rpc::{ + model::{ + requests::{ + DatastoreMode, DatastoreRequest, DeldatastoreRequest, DelinvoiceRequest, + DelinvoiceStatus, ListdatastoreRequest, ListinvoicesRequest, ListpeersRequest, + }, + responses::InvoiceResponse, + }, + primitives::{Amount, AmountOrAny, PublicKey, ShortChannelId}, + ClnRpc, }; -use cln_rpc::model::responses::InvoiceResponse; -use cln_rpc::primitives::{Amount, AmountOrAny, PublicKey, ShortChannelId}; -use cln_rpc::ClnRpc; use log::{debug, info, warn}; use rand::{CryptoRng, Rng}; use serde::{Deserialize, Serialize}; use std::path::Path; +use std::path::PathBuf; use std::str::FromStr as _; /// An option to enable this service. @@ -38,24 +46,43 @@ const OPTION_ENABLED: options::FlagConfigOption = options::ConfigOption::new_fla "Enables an LSPS client on the node.", ); +const DEFAULT_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); + #[derive(Clone)] -struct State { - hook_manager: CustomMessageHookManager, +pub struct State { + sender: ClnSender, + pending: PendingRequests, + timeout: std::time::Duration, } -impl WithCustomMessageHookManager for State { - fn get_custommsg_hook_manager(&self) -> &CustomMessageHookManager { - &self.hook_manager +impl State { + pub fn new(rpc_path: PathBuf, timeout: std::time::Duration) -> Self { + Self { + sender: ClnSender::new(rpc_path), + pending: PendingRequests::new(), + timeout, + } + } + + pub fn client(&self) -> LspsClient> { + LspsClient::new(self.transport()) + } +} + +impl ClientState for State { + fn transport(&self) -> MultiplexedTransport { + MultiplexedTransport::new(self.sender.clone(), self.pending.clone(), self.timeout) + } + + fn pending(&self) -> &PendingRequests { + &self.pending } } #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - let hook_manager = CustomMessageHookManager::new(); - let state = State { hook_manager }; - if let Some(plugin) = cln_plugin::Builder::new(tokio::io::stdin(), tokio::io::stdout()) - .hook("custommsg", CustomMessageHookManager::on_custommsg::) + .hook("custommsg", hooks::client_custommsg_hook) .option(OPTION_ENABLED) .rpcmethod( "lsps-listprotocols", @@ -94,6 +121,10 @@ async fn main() -> Result<(), anyhow::Error> { .await; } + let dir = plugin.configuration().lightning_dir; + let rpc_path = Path::new(&dir).join(&plugin.configuration().rpc_file); + let state = State::new(rpc_path, DEFAULT_REQUEST_TIMEOUT); + let plugin = plugin.start(state).await?; plugin.join().await } else { @@ -113,6 +144,8 @@ async fn on_lsps_lsps2_getinfo( req.lsp_id, req.token ); + let lsp_id = PublicKey::from_str(&req.lsp_id).context("lsp_id is not a valid public key")?; + let dir = p.configuration().lightning_dir; let rpc_path = Path::new(&dir).join(&p.configuration().rpc_file); let mut cln_client = cln_rpc::ClnRpc::new(rpc_path.clone()).await?; @@ -131,25 +164,12 @@ async fn on_lsps_lsps2_getinfo( debug!("Peer {} doesn't have the LSP feature bit set.", &req.lsp_id); } - // Create Transport and Client - let transport = Bolt8Transport::new( - &req.lsp_id, - rpc_path.clone(), // Clone path for potential reuse - p.state().hook_manager.clone(), - None, // Use default timeout - ) - .context("Failed to create Bolt8Transport")?; - let client = JsonRpcClient::new(transport); - // 1. Call lsps2.get_info. - let info_req = Lsps2GetInfoRequest { token: req.token }; - let info_res: Lsps2GetInfoResponse = client - .call_typed(info_req) - .await - .context("lsps2.get_info call failed")?; - debug!("received lsps2.get_info response: {:?}", info_res); - - Ok(serde_json::to_value(info_res)?) + let client = p.state().client(); + match client.get_info(&lsp_id, req.token).await?.as_result() { + Ok(i) => Ok(serde_json::to_value(i)?), + Err(e) => Ok(serde_json::to_value(e)?), + } } /// Rpc Method handler for `lsps-lsps2-buy`. @@ -164,6 +184,8 @@ async fn on_lsps_lsps2_buy( req.lsp_id, req.opening_fee_params, req.payment_size_msat ); + let lsp_id = PublicKey::from_str(&req.lsp_id).context("lsp_id is not a valid public key")?; + let dir = p.configuration().lightning_dir; let rpc_path = Path::new(&dir).join(&p.configuration().rpc_file); let mut cln_client = cln_rpc::ClnRpc::new(rpc_path.clone()).await?; @@ -182,15 +204,7 @@ async fn on_lsps_lsps2_buy( debug!("Peer {} doesn't have the LSP feature bit set.", &req.lsp_id); } - // Create Transport and Client - let transport = Bolt8Transport::new( - &req.lsp_id, - rpc_path.clone(), // Clone path for potential reuse - p.state().hook_manager.clone(), - None, // Use default timeout - ) - .context("Failed to create Bolt8Transport")?; - let client = JsonRpcClient::new(transport); + let client = p.state().client(); let selected_params = req.opening_fee_params; if let Some(payment_size) = req.payment_size_msat { @@ -236,16 +250,14 @@ async fn on_lsps_lsps2_buy( } debug!("Calling lsps2.buy for peer {}", req.lsp_id); - let buy_req = Lsps2BuyRequest { - opening_fee_params: selected_params, // Pass the chosen params back - payment_size_msat: req.payment_size_msat, - }; - let buy_res: Lsps2BuyResponse = client - .call_typed(buy_req) - .await - .context("lsps2.buy call failed")?; - - Ok(serde_json::to_value(buy_res)?) + match client + .buy(&lsp_id, selected_params, req.payment_size_msat) + .await? + .as_result() + { + Ok(i) => Ok(serde_json::to_value(i)?), + Err(e) => Ok(serde_json::to_value(e)?), + } } async fn on_lsps_lsps2_approve( @@ -703,6 +715,7 @@ async fn on_lsps_listprotocols( let mut cln_client = cln_rpc::ClnRpc::new(rpc_path.clone()).await?; let req: Request = serde_json::from_value(v).context("Failed to parse request JSON")?; + let lsp_id = PublicKey::from_str(&req.lsp_id).context("lsp_id is not a valid public key")?; let lsp_status = check_peer_lsp_status(&mut cln_client, &req.lsp_id).await?; // Fail early: Check that we are connected to the peer. @@ -717,26 +730,14 @@ async fn on_lsps_listprotocols( debug!("Peer {} doesn't have the LSP feature bit set.", &req.lsp_id); } - // Create the transport first and handle potential errors - let transport = Bolt8Transport::new( - &req.lsp_id, - rpc_path, - p.state().hook_manager.clone(), - None, // Use default timeout - ) - .context("Failed to create Bolt8Transport")?; - - // Now create the client using the transport - let client = JsonRpcClient::new(transport); - - let request = lsps0::model::Lsps0listProtocolsRequest {}; - let res: lsps0::model::Lsps0listProtocolsResponse = client - .call_typed(request) - .await - .map_err(|e| anyhow!("lsps0.list_protocols call failed: {}", e))?; - - debug!("Received lsps0.list_protocols response: {:?}", res); - Ok(serde_json::to_value(res)?) + let client = p.state().client(); + match client.list_protocols(&lsp_id).await?.as_result() { + Ok(i) => { + debug!("Received lsps0.list_protocols response: {:?}", i); + Ok(serde_json::to_value(i)?) + } + Err(e) => Ok(serde_json::to_value(e)?), + } } struct PeerLspStatus { @@ -771,7 +772,7 @@ async fn check_peer_lsp_status( let has_lsp_feature = if let Some(f_str) = &peer.features { let feature_bits = hex::decode(f_str) .map_err(|e| anyhow!("Invalid feature bits hex for peer {peer_id}, {f_str}: {e}"))?; - util::is_feature_bit_set_reversed(&feature_bits, LSP_FEATURE_BIT) + is_feature_bit_set_reversed(&feature_bits, LSP_FEATURE_BIT) } else { false }; diff --git a/plugins/lsps-plugin/src/cln_adapters/hooks.rs b/plugins/lsps-plugin/src/cln_adapters/hooks.rs new file mode 100644 index 000000000000..4cb64785265f --- /dev/null +++ b/plugins/lsps-plugin/src/cln_adapters/hooks.rs @@ -0,0 +1,101 @@ +use crate::{ + cln_adapters::state::{ClientState, ServiceState}, + core::{router::RequestContext, transport::MessageSender as _}, + proto::lsps0, +}; +use anyhow::Result; +use bitcoin::secp256k1::PublicKey; +use cln_plugin::Plugin; +use serde::Deserialize; +use serde_json::Value; + +pub async fn client_custommsg_hook(plugin: Plugin, v: Value) -> Result +where + S: Clone + Sync + Send + 'static + ClientState, +{ + let Some(hook) = CustomMsgHook::parse(v) else { + return Ok(serde_json::json!({ + "result": "continue" + })); + }; + + if let Some(id) = extract_message_id(&hook.payload) { + plugin.state().pending().complete(&id, hook.payload).await; + } + + return Ok(serde_json::json!({ + "result": "continue" + })); +} + +pub async fn service_custommsg_hook(plugin: Plugin, v: Value) -> Result +where + S: Clone + Sync + Send + 'static + ServiceState, +{ + let Some(hook) = CustomMsgHook::parse(v) else { + return Ok(serde_json::json!({ + "result": "continue" + })); + }; + let service = plugin.state().service(); + let ctx = RequestContext { + peer_id: hook.peer_id, + }; + let res = service.handle(&ctx, &hook.payload).await; + if let Some(payload) = res { + let sender = plugin.state().sender().clone(); + if let Err(e) = sender.send(&hook.peer_id, &payload).await { + log::error!("Failed to send LSPS response to {}: {}", &hook.peer_id, e); + }; + } + + Ok(serde_json::json!({ + "result": "continue" + })) +} + +#[derive(Debug, Deserialize)] +struct CustomMsgHookRaw { + peer_id: PublicKey, + payload: String, +} + +/// Parsed and validated hook data +pub struct CustomMsgHook { + pub peer_id: PublicKey, + pub payload: Vec, +} + +impl CustomMsgHook { + /// Parse and validate everything upfront + pub fn parse(v: Value) -> Option { + let raw: CustomMsgHookRaw = serde_json::from_value(v).ok()?; + let peer_id = raw.peer_id; + let payload = decode_lsps0_frame_hex(&raw.payload)?; + Some(Self { peer_id, payload }) + } +} + +fn decode_lsps0_frame_hex(hex_str: &str) -> Option> { + let frame = match hex::decode(hex_str) { + Ok(f) => f, + Err(e) => { + log::error!( + "Failed to decode hex string payload from custom message: {}", + e + ); + return None; + } + }; + lsps0::decode_frame(&frame).ok().map(|d| d.to_owned()) +} + +fn extract_message_id(payload: &[u8]) -> Option { + #[derive(Deserialize)] + struct IdOnly { + id: Option, + } + + let parsed: IdOnly = serde_json::from_slice(payload).ok()?; + parsed.id +} diff --git a/plugins/lsps-plugin/src/cln_adapters/mod.rs b/plugins/lsps-plugin/src/cln_adapters/mod.rs new file mode 100644 index 000000000000..063690099ca2 --- /dev/null +++ b/plugins/lsps-plugin/src/cln_adapters/mod.rs @@ -0,0 +1,5 @@ +pub mod hooks; +pub mod rpc; +pub mod sender; +pub mod state; +pub mod types; diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs new file mode 100644 index 000000000000..07dd1b007a28 --- /dev/null +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -0,0 +1,304 @@ +use crate::{ + core::lsps2::provider::{ + Blockheight, BlockheightProvider, DatastoreProvider, LightningProvider, Lsps2OfferProvider, + }, + proto::{ + lsps0::Msat, + lsps2::{ + DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, + Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, + }, + }, +}; +use anyhow::{Context, Result}; +use async_trait::async_trait; +use bitcoin::secp256k1::PublicKey; +use cln_rpc::{ + model::{ + requests::{ + DatastoreMode, DatastoreRequest, DeldatastoreRequest, FundchannelRequest, + GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, + }, + responses::ListdatastoreResponse, + }, + primitives::{Amount, AmountOrAll, ChannelState, Sha256, ShortChannelId}, + ClnRpc, +}; +use core::fmt; +use serde::Serialize; +use std::path::PathBuf; + +pub const DS_MAIN_KEY: &'static str = "lsps"; +pub const DS_SUB_KEY: &'static str = "lsps2"; + +#[derive(Clone)] +pub struct ClnApiRpc { + rpc_path: PathBuf, +} + +impl ClnApiRpc { + pub fn new(rpc_path: PathBuf) -> Self { + Self { rpc_path } + } + + async fn create_rpc(&self) -> Result { + ClnRpc::new(&self.rpc_path).await + } +} + +#[async_trait] +impl LightningProvider for ClnApiRpc { + async fn fund_jit_channel( + &self, + peer_id: &PublicKey, + amount: &Msat, + ) -> Result<(Sha256, String)> { + let mut rpc = self.create_rpc().await?; + let res = rpc + .call_typed(&FundchannelRequest { + announce: Some(false), + close_to: None, + compact_lease: None, + feerate: None, + minconf: None, + mindepth: Some(0), + push_msat: None, + request_amt: None, + reserve: None, + channel_type: Some(vec![12, 46, 50]), + utxos: None, + amount: AmountOrAll::Amount(Amount::from_msat(amount.msat())), + id: peer_id.to_owned(), + }) + .await + .with_context(|| "calling fundchannel")?; + Ok((res.channel_id, res.txid)) + } + + async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> Result { + let mut rpc = self.create_rpc().await?; + let r = rpc + .call_typed(&ListpeerchannelsRequest { + id: Some(peer_id.to_owned()), + short_channel_id: None, + }) + .await + .with_context(|| "calling listpeerchannels")?; + + let chs = r + .channels + .iter() + .find(|&ch| ch.channel_id.is_some_and(|id| id == *channel_id)); + if let Some(ch) = chs { + if ch.state == ChannelState::CHANNELD_NORMAL { + return Ok(true); + } + } + + return Ok(false); + } +} + +#[async_trait] +impl DatastoreProvider for ClnApiRpc { + async fn store_buy_request( + &self, + scid: &ShortChannelId, + peer_id: &PublicKey, + opening_fee_params: &OpeningFeeParams, + expected_payment_size: &Option, + ) -> Result { + let mut rpc = self.create_rpc().await?; + #[derive(Serialize)] + struct BorrowedDatastoreEntry<'a> { + peer_id: &'a PublicKey, + opening_fee_params: &'a OpeningFeeParams, + #[serde(borrow)] + expected_payment_size: &'a Option, + } + + let ds = BorrowedDatastoreEntry { + peer_id, + opening_fee_params, + expected_payment_size, + }; + let json_str = serde_json::to_string(&ds)?; + + let ds = DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::MUST_CREATE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + scid.to_string(), + ], + }; + + let _ = rpc + .call_typed(&ds) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling datastore")?; + + Ok(true) + } + + async fn get_buy_request(&self, scid: &ShortChannelId) -> Result { + let mut rpc = self.create_rpc().await?; + let key = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + scid.to_string(), + ]; + let res = rpc + .call_typed(&ListdatastoreRequest { + key: Some(key.clone()), + }) + .await + .with_context(|| "calling listdatastore")?; + + let (rec, _) = deserialize_by_key(&res, key)?; + Ok(rec) + } + + async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { + let mut rpc = self.create_rpc().await?; + let key = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + scid.to_string(), + ]; + + let _ = rpc + .call_typed(&DeldatastoreRequest { + generation: None, + key, + }) + .await; + + Ok(()) + } +} + +#[async_trait] +impl Lsps2OfferProvider for ClnApiRpc { + async fn get_offer( + &self, + request: &Lsps2PolicyGetInfoRequest, + ) -> Result { + let mut rpc = self.create_rpc().await?; + rpc.call_raw("lsps2-policy-getpolicy", request) + .await + .context("failed to call lsps2-policy-getpolicy") + } + + async fn get_channel_capacity( + &self, + params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> Result { + let mut rpc = self.create_rpc().await?; + rpc.call_raw("lsps2-policy-getchannelcapacity", params) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling lsps2-policy-getchannelcapacity") + } +} + +#[async_trait] +impl BlockheightProvider for ClnApiRpc { + async fn get_blockheight(&self) -> Result { + let mut rpc = self.create_rpc().await?; + let info = rpc + .call_typed(&GetinfoRequest {}) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling getinfo")?; + Ok(info.blockheight) + } +} + +#[derive(Debug)] +pub enum DsError { + /// No datastore entry with this exact key. + NotFound { key: Vec }, + /// Entry existed but had neither `string` nor `hex`. + MissingValue { key: Vec }, + /// JSON parse failed (from `string` or decoded `hex`). + JsonParse { + key: Vec, + source: serde_json::Error, + }, + /// Hex decode failed. + HexDecode { + key: Vec, + source: hex::FromHexError, + }, +} + +impl fmt::Display for DsError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DsError::NotFound { key } => write!(f, "no datastore entry for key {:?}", key), + DsError::MissingValue { key } => write!( + f, + "datastore entry had neither `string` nor `hex` for key {:?}", + key + ), + DsError::JsonParse { key, source } => { + write!(f, "failed to parse JSON at key {:?}: {}", key, source) + } + DsError::HexDecode { key, source } => { + write!(f, "failed to decode hex at key {:?}: {}", key, source) + } + } + } +} + +impl std::error::Error for DsError {} + +pub fn deserialize_by_key( + resp: &ListdatastoreResponse, + key: K, +) -> std::result::Result<(DatastoreEntry, Option), DsError> +where + K: AsRef<[String]>, +{ + let wanted: &[String] = key.as_ref(); + + let ds = resp + .datastore + .iter() + .find(|d| d.key.as_slice() == wanted) + .ok_or_else(|| DsError::NotFound { + key: wanted.to_vec(), + })?; + + // Prefer `string`, fall back to `hex` + if let Some(s) = &ds.string { + let value = serde_json::from_str::(s).map_err(|e| DsError::JsonParse { + key: ds.key.clone(), + source: e, + })?; + return Ok((value, ds.generation)); + } + + if let Some(hx) = &ds.hex { + let bytes = hex::decode(hx).map_err(|e| DsError::HexDecode { + key: ds.key.clone(), + source: e, + })?; + let value = + serde_json::from_slice::(&bytes).map_err(|e| DsError::JsonParse { + key: ds.key.clone(), + source: e, + })?; + return Ok((value, ds.generation)); + } + + Err(DsError::MissingValue { + key: ds.key.clone(), + }) +} diff --git a/plugins/lsps-plugin/src/cln_adapters/sender.rs b/plugins/lsps-plugin/src/cln_adapters/sender.rs new file mode 100644 index 000000000000..39a73acd64ca --- /dev/null +++ b/plugins/lsps-plugin/src/cln_adapters/sender.rs @@ -0,0 +1,45 @@ +use crate::{ + core::transport::{Error as TransportError, MessageSender}, + proto::lsps0, +}; +use async_trait::async_trait; +use bitcoin::secp256k1::PublicKey; +use cln_rpc::{model::requests::SendcustommsgRequest, ClnRpc}; +use std::path::PathBuf; + +#[derive(Clone)] +pub struct ClnSender { + rpc_path: PathBuf, +} + +impl ClnSender { + pub fn new(rpc_path: PathBuf) -> Self { + Self { rpc_path } + } +} + +#[async_trait] +impl MessageSender for ClnSender { + async fn send(&self, peer_id: &PublicKey, payload: &[u8]) -> Result<(), TransportError> { + let mut rpc = ClnRpc::new(&self.rpc_path) + .await + .map_err(|e| TransportError::Internal(e.to_string()))?; + + // Encode frame for LSPS0 Bolt8 transport. + let msg = encode_lsps0_frame_hex(payload); + + rpc.call_typed(&SendcustommsgRequest { + msg, + node_id: peer_id.to_owned(), + }) + .await + .map_err(|e| TransportError::Internal(e.to_string()))?; + + Ok(()) + } +} + +fn encode_lsps0_frame_hex(payload: &[u8]) -> String { + let frame = lsps0::encode_frame(payload); + hex::encode(&frame) +} diff --git a/plugins/lsps-plugin/src/cln_adapters/state.rs b/plugins/lsps-plugin/src/cln_adapters/state.rs new file mode 100644 index 000000000000..4ba564a46272 --- /dev/null +++ b/plugins/lsps-plugin/src/cln_adapters/state.rs @@ -0,0 +1,19 @@ +use std::sync::Arc; + +use crate::{ + cln_adapters::sender::ClnSender, + core::{ + server::LspsService, + transport::{MultiplexedTransport, PendingRequests}, + }, +}; + +pub trait ClientState { + fn transport(&self) -> MultiplexedTransport; + fn pending(&self) -> &PendingRequests; +} + +pub trait ServiceState { + fn service(&self) -> Arc; + fn sender(&self) -> ClnSender; +} diff --git a/plugins/lsps-plugin/src/cln_adapters/types.rs b/plugins/lsps-plugin/src/cln_adapters/types.rs new file mode 100644 index 000000000000..80ea79273e6b --- /dev/null +++ b/plugins/lsps-plugin/src/cln_adapters/types.rs @@ -0,0 +1,165 @@ +//! Backfill structs for missing or incomplete Core Lightning types. +//! +//! This module provides struct implementations that are not available or +//! fully accessible in the core-lightning crate, enabling better compatibility +//! and interoperability with Core Lightning's RPC interface. +use cln_rpc::primitives::{Amount, ShortChannelId}; +use hex::FromHex; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::core::tlv::TlvStream; + +#[derive(Debug, Deserialize)] +#[allow(unused)] +pub struct Onion { + pub forward_msat: Option, + #[serde(deserialize_with = "from_hex")] + pub next_onion: Vec, + pub outgoing_cltv_value: Option, + pub payload: TlvStream, + // pub payload: TlvStream, + #[serde(deserialize_with = "from_hex")] + pub shared_secret: Vec, + pub short_channel_id: Option, + pub total_msat: Option, + #[serde(rename = "type")] + pub type_: Option, +} + +#[derive(Debug, Deserialize)] +#[allow(unused)] +pub struct Htlc { + pub amount_msat: Amount, + pub cltv_expiry: u32, + pub cltv_expiry_relative: u16, + pub id: u64, + #[serde(deserialize_with = "from_hex")] + pub payment_hash: Vec, + pub short_channel_id: ShortChannelId, + pub extra_tlvs: Option, +} + +#[derive(Debug, Serialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum HtlcAcceptedResult { + Continue, + Fail, + Resolve, +} + +impl std::fmt::Display for HtlcAcceptedResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + HtlcAcceptedResult::Continue => "continue", + HtlcAcceptedResult::Fail => "fail", + HtlcAcceptedResult::Resolve => "resolve", + }; + write!(f, "{s}") + } +} + +#[derive(Debug, Deserialize)] +pub struct HtlcAcceptedRequest { + pub htlc: Htlc, + pub onion: Onion, + pub forward_to: Option, +} + +#[derive(Debug, Serialize)] +pub struct HtlcAcceptedResponse { + pub result: HtlcAcceptedResult, + #[serde(skip_serializing_if = "Option::is_none")] + pub payment_key: Option, + #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] + pub payload: Option>, + #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] + pub forward_to: Option>, + #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] + pub extra_tlvs: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub failure_message: Option, + #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] + pub failure_onion: Option>, +} + +impl HtlcAcceptedResponse { + pub fn continue_( + payload: Option>, + forward_to: Option>, + extra_tlvs: Option>, + ) -> Self { + Self { + result: HtlcAcceptedResult::Continue, + payment_key: None, + payload, + forward_to, + extra_tlvs, + failure_message: None, + failure_onion: None, + } + } + + pub fn fail(failure_message: Option, failure_onion: Option>) -> Self { + Self { + result: HtlcAcceptedResult::Fail, + payment_key: None, + payload: None, + forward_to: None, + extra_tlvs: None, + failure_message, + failure_onion, + } + } +} + +#[derive(Debug, Deserialize)] +pub struct InvoicePaymentRequest { + pub payment: InvoicePaymentRequestPayment, +} + +#[derive(Debug, Deserialize)] +pub struct InvoicePaymentRequestPayment { + pub label: String, + pub preimage: String, + pub msat: u64, +} + +#[derive(Debug, Deserialize)] +pub struct OpenChannelRequest { + pub openchannel: OpenChannelRequestOpenChannel, +} + +#[derive(Debug, Deserialize)] +pub struct OpenChannelRequestOpenChannel { + pub id: String, + pub funding_msat: u64, + pub push_msat: u64, + pub dust_limit_msat: u64, + pub max_htlc_value_in_flight_msat: u64, + pub channel_reserve_msat: u64, + pub htlc_minimum_msat: u64, + pub feerate_per_kw: u32, + pub to_self_delay: u32, + pub max_accepted_htlcs: u32, + pub channel_flags: u64, +} + +/// Deserializes a lowercase hex string to a `Vec`. +pub fn from_hex<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + use serde::de::Error; + String::deserialize(deserializer) + .and_then(|string| Vec::from_hex(string).map_err(|err| Error::custom(err.to_string()))) +} + +pub fn to_hex(bytes: &Option>, serializer: S) -> Result +where + S: Serializer, +{ + match bytes { + Some(data) => serializer.serialize_str(&hex::encode(data)), + None => serializer.serialize_none(), + } +} diff --git a/plugins/lsps-plugin/src/core/client.rs b/plugins/lsps-plugin/src/core/client.rs new file mode 100644 index 000000000000..0c494e9f87ee --- /dev/null +++ b/plugins/lsps-plugin/src/core/client.rs @@ -0,0 +1,66 @@ +use bitcoin::secp256k1::PublicKey; + +use crate::{ + core::transport::{self, Transport}, + proto::{ + jsonrpc::{JsonRpcRequest, JsonRpcResponse}, + lsps0::{Lsps0listProtocolsRequest, Lsps0listProtocolsResponse, Msat}, + lsps2::{ + Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, Lsps2GetInfoResponse, + OpeningFeeParams, + }, + }, +}; + +pub struct LspsClient { + transport: T, +} + +impl LspsClient { + pub fn new(transport: T) -> Self { + Self { transport } + } +} + +// LSPS0 Implementation +impl LspsClient { + pub async fn list_protocols( + &self, + peer: &PublicKey, + ) -> Result, transport::Error> { + self.transport + .request(peer, &Lsps0listProtocolsRequest {}.into_request()) + .await + } +} + +// LSPS2 Implementation +impl LspsClient { + pub async fn get_info( + &self, + peer: &PublicKey, + token: Option, + ) -> Result, transport::Error> { + self.transport + .request(peer, &Lsps2GetInfoRequest { token }.into_request()) + .await + } + + pub async fn buy( + &self, + peer: &PublicKey, + opening_fee_params: OpeningFeeParams, + payment_size_msat: Option, + ) -> Result, transport::Error> { + self.transport + .request( + peer, + &Lsps2BuyRequest { + opening_fee_params, + payment_size_msat, + } + .into_request(), + ) + .await + } +} diff --git a/plugins/lsps-plugin/src/util.rs b/plugins/lsps-plugin/src/core/features.rs similarity index 62% rename from plugins/lsps-plugin/src/util.rs rename to plugins/lsps-plugin/src/core/features.rs index 06784911dfc2..1d7d4f1d3319 100644 --- a/plugins/lsps-plugin/src/util.rs +++ b/plugins/lsps-plugin/src/core/features.rs @@ -1,9 +1,4 @@ -use anyhow::anyhow; -use anyhow::Result; -use cln_rpc::primitives::PublicKey; use core::fmt; -use serde_json::Value; -use std::str::FromStr; /// Checks whether a feature bit is set in a bitmap interpreted as /// **big-endian across bytes**, while keeping **LSB-first within each byte**. @@ -97,103 +92,10 @@ impl std::error::Error for UnwrapError { } } -/// Wraps a payload with a peer ID for internal LSPS message transmission. -pub fn try_wrap_payload_with_peer_id(payload: &[u8], peer_id: PublicKey) -> Result> { - // We expect the payload to be valid json, so no empty payload allowed, also - // checks that we have curly braces at start and end. - if payload.is_empty() || payload[0] != b'{' || payload[payload.len() - 1] != b'}' { - return Err(anyhow!("payload no valid json")); - } - - let pubkey_hex = peer_id.to_string(); - let mut result = Vec::with_capacity(pubkey_hex.len() + payload.len() + 13); - - result.extend_from_slice(&payload[..payload.len() - 1]); - result.extend_from_slice(b",\"peer_id\":\""); - result.extend_from_slice(pubkey_hex.as_bytes()); - result.extend_from_slice(b"\"}"); - Ok(result) -} - -/// Safely unwraps payload data and a peer ID -pub fn try_unwrap_payload_with_peer_id(data: &[u8]) -> Result<(Vec, PublicKey)> { - let mut json: Value = - serde_json::from_slice(data).map_err(|e| UnwrapError::SerdeFailure(e.to_string()))?; - - if let Value::Object(ref mut map) = json { - if let Some(Value::String(peer_id)) = map.remove("peer_id") { - let modified_json = serde_json::to_string(&json) - .map_err(|e| UnwrapError::SerdeFailure(e.to_string()))?; - return Ok(( - modified_json.into_bytes(), - PublicKey::from_str(&peer_id) - .map_err(|e| UnwrapError::InvalidPublicKey(e.to_string()))?, - )); - } - } - Err(UnwrapError::InvalidPublicKey(String::from( - "public key missing", - )))? -} - -/// Unwraps payload data and peer ID, panicking on error -/// -/// This is a convenience function for cases where one knows the data is valid. -pub fn unwrap_payload_with_peer_id(data: &[u8]) -> (Vec, PublicKey) { - try_unwrap_payload_with_peer_id(data).expect("Failed to unwrap payload with peer_id") -} - -/// Wraps payload data and peer ID, panicking on error -/// -/// This is a convenience function for cases where one knows that the payload is -/// valid. -pub fn wrap_payload_with_peer_id(payload: &[u8], peer_id: PublicKey) -> Vec { - try_wrap_payload_with_peer_id(payload, peer_id).expect("Failed to wrap payload with peer_id") -} - #[cfg(test)] mod tests { - use serde_json::json; - use super::*; - // Valid test public key - const PUBKEY: [u8; 33] = [ - 0x02, 0x79, 0xbe, 0x66, 0x7e, 0xf9, 0xdc, 0xbb, 0xac, 0x55, 0xa0, 0x62, 0x95, 0xce, 0x87, - 0x0b, 0x07, 0x02, 0x9b, 0xfc, 0xdb, 0x2d, 0xce, 0x28, 0xd9, 0x59, 0xf2, 0x81, 0x5b, 0x16, - 0xf8, 0x17, 0x98, - ]; - - #[test] - fn test_wrap_and_unwrap_roundtrip() { - let peer_id = PublicKey::from_slice(&PUBKEY).unwrap(); - let payload = - json!({"jsonrpc": "2.0","method": "some-method","params": {},"id": "some-id"}); - let wrapped = wrap_payload_with_peer_id(payload.to_string().as_bytes(), peer_id); - - let (unwrapped_payload, unwrapped_peer_id) = unwrap_payload_with_peer_id(&wrapped); - let value: serde_json::Value = serde_json::from_slice(&unwrapped_payload).unwrap(); - - assert_eq!(value, payload); - assert_eq!(unwrapped_peer_id, peer_id); - } - - #[test] - fn test_invalid_pubkey() { - let mut invalid_data = vec![0u8; 40]; - // Set an invalid public key (all zeros) - invalid_data[0] = 0x02; // Valid prefix - // But rest remains zeros which is invalid - let payload = json!({"jsonrpc": "2.0","method": "some-method","params": {},"id": "some-id","peer_id": hex::encode(&invalid_data)}); - - let result = try_unwrap_payload_with_peer_id(payload.to_string().as_bytes()); - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err().downcast_ref::(), - Some(UnwrapError::InvalidPublicKey(_)) - )); - } - #[test] fn test_basic_bit_checks() { // Example bitmap: diff --git a/plugins/lsps-plugin/src/lsps2/handler.rs b/plugins/lsps-plugin/src/core/lsps2/handler.rs similarity index 53% rename from plugins/lsps-plugin/src/lsps2/handler.rs rename to plugins/lsps-plugin/src/core/lsps2/handler.rs index 4e86943f6c7e..88124788a62f 100644 --- a/plugins/lsps-plugin/src/lsps2/handler.rs +++ b/plugins/lsps-plugin/src/core/lsps2/handler.rs @@ -1,23 +1,28 @@ use crate::{ - jsonrpc::{server::RequestHandler, JsonRpcResponse as _, RequestObject, RpcError}, - lsps0::primitives::{Msat, ShortChannelId}, + core::lsps2::service::Lsps2Handler, lsps2::{ cln::{HtlcAcceptedRequest, HtlcAcceptedResponse, TLV_FORWARD_AMT}, - model::{ + DS_MAIN_KEY, DS_SUB_KEY, + }, + proto::{ + jsonrpc::{RpcError, RpcErrorExt as _}, + lsps0::{LSPS0RpcErrorExt, Msat, ShortChannelId}, + lsps2::{ compute_opening_fee, failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, DatastoreEntry, Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, Lsps2GetInfoResponse, Lsps2PolicyGetChannelCapacityRequest, Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, Promise, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, PolicyOpeningFeeParams, Promise, }, - DS_MAIN_KEY, DS_SUB_KEY, }, - util::unwrap_payload_with_peer_id, }; use anyhow::{Context, Result as AnyResult}; use async_trait::async_trait; -use bitcoin::hashes::Hash as _; +use bitcoin::{ + hashes::{sha256::Hash as Sha256, Hash as _}, + secp256k1::PublicKey, +}; use chrono::Utc; use cln_rpc::{ model::{ @@ -25,51 +30,15 @@ use cln_rpc::{ DatastoreMode, DatastoreRequest, DeldatastoreRequest, FundchannelRequest, GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, }, - responses::{ - DatastoreResponse, DeldatastoreResponse, FundchannelResponse, GetinfoResponse, - ListdatastoreResponse, ListpeerchannelsResponse, - }, + responses::ListdatastoreResponse, }, primitives::{Amount, AmountOrAll, ChannelState}, ClnRpc, }; use log::{debug, warn}; use rand::{rng, Rng as _}; -use std::{fmt, path::PathBuf, time::Duration}; - -#[async_trait] -pub trait ClnApi: Send + Sync { - async fn lsps2_getpolicy( - &self, - params: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult; - - async fn lsps2_getchannelcapacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult; - - async fn cln_getinfo(&self, params: &GetinfoRequest) -> AnyResult; - - async fn cln_datastore(&self, params: &DatastoreRequest) -> AnyResult; - - async fn cln_listdatastore( - &self, - params: &ListdatastoreRequest, - ) -> AnyResult; - - async fn cln_deldatastore( - &self, - params: &DeldatastoreRequest, - ) -> AnyResult; - - async fn cln_fundchannel(&self, params: &FundchannelRequest) -> AnyResult; - - async fn cln_listpeerchannels( - &self, - params: &ListpeerchannelsRequest, - ) -> AnyResult; -} +use serde::Serialize; +use std::{fmt, path::PathBuf, sync::Arc, time::Duration}; const DEFAULT_CLTV_EXPIRY_DELTA: u32 = 144; @@ -89,251 +58,336 @@ impl ClnApiRpc { } #[async_trait] -impl ClnApi for ClnApiRpc { - async fn lsps2_getpolicy( +impl LightningProvider for ClnApiRpc { + async fn fund_jit_channel( &self, - params: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult { + peer_id: &PublicKey, + amount: &Msat, + ) -> AnyResult<(Sha256, String)> { let mut rpc = self.create_rpc().await?; - rpc.call_raw("lsps2-policy-getpolicy", params) + let res = rpc + .call_typed(&FundchannelRequest { + announce: Some(false), + close_to: None, + compact_lease: None, + feerate: None, + minconf: None, + mindepth: Some(0), + push_msat: None, + request_amt: None, + reserve: None, + channel_type: Some(vec![12, 46, 50]), + utxos: None, + amount: AmountOrAll::Amount(Amount::from_msat(amount.msat())), + id: peer_id.to_owned(), + }) .await - .map_err(anyhow::Error::new) - .with_context(|| "calling lsps2-policy-getpolicy") + .with_context(|| "calling fundchannel")?; + Ok((res.channel_id, res.txid)) } - async fn lsps2_getchannelcapacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { + async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> AnyResult { let mut rpc = self.create_rpc().await?; - rpc.call_raw("lsps2-policy-getchannelcapacity", params) + let r = rpc + .call_typed(&ListpeerchannelsRequest { + id: Some(peer_id.to_owned()), + short_channel_id: None, + }) .await - .map_err(anyhow::Error::new) - .with_context(|| "calling lsps2-policy-getchannelcapacity") + .with_context(|| "calling listpeerchannels")?; + + let chs = r + .channels + .iter() + .find(|&ch| ch.channel_id.is_some_and(|id| id == *channel_id)); + if let Some(ch) = chs { + if ch.state == ChannelState::CHANNELD_NORMAL { + return Ok(true); + } + } + + return Ok(false); } +} - async fn cln_getinfo(&self, params: &GetinfoRequest) -> AnyResult { +#[async_trait] +impl DatastoreProvider for ClnApiRpc { + async fn store_buy_request( + &self, + scid: &ShortChannelId, + peer_id: &PublicKey, + opening_fee_params: &OpeningFeeParams, + expected_payment_size: &Option, + ) -> AnyResult { let mut rpc = self.create_rpc().await?; - rpc.call_typed(params) + #[derive(Serialize)] + struct BorrowedDatastoreEntry<'a> { + peer_id: &'a PublicKey, + opening_fee_params: &'a OpeningFeeParams, + #[serde(borrow)] + expected_payment_size: &'a Option, + } + + let ds = BorrowedDatastoreEntry { + peer_id, + opening_fee_params, + expected_payment_size, + }; + let json_str = serde_json::to_string(&ds)?; + + let ds = DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::MUST_CREATE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + scid.to_string(), + ], + }; + + let _ = rpc + .call_typed(&ds) .await .map_err(anyhow::Error::new) - .with_context(|| "calling getinfo") + .with_context(|| "calling datastore")?; + + Ok(true) } - async fn cln_datastore(&self, params: &DatastoreRequest) -> AnyResult { + async fn get_buy_request(&self, scid: &ShortChannelId) -> AnyResult { let mut rpc = self.create_rpc().await?; - rpc.call_typed(params) + let key = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + scid.to_string(), + ]; + let res = rpc + .call_typed(&ListdatastoreRequest { + key: Some(key.clone()), + }) .await - .map_err(anyhow::Error::new) - .with_context(|| "calling datastore") + .with_context(|| "calling listdatastore")?; + + let (rec, _) = deserialize_by_key(&res, key)?; + Ok(rec) } - async fn cln_listdatastore( - &self, - params: &ListdatastoreRequest, - ) -> AnyResult { + async fn del_buy_request(&self, scid: &ShortChannelId) -> AnyResult<()> { let mut rpc = self.create_rpc().await?; - rpc.call_typed(params) - .await - .map_err(anyhow::Error::new) - .with_context(|| "calling listdatastore") + let key = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + scid.to_string(), + ]; + + let _ = rpc + .call_typed(&DeldatastoreRequest { + generation: None, + key, + }) + .await; + + Ok(()) } +} - async fn cln_deldatastore( +#[async_trait] +impl Lsps2OfferProvider for ClnApiRpc { + async fn get_offer( &self, - params: &DeldatastoreRequest, - ) -> AnyResult { + request: &Lsps2PolicyGetInfoRequest, + ) -> AnyResult { let mut rpc = self.create_rpc().await?; - rpc.call_typed(params) + rpc.call_raw("lsps2-policy-getpolicy", request) .await - .map_err(anyhow::Error::new) - .with_context(|| "calling deldatastore") + .context("failed to call lsps2-policy-getpolicy") } - async fn cln_fundchannel(&self, params: &FundchannelRequest) -> AnyResult { + async fn get_channel_capacity( + &self, + params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> AnyResult { let mut rpc = self.create_rpc().await?; - rpc.call_typed(params) + rpc.call_raw("lsps2-policy-getchannelcapacity", params) .await .map_err(anyhow::Error::new) - .with_context(|| "calling fundchannel") + .with_context(|| "calling lsps2-policy-getchannelcapacity") } +} - async fn cln_listpeerchannels( - &self, - params: &ListpeerchannelsRequest, - ) -> AnyResult { +#[async_trait] +impl BlockheightProvider for ClnApiRpc { + async fn get_blockheight(&self) -> AnyResult { let mut rpc = self.create_rpc().await?; - rpc.call_typed(params) + let info = rpc + .call_typed(&GetinfoRequest {}) .await .map_err(anyhow::Error::new) - .with_context(|| "calling listpeerchannels") - } -} - -/// Handler for the `lsps2.get_info` method. -pub struct Lsps2GetInfoHandler { - pub api: A, - pub promise_secret: [u8; 32], -} - -impl Lsps2GetInfoHandler { - pub fn new(api: A, promise_secret: [u8; 32]) -> Self { - Self { - api, - promise_secret, - } + .with_context(|| "calling getinfo")?; + Ok(info.blockheight) } } -/// The RequestHandler calls the internal rpc command `lsps2-policy-getinfo`. It -/// expects a plugin has registered this command and manages policies for the -/// LSPS2 service. #[async_trait] -impl RequestHandler for Lsps2GetInfoHandler { - async fn handle(&self, payload: &[u8]) -> core::result::Result, RpcError> { - let (payload, _) = unwrap_payload_with_peer_id(payload); +pub trait Lsps2OfferProvider: Send + Sync { + async fn get_offer( + &self, + request: &Lsps2PolicyGetInfoRequest, + ) -> AnyResult; - let req: RequestObject = serde_json::from_slice(&payload) - .map_err(|e| RpcError::parse_error(format!("failed to parse request: {e}")))?; + async fn get_channel_capacity( + &self, + params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> AnyResult; +} - if req.id.is_none() { - // Is a notification we can not reply so we just return - return Ok(vec![]); - } - let params = req - .params - .ok_or(RpcError::invalid_params("expected params but was missing"))?; +type Blockheight = u32; - let policy_params: Lsps2PolicyGetInfoRequest = params.into(); - let res_data: Lsps2PolicyGetInfoResponse = self - .api - .lsps2_getpolicy(&policy_params) - .await - .map_err(|e| RpcError { - code: 200, - message: format!("failed to fetch policy {e:#}"), - data: None, - })?; +#[async_trait] +pub trait BlockheightProvider: Send + Sync { + async fn get_blockheight(&self) -> AnyResult; +} - let opening_fee_params_menu = res_data - .policy_opening_fee_params_menu - .iter() - .map(|v| { - let promise: Promise = v - .get_hmac_hex(&self.promise_secret) - .try_into() - .map_err(|e| RpcError::internal_error(format!("invalid promise: {e}")))?; - Ok(OpeningFeeParams { - min_fee_msat: v.min_fee_msat, - proportional: v.proportional, - valid_until: v.valid_until, - min_lifetime: v.min_lifetime, - max_client_to_self_delay: v.max_client_to_self_delay, - min_payment_size_msat: v.min_payment_size_msat, - max_payment_size_msat: v.max_payment_size_msat, - promise, - }) - }) - .collect::, RpcError>>()?; +#[async_trait] +pub trait DatastoreProvider: Send + Sync { + async fn store_buy_request( + &self, + scid: &ShortChannelId, + peer_id: &PublicKey, + offer: &OpeningFeeParams, + expected_payment_size: &Option, + ) -> AnyResult; + + async fn get_buy_request(&self, scid: &ShortChannelId) -> AnyResult; + async fn del_buy_request(&self, scid: &ShortChannelId) -> AnyResult<()>; +} - let res = Lsps2GetInfoResponse { - opening_fee_params_menu, - } - .into_response(req.id.unwrap()); // We checked that we got an id before. +#[async_trait] +pub trait LightningProvider: Send + Sync { + async fn fund_jit_channel( + &self, + peer_id: &PublicKey, + amount: &Msat, + ) -> AnyResult<(Sha256, String)>; - serde_json::to_vec(&res) - .map_err(|e| RpcError::internal_error(format!("Failed to serialize response: {}", e))) - } + async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> AnyResult; } -pub struct Lsps2BuyHandler { - pub api: A, +pub struct Lsps2ServiceHandler { + pub api: Arc, pub promise_secret: [u8; 32], } -impl Lsps2BuyHandler { - pub fn new(api: A, promise_secret: [u8; 32]) -> Self { - Self { +impl Lsps2ServiceHandler { + pub fn new(api: Arc, promise_seret: &[u8; 32]) -> Self { + Lsps2ServiceHandler { api, - promise_secret, + promise_secret: promise_seret.to_owned(), } } } -#[async_trait] -impl RequestHandler for Lsps2BuyHandler { - async fn handle(&self, payload: &[u8]) -> core::result::Result, RpcError> { - let (payload, peer_id) = unwrap_payload_with_peer_id(payload); +async fn get_info_handler( + api: Arc, + secret: &[u8; 32], + request: &Lsps2GetInfoRequest, +) -> std::result::Result { + let res_data = api + .get_offer(&Lsps2PolicyGetInfoRequest { + token: request.token.clone(), + }) + .await + .map_err(|_| RpcError::internal_error("internal error"))?; + + if res_data.client_rejected { + return Err(RpcError::client_rejected("client was rejected")); + }; - let req: RequestObject = serde_json::from_slice(&payload) - .map_err(|e| RpcError::parse_error(format!("Failed to parse request: {}", e)))?; + let opening_fee_params_menu = res_data + .policy_opening_fee_params_menu + .iter() + .map(|v| make_opening_fee_params(v, secret)) + .collect::, _>>()?; - if req.id.is_none() { - // Is a notification we can not reply so we just return - return Ok(vec![]); - } + Ok(Lsps2GetInfoResponse { + opening_fee_params_menu, + }) +} - let req_params = req - .params - .ok_or_else(|| RpcError::invalid_request("Missing params field"))?; +fn make_opening_fee_params( + v: &PolicyOpeningFeeParams, + secret: &[u8; 32], +) -> Result { + let promise: Promise = v + .get_hmac_hex(secret) + .try_into() + .map_err(|_| RpcError::internal_error("internal error"))?; + Ok(OpeningFeeParams { + min_fee_msat: v.min_fee_msat, + proportional: v.proportional, + valid_until: v.valid_until, + min_lifetime: v.min_lifetime, + max_client_to_self_delay: v.max_client_to_self_delay, + min_payment_size_msat: v.min_payment_size_msat, + max_payment_size_msat: v.max_payment_size_msat, + promise, + }) +} - let fee_params = req_params.opening_fee_params; +#[async_trait] +impl Lsps2Handler + for Lsps2ServiceHandler +{ + async fn handle_get_info( + &self, + request: Lsps2GetInfoRequest, + ) -> std::result::Result { + get_info_handler(self.api.clone(), &self.promise_secret, &request).await + } + + async fn handle_buy( + &self, + peer_id: PublicKey, + request: Lsps2BuyRequest, + ) -> core::result::Result { + let fee_params = request.opening_fee_params; // FIXME: In the future we should replace the \`None\` with a meaningful // value that reflects the inbound capacity for this node from the // public network for a better pre-condition check on the payment_size. - fee_params.validate(&self.promise_secret, req_params.payment_size_msat, None)?; + fee_params.validate(&self.promise_secret, request.payment_size_msat, None)?; // Generate a tmp scid to identify jit channel request in htlc. - let get_info_req = GetinfoRequest {}; - let info = self.api.cln_getinfo(&get_info_req).await.map_err(|e| { - warn!("Failed to call getinfo via rpc {}", e); - RpcError::internal_error("Internal error") - })?; + let blockheight = self + .api + .get_blockheight() + .await + .map_err(|_| RpcError::internal_error("internal error"))?; // FIXME: Future task: Check that we don't conflict with any jit scid we // already handed out -> Check datastore entries. - let jit_scid_u64 = generate_jit_scid(info.blockheight); - let jit_scid = ShortChannelId::from(jit_scid_u64); - let ds_data = DatastoreEntry { - peer_id, - opening_fee_params: fee_params, - expected_payment_size: req_params.payment_size_msat, - }; - let ds_json = serde_json::to_string(&ds_data).map_err(|e| { - warn!("Failed to serialize opening fee params to string {}", e); - RpcError::internal_error("Internal error") - })?; + let jit_scid = ShortChannelId::from(generate_jit_scid(blockheight)); - let ds_req = DatastoreRequest { - generation: None, - hex: None, - mode: Some(DatastoreMode::MUST_CREATE), - string: Some(ds_json), - key: vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - jit_scid.to_string(), - ], - }; + let ok = self + .api + .store_buy_request(&jit_scid, &peer_id, &fee_params, &request.payment_size_msat) + .await + .map_err(|_| RpcError::internal_error("internal error"))?; - let _ds_res = self.api.cln_datastore(&ds_req).await.map_err(|e| { - warn!("Failed to store jit request in ds via rpc {}", e); - RpcError::internal_error("Internal error") - })?; + if !ok { + return Err(RpcError::internal_error("internal error"))?; + } - let res = Lsps2BuyResponse { + Ok(Lsps2BuyResponse { jit_channel_scid: jit_scid, // We can make this configurable if necessary. lsp_cltv_expiry_delta: DEFAULT_CLTV_EXPIRY_DELTA, // We can implement the other mode later on as we might have to do // some additional work on core-lightning to enable this. client_trusts_lsp: false, - } - .into_response(req.id.unwrap()); // We checked that we got an id before. - - serde_json::to_vec(&res) - .map_err(|e| RpcError::internal_error(format!("Failed to serialize response: {}", e))) + }) } } @@ -346,13 +400,13 @@ fn generate_jit_scid(best_blockheigt: u32) -> u64 { ((block as u64) << 40) | ((tx_idx as u64) << 16) | (output_idx as u64) } -pub struct HtlcAcceptedHookHandler { +pub struct HtlcAcceptedHookHandler { api: A, htlc_minimum_msat: u64, backoff_listpeerchannels: Duration, } -impl HtlcAcceptedHookHandler { +impl HtlcAcceptedHookHandler { pub fn new(api: A, htlc_minimum_msat: u64) -> Self { Self { api, @@ -360,7 +414,9 @@ impl HtlcAcceptedHookHandler { backoff_listpeerchannels: Duration::from_secs(10), } } +} +impl HtlcAcceptedHookHandler { pub async fn handle(&self, req: HtlcAcceptedRequest) -> AnyResult { let scid = match req.onion.short_channel_id { Some(scid) => scid, @@ -371,28 +427,9 @@ impl HtlcAcceptedHookHandler { }; // A) Is this SCID one that we care about? - let ds_req = ListdatastoreRequest { - key: Some(scid_ds_key(scid)), - }; - let ds_res = self.api.cln_listdatastore(&ds_req).await.map_err(|e| { - warn!("Failed to listpeerchannels via rpc {}", e); - RpcError::internal_error("Internal error") - })?; - - let (ds_rec, ds_gen) = match deserialize_by_key(&ds_res, scid_ds_key(scid)) { - Ok(r) => r, - Err(DsError::NotFound { .. }) => { - // We don't know the scid, continue. - return Ok(HtlcAcceptedResponse::continue_(None, None, None)); - } - Err(e @ DsError::MissingValue { .. }) - | Err(e @ DsError::HexDecode { .. }) - | Err(e @ DsError::JsonParse { .. }) => { - // We have a data issue, log and continue. - // Note: We may want to actually reject the htlc here or throw - // an error alltogether but we will try to fulfill this htlc for - // now. - warn!("datastore issue: {}", e); + let ds_rec = match self.api.get_buy_request(&scid).await { + Ok(rec) => rec, + Err(_) => { return Ok(HtlcAcceptedResponse::continue_(None, None, None)); } }; @@ -417,14 +454,7 @@ impl HtlcAcceptedHookHandler { let now = Utc::now(); if now >= ds_rec.opening_fee_params.valid_until { // Not valid anymore, remove from DS and fail HTLC. - let ds_req = DeldatastoreRequest { - generation: ds_gen, - key: scid_ds_key(scid), - }; - match self.api.cln_deldatastore(&ds_req).await { - Ok(_) => debug!("removed datastore for scid: {}, wasn't valid anymore", scid), - Err(e) => warn!("could not remove datastore for scid: {}: {}", scid, e), - }; + let _ = self.api.del_buy_request(&scid).await; return Ok(HtlcAcceptedResponse::fail( Some(TEMPORARY_CHANNEL_FAILURE.to_string()), None, @@ -472,7 +502,7 @@ impl HtlcAcceptedHookHandler { init_payment_size: Msat::from_msat(req.htlc.amount_msat.msat()), scid, }; - let ch_cap_res = match self.api.lsps2_getchannelcapacity(&ch_cap_req).await { + let ch_cap_res = match self.api.get_channel_capacity(&ch_cap_req).await { Ok(r) => r, Err(e) => { warn!("failed to get channel capacity for scid {}: {}", scid, e); @@ -484,7 +514,7 @@ impl HtlcAcceptedHookHandler { }; let cap = match ch_cap_res.channel_capacity_msat { - Some(c) => c, + Some(c) => Msat::from_msat(c), None => { debug!("policy giver does not allow channel for scid {}", scid); return Ok(HtlcAcceptedResponse::fail( @@ -500,27 +530,9 @@ impl HtlcAcceptedHookHandler { // (amount_msat - opening fee) in the future. // Fixme: Make this configurable, maybe return the whole request from // the policy giver? - let fund_ch_req = FundchannelRequest { - announce: Some(false), - close_to: None, - compact_lease: None, - feerate: None, - minconf: None, - mindepth: Some(0), - push_msat: None, - request_amt: None, - reserve: None, - channel_type: Some(vec![12, 46, 50]), - utxos: None, - amount: AmountOrAll::Amount(Amount::from_msat(cap)), - id: ds_rec.peer_id, - }; - - let fund_ch_res = match self.api.cln_fundchannel(&fund_ch_req).await { - Ok(r) => r, - Err(e) => { - // Fixme: Retry to fund the channel. - warn!("could not fund jit channel for scid {}: {}", scid, e); + let channel_id = match self.api.fund_jit_channel(&ds_rec.peer_id, &cap).await { + Ok((channel_id, _)) => channel_id, + Err(_) => { return Ok(HtlcAcceptedResponse::fail( Some(UNKNOWN_NEXT_PEER.to_string()), None, @@ -532,31 +544,15 @@ impl HtlcAcceptedHookHandler { // Fixme: Use event to check for channel ready, // Fixme: Check for htlc timeout if peer refuses to send "ready". // Fixme: handle unexpected channel states. - let mut is_active = false; - while !is_active { - let ls_ch_req = ListpeerchannelsRequest { - id: Some(ds_rec.peer_id), - short_channel_id: None, - }; - let ls_ch_res = match self.api.cln_listpeerchannels(&ls_ch_req).await { - Ok(r) => r, - Err(e) => { - warn!("failed to fetch peer channels for scid {}: {}", scid, e); - tokio::time::sleep(self.backoff_listpeerchannels).await; - continue; - } + loop { + match self + .api + .is_channel_ready(&ds_rec.peer_id, &channel_id) + .await + { + Ok(true) => break, + Ok(false) | Err(_) => tokio::time::sleep(self.backoff_listpeerchannels).await, }; - let chs = ls_ch_res - .channels - .iter() - .find(|&ch| ch.channel_id.is_some_and(|id| id == fund_ch_res.channel_id)); - if let Some(ch) = chs { - debug!("jit channel for scid {} has state {:?}", scid, ch.state); - if ch.state == ChannelState::CHANNELD_NORMAL { - is_active = true; - } - } - tokio::time::sleep(self.backoff_listpeerchannels).await; } // G) We got a working channel, deduct fee and forward htlc. @@ -575,7 +571,7 @@ impl HtlcAcceptedHookHandler { Ok(HtlcAcceptedResponse::continue_( Some(payload_bytes), - Some(fund_ch_res.channel_id.as_byte_array().to_vec()), + Some(channel_id.as_byte_array().to_vec()), Some(extra_tlvs_bytes), )) } @@ -620,14 +616,6 @@ impl fmt::Display for DsError { impl std::error::Error for DsError {} -fn scid_ds_key(scid: ShortChannelId) -> Vec { - vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - scid.to_string(), - ] -} - pub fn deserialize_by_key( resp: &ListdatastoreResponse, key: K, @@ -674,25 +662,20 @@ where #[cfg(test)] mod tests { - use std::sync::{Arc, Mutex}; - use super::*; use crate::{ - jsonrpc::{JsonRpcRequest, ResponseObject}, - lsps0::primitives::{Msat, Ppm}, - lsps2::{ - cln::{tlv::TlvStream, HtlcAcceptedResult}, - model::PolicyOpeningFeeParams, + lsps2::cln::{tlv::TlvStream, HtlcAcceptedResult}, + proto::{ + jsonrpc, + lsps0::Ppm, + lsps2::{Lsps2PolicyGetInfoResponse, PolicyOpeningFeeParams}, }, - util::wrap_payload_with_peer_id, }; + use anyhow::bail; use chrono::{TimeZone, Utc}; - use cln_rpc::{model::responses::ListdatastoreDatastore, RpcError as ClnRpcError}; - use cln_rpc::{ - model::responses::ListpeerchannelsChannels, - primitives::{Amount, PublicKey, Sha256}, - }; - use serde::Serialize; + use cln_rpc::primitives::{Amount, PublicKey}; + use cln_rpc::RpcError as ClnRpcError; + use std::sync::{Arc, Mutex}; const PUBKEY: [u8; 33] = [ 0x02, 0x79, 0xbe, 0x66, 0x7e, 0xf9, 0xdc, 0xbb, 0xac, 0x55, 0xa0, 0x62, 0x95, 0xce, 0x87, @@ -704,11 +687,6 @@ mod tests { PublicKey::from_slice(&PUBKEY).expect("Valid pubkey") } - fn create_wrapped_request(request: &RequestObject) -> Vec { - let payload = serde_json::to_vec(request).expect("Failed to serialize request"); - wrap_payload_with_peer_id(&payload, create_peer_id()) - } - /// Build a pair: policy params + buy params with a Promise derived from `secret` fn params_with_promise(secret: &[u8; 32]) -> (PolicyOpeningFeeParams, OpeningFeeParams) { let policy = PolicyOpeningFeeParams { @@ -739,39 +717,37 @@ mod tests { struct FakeCln { lsps2_getpolicy_response: Arc>>, lsps2_getpolicy_error: Arc>>, - cln_getinfo_response: Arc>>, - cln_getinfo_error: Arc>>, - cln_datastore_response: Arc>>, - cln_datastore_error: Arc>>, - cln_listdatastore_response: Arc>>, - cln_listdatastore_error: Arc>>, - cln_deldatastore_response: Arc>>, - cln_deldatastore_error: Arc>>, - cln_fundchannel_response: Arc>>, - cln_fundchannel_error: Arc>>, - cln_listpeerchannels_response: Arc>>, - cln_listpeerchannels_error: Arc>>, + blockheight_response: Option, + blockheight_error: Arc>>, + store_buy_request_response: bool, + get_buy_request_response: Arc>>, + get_buy_request_error: Arc>>, + fund_channel_error: Arc>>, + fund_channel_response: Arc>>, lsps2_getchannelcapacity_response: Arc>>, lsps2_getchannelcapacity_error: Arc>>, } #[async_trait] - impl ClnApi for FakeCln { - async fn lsps2_getpolicy( + impl Lsps2OfferProvider for FakeCln { + async fn get_offer( &self, - _params: &Lsps2PolicyGetInfoRequest, - ) -> Result { + _request: &Lsps2PolicyGetInfoRequest, + ) -> AnyResult { if let Some(err) = self.lsps2_getpolicy_error.lock().unwrap().take() { return Err(anyhow::Error::new(err).context("from fake api")); }; if let Some(res) = self.lsps2_getpolicy_response.lock().unwrap().take() { - return Ok(res); + return Ok(Lsps2PolicyGetInfoResponse { + policy_opening_fee_params_menu: res.policy_opening_fee_params_menu, + client_rejected: false, + }); }; panic!("No lsps2 response defined"); } - async fn lsps2_getchannelcapacity( + async fn get_channel_capacity( &self, _params: &Lsps2PolicyGetChannelCapacityRequest, ) -> AnyResult { @@ -788,151 +764,72 @@ mod tests { } panic!("No lsps2 getchannelcapacity response defined"); } + } - async fn cln_getinfo( - &self, - _params: &GetinfoRequest, - ) -> Result { - if let Some(err) = self.cln_getinfo_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); + #[async_trait] + impl BlockheightProvider for FakeCln { + async fn get_blockheight(&self) -> AnyResult { + if let Some(err) = self.blockheight_error.lock().unwrap().take() { + return Err(err); }; - if let Some(res) = self.cln_getinfo_response.lock().unwrap().take() { - return Ok(res); + if let Some(blockheight) = self.blockheight_response { + return Ok(blockheight); }; panic!("No cln getinfo response defined"); } + } - async fn cln_datastore( + #[async_trait] + impl DatastoreProvider for FakeCln { + async fn store_buy_request( &self, - _params: &DatastoreRequest, - ) -> Result { - if let Some(err) = self.cln_datastore_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); - }; - if let Some(res) = self.cln_datastore_response.lock().unwrap().take() { - return Ok(res); - }; - panic!("No cln datastore response defined"); + _scid: &ShortChannelId, + _peer_id: &PublicKey, + _offer: &OpeningFeeParams, + _payment_size_msat: &Option, + ) -> AnyResult { + Ok(self.store_buy_request_response) } - async fn cln_listdatastore( - &self, - _params: &ListdatastoreRequest, - ) -> AnyResult { - if let Some(err) = self.cln_listdatastore_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); + async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { + if let Some(err) = self.get_buy_request_error.lock().unwrap().take() { + return Err(err); } - if let Some(res) = self.cln_listdatastore_response.lock().unwrap().take() { + if let Some(res) = self.get_buy_request_response.lock().unwrap().take() { return Ok(res); + } else { + bail!("request not found") } - panic!("No cln listdatastore response defined"); } - async fn cln_deldatastore( - &self, - _params: &DeldatastoreRequest, - ) -> AnyResult { - if let Some(err) = self.cln_deldatastore_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); - } - if let Some(res) = self.cln_deldatastore_response.lock().unwrap().take() { - return Ok(res); - } - panic!("No cln deldatastore response defined"); + async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { + Ok(()) } + } - async fn cln_fundchannel( + #[async_trait] + impl LightningProvider for FakeCln { + async fn fund_jit_channel( &self, - _params: &FundchannelRequest, - ) -> AnyResult { - if let Some(err) = self.cln_fundchannel_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); + _peer_id: &PublicKey, + _amount: &Msat, + ) -> AnyResult<(Sha256, String)> { + if let Some(err) = self.fund_channel_error.lock().unwrap().take() { + return Err(err); } - if let Some(res) = self.cln_fundchannel_response.lock().unwrap().take() { + if let Some(res) = self.fund_channel_response.lock().unwrap().take() { return Ok(res); + } else { + bail!("request not found") } - panic!("No cln fundchannel response defined"); } - async fn cln_listpeerchannels( + async fn is_channel_ready( &self, - _params: &ListpeerchannelsRequest, - ) -> AnyResult { - if let Some(err) = self.cln_listpeerchannels_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); - } - - if let Some(res) = self.cln_listpeerchannels_response.lock().unwrap().take() { - return Ok(res); - } - - // Default: return a ready channel - let channel = ListpeerchannelsChannels { - channel_id: Some(*Sha256::from_bytes_ref(&[1u8; 32])), - state: ChannelState::CHANNELD_NORMAL, - peer_id: create_peer_id(), - peer_connected: true, - alias: None, - closer: None, - funding: None, - funding_outnum: None, - funding_txid: None, - htlcs: None, - in_offered_msat: None, - initial_feerate: None, - last_feerate: None, - last_stable_connection: None, - last_tx_fee_msat: None, - lost_state: None, - max_accepted_htlcs: None, - minimum_htlc_in_msat: None, - next_feerate: None, - next_fee_step: None, - out_fulfilled_msat: None, - out_offered_msat: None, - owner: None, - private: None, - receivable_msat: None, - reestablished: None, - scratch_txid: None, - short_channel_id: None, - spendable_msat: None, - status: None, - their_reserve_msat: None, - to_us_msat: None, - total_msat: None, - close_to: None, - close_to_addr: None, - direction: None, - dust_limit_msat: None, - fee_base_msat: None, - fee_proportional_millionths: None, - feerate: None, - ignore_fee_limits: None, - in_fulfilled_msat: None, - in_payments_fulfilled: None, - in_payments_offered: None, - max_to_us_msat: None, - maximum_htlc_out_msat: None, - min_to_us_msat: None, - minimum_htlc_out_msat: None, - our_max_htlc_value_in_flight_msat: None, - our_reserve_msat: None, - our_to_self_delay: None, - out_payments_fulfilled: None, - out_payments_offered: None, - their_max_htlc_value_in_flight_msat: None, - their_to_self_delay: None, - updates: None, - inflight: None, - #[allow(deprecated)] - max_total_htlc_in_msat: None, - opener: cln_rpc::primitives::ChannelSide::LOCAL, - }; - - Ok(ListpeerchannelsResponse { - channels: vec![channel], - }) + _peer_id: &PublicKey, + _channel_id: &Sha256, + ) -> AnyResult { + Ok(true) } } @@ -978,32 +875,15 @@ mod tests { } } - fn minimal_getinfo(height: u32) -> GetinfoResponse { - GetinfoResponse { - lightning_dir: String::default(), - alias: None, - our_features: None, - warning_bitcoind_sync: None, - warning_lightningd_sync: None, - address: None, - binding: None, - blockheight: height, - color: String::default(), - fees_collected_msat: Amount::from_msat(0), - id: PublicKey::from_slice(&PUBKEY).expect("pubkey from slice"), - network: String::default(), - num_active_channels: u32::default(), - num_inactive_channels: u32::default(), - num_peers: u32::default(), - num_pending_channels: u32::default(), - version: String::default(), - } + fn test_promise_secret() -> [u8; 32] { + [0x42; 32] } #[tokio::test] async fn test_successful_get_info() { - let promise_secret = [0u8; 32]; + let promise_secret = test_promise_secret(); let params = Lsps2PolicyGetInfoResponse { + client_rejected: false, policy_opening_fee_params_menu: vec![PolicyOpeningFeeParams { min_fee_msat: Msat(2000), proportional: Ppm(10000), @@ -1017,114 +897,114 @@ mod tests { let promise = params.policy_opening_fee_params_menu[0].get_hmac_hex(&promise_secret); let fake = FakeCln::default(); *fake.lsps2_getpolicy_response.lock().unwrap() = Some(params); - let handler = Lsps2GetInfoHandler::new(fake, promise_secret); - let request = Lsps2GetInfoRequest { token: None }.into_request(Some("test-id".to_string())); - let payload = create_wrapped_request(&request); + let handler = Lsps2ServiceHandler { + api: Arc::new(fake), + promise_secret, + }; - let result = handler.handle(&payload).await.unwrap(); - let response: ResponseObject = - serde_json::from_slice(&result).unwrap(); - let response = response.into_inner().unwrap(); + let request = Lsps2GetInfoRequest { token: None }; + let result = handler.handle_get_info(request).await.unwrap(); assert_eq!( - response.opening_fee_params_menu[0].min_payment_size_msat, + result.opening_fee_params_menu[0].min_payment_size_msat, Msat(1000000) ); assert_eq!( - response.opening_fee_params_menu[0].max_payment_size_msat, + result.opening_fee_params_menu[0].max_payment_size_msat, Msat(100000000) ); assert_eq!( - response.opening_fee_params_menu[0].promise, + result.opening_fee_params_menu[0].promise, promise.try_into().unwrap() ); } #[tokio::test] async fn test_get_info_rpc_error_handling() { + let promise_secret = test_promise_secret(); let fake = FakeCln::default(); *fake.lsps2_getpolicy_error.lock().unwrap() = Some(ClnRpcError { code: Some(-1), message: "not found".to_string(), data: None, }); - let handler = Lsps2GetInfoHandler::new(fake, [0; 32]); - let request = Lsps2GetInfoRequest { token: None }.into_request(Some("test-id".to_string())); - let payload = create_wrapped_request(&request); - let result = handler.handle(&payload).await; + let handler = Lsps2ServiceHandler { + api: Arc::new(fake), + promise_secret, + }; + + let request = Lsps2GetInfoRequest { token: None }; + let result = handler.handle_get_info(request).await; assert!(result.is_err()); let error = result.unwrap_err(); - assert_eq!(error.code, 200); - assert!(error.message.contains("failed to fetch policy")); + assert_eq!(error.code, jsonrpc::INTERNAL_ERROR); + assert!(error.message.contains("internal error")); } #[tokio::test] async fn buy_ok_fixed_amount() { - let secret = [0u8; 32]; - let fake = FakeCln::default(); - *fake.cln_getinfo_response.lock().unwrap() = Some(minimal_getinfo(900_000)); - *fake.cln_datastore_response.lock().unwrap() = Some(DatastoreResponse { - generation: Some(0), - hex: None, - string: None, - key: vec![], - }); + let promise_secret = test_promise_secret(); + let mut fake = FakeCln::default(); + fake.blockheight_response = Some(900_000); + fake.store_buy_request_response = true; + + let handler = Lsps2ServiceHandler { + api: Arc::new(fake), + promise_secret, + }; - let handler = Lsps2BuyHandler::new(fake, secret); - let (_policy, buy) = params_with_promise(&secret); + let (_policy, buy) = params_with_promise(&promise_secret); // Set payment_size_msat => "MPP+fixed-invoice" mode. - let req = Lsps2BuyRequest { + let request = Lsps2BuyRequest { opening_fee_params: buy, payment_size_msat: Some(Msat(2_000_000)), - } - .into_request(Some("ok-fixed".into())); - let payload = create_wrapped_request(&req); + }; + let peer_id = create_peer_id(); - let out = handler.handle(&payload).await.unwrap(); - let resp: ResponseObject = serde_json::from_slice(&out).unwrap(); - let resp = resp.into_inner().unwrap(); + let result = handler.handle_buy(peer_id, request).await.unwrap(); - assert_eq!(resp.lsp_cltv_expiry_delta, DEFAULT_CLTV_EXPIRY_DELTA); - assert!(!resp.client_trusts_lsp); - assert!(resp.jit_channel_scid.to_u64() > 0); + assert_eq!(result.lsp_cltv_expiry_delta, DEFAULT_CLTV_EXPIRY_DELTA); + assert!(!result.client_trusts_lsp); + assert!(result.jit_channel_scid.to_u64() > 0); } #[tokio::test] async fn buy_ok_variable_amount_no_payment_size() { - let secret = [2u8; 32]; - let fake = FakeCln::default(); - *fake.cln_getinfo_response.lock().unwrap() = Some(minimal_getinfo(900_100)); - *fake.cln_datastore_response.lock().unwrap() = Some(DatastoreResponse { - generation: Some(0), - hex: None, - string: None, - key: vec![], - }); + let promise_secret = test_promise_secret(); + let mut fake = FakeCln::default(); + fake.blockheight_response = Some(900_100); + fake.store_buy_request_response = true; + + let handler = Lsps2ServiceHandler { + api: Arc::new(fake), + promise_secret, + }; - let handler = Lsps2BuyHandler::new(fake, secret); - let (_policy, buy) = params_with_promise(&secret); + let (_policy, buy) = params_with_promise(&promise_secret); // No payment_size_msat => "no-MPP+var-invoice" mode. - let req = Lsps2BuyRequest { + let request = Lsps2BuyRequest { opening_fee_params: buy, payment_size_msat: None, - } - .into_request(Some("ok-var".into())); - let payload = create_wrapped_request(&req); + }; + let peer_id = create_peer_id(); - let out = handler.handle(&payload).await.unwrap(); - let resp: ResponseObject = serde_json::from_slice(&out).unwrap(); - assert!(resp.into_inner().is_ok()); + let result = handler.handle_buy(peer_id, request).await; + + assert!(result.is_ok()); } #[tokio::test] async fn buy_rejects_invalid_promise_or_past_valid_until_with_201() { - let secret = [3u8; 32]; - let handler = Lsps2BuyHandler::new(FakeCln::default(), secret); + let promise_secret = test_promise_secret(); + let handler = Lsps2ServiceHandler { + api: Arc::new(FakeCln::default()), + promise_secret, + }; // Case A: wrong promise (derive with different secret) let (_policy_wrong, mut buy_wrong) = params_with_promise(&[9u8; 32]); @@ -1132,33 +1012,30 @@ mod tests { let req_wrong = Lsps2BuyRequest { opening_fee_params: buy_wrong, payment_size_msat: Some(Msat(2_000_000)), - } - .into_request(Some("bad-promise".into())); - let err1 = handler - .handle(&create_wrapped_request(&req_wrong)) - .await - .unwrap_err(); + }; + let peer_id = create_peer_id(); + + let err1 = handler.handle_buy(peer_id, req_wrong).await.unwrap_err(); assert_eq!(err1.code, 201); // Case B: past valid_until - let (_policy, mut buy_past) = params_with_promise(&secret); + let (_policy, mut buy_past) = params_with_promise(&promise_secret); buy_past.valid_until = Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(); // past let req_past = Lsps2BuyRequest { opening_fee_params: buy_past, payment_size_msat: Some(Msat(2_000_000)), - } - .into_request(Some("past-valid".into())); - let err2 = handler - .handle(&create_wrapped_request(&req_past)) - .await - .unwrap_err(); + }; + let err2 = handler.handle_buy(peer_id, req_past).await.unwrap_err(); assert_eq!(err2.code, 201); } #[tokio::test] async fn buy_rejects_when_opening_fee_ge_payment_size_with_202() { - let secret = [4u8; 32]; - let handler = Lsps2BuyHandler::new(FakeCln::default(), secret); + let promise_secret = test_promise_secret(); + let handler = Lsps2ServiceHandler { + api: Arc::new(FakeCln::default()), + promise_secret, + }; // Make min_fee already >= payment_size to trigger 202 let policy = PolicyOpeningFeeParams { @@ -1170,7 +1047,7 @@ mod tests { min_payment_size_msat: Msat(1), max_payment_size_msat: Msat(u64::MAX / 2), }; - let hex = policy.get_hmac_hex(&secret); + let hex = policy.get_hmac_hex(&promise_secret); let promise: Promise = hex.try_into().unwrap(); let buy = OpeningFeeParams { min_fee_msat: policy.min_fee_msat, @@ -1183,23 +1060,23 @@ mod tests { promise, }; - let req = Lsps2BuyRequest { + let request = Lsps2BuyRequest { opening_fee_params: buy, payment_size_msat: Some(Msat(9_999)), // strictly less than min_fee => opening_fee >= payment_size - } - .into_request(Some("too-small".into())); + }; + let peer_id = create_peer_id(); + let err = handler.handle_buy(peer_id, request).await.unwrap_err(); - let err = handler - .handle(&create_wrapped_request(&req)) - .await - .unwrap_err(); assert_eq!(err.code, 202); } #[tokio::test] async fn buy_rejects_on_fee_overflow_with_203() { - let secret = [5u8; 32]; - let handler = Lsps2BuyHandler::new(FakeCln::default(), secret); + let promise_secret = test_promise_secret(); + let handler = Lsps2ServiceHandler { + api: Arc::new(FakeCln::default()), + promise_secret, + }; // Choose values likely to overflow if multiplication isn't checked: // opening_fee = min_fee + payment_size * proportional / 1_000_000 @@ -1212,7 +1089,7 @@ mod tests { min_payment_size_msat: Msat(1), max_payment_size_msat: Msat(u64::MAX), }; - let hex = policy.get_hmac_hex(&secret); + let hex = policy.get_hmac_hex(&promise_secret); let promise: Promise = hex.try_into().unwrap(); let buy = OpeningFeeParams { min_fee_msat: policy.min_fee_msat, @@ -1225,16 +1102,13 @@ mod tests { promise, }; - let req = Lsps2BuyRequest { + let request = Lsps2BuyRequest { opening_fee_params: buy, payment_size_msat: Some(Msat(u64::MAX / 2)), - } - .into_request(Some("overflow".into())); + }; + let peer_id = create_peer_id(); + let err = handler.handle_buy(peer_id, request).await.unwrap_err(); - let err = handler - .handle(&create_wrapped_request(&req)) - .await - .unwrap_err(); assert_eq!(err.code, 203); } #[tokio::test] @@ -1252,13 +1126,10 @@ mod tests { #[tokio::test] async fn test_htlc_unknown_scid_continues() { let fake = FakeCln::default(); + let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); let scid = ShortChannelId::from(123456789u64); - // Return empty datastore response (SCID not found) - *fake.cln_listdatastore_response.lock().unwrap() = - Some(ListdatastoreResponse { datastore: vec![] }); - let req = create_test_htlc_request(Some(scid), 1000000); let result = handler.handle(req).await.unwrap(); @@ -1277,23 +1148,7 @@ mod tests { ds_entry.opening_fee_params.valid_until = Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(); // expired - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); - - // Mock successful deletion - *fake.cln_deldatastore_response.lock().unwrap() = Some(DeldatastoreResponse { - generation: Some(1), - hex: None, - string: None, - key: scid_ds_key(scid), - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); let req = create_test_htlc_request(Some(scid), 1000000); @@ -1313,16 +1168,7 @@ mod tests { let scid = ShortChannelId::from(123456789u64); let ds_entry = create_test_datastore_entry(peer_id, None); - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); // HTLC amount below minimum let req = create_test_htlc_request(Some(scid), 100); @@ -1343,16 +1189,7 @@ mod tests { let scid = ShortChannelId::from(123456789u64); let ds_entry = create_test_datastore_entry(peer_id, None); - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); // HTLC amount above maximum let req = create_test_htlc_request(Some(scid), 200_000_000); @@ -1373,16 +1210,7 @@ mod tests { let scid = ShortChannelId::from(123456789u64); let ds_entry = create_test_datastore_entry(peer_id, None); - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); // HTLC amount just barely covers minimum fee but not minimum HTLC let req = create_test_htlc_request(Some(scid), 2500); // min_fee is 2000, htlc_minimum is 1000 @@ -1403,16 +1231,7 @@ mod tests { let scid = ShortChannelId::from(123456789u64); let ds_entry = create_test_datastore_entry(peer_id, None); - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); *fake.lsps2_getchannelcapacity_error.lock().unwrap() = Some(ClnRpcError { code: Some(-1), @@ -1438,16 +1257,7 @@ mod tests { let scid = ShortChannelId::from(123456789u64); let ds_entry = create_test_datastore_entry(peer_id, None); - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); // Policy response with no channel capacity (denied) *fake.lsps2_getchannelcapacity_response.lock().unwrap() = @@ -1473,27 +1283,14 @@ mod tests { let scid = ShortChannelId::from(123456789u64); let ds_entry = create_test_datastore_entry(peer_id, None); - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); *fake.lsps2_getchannelcapacity_response.lock().unwrap() = Some(Lsps2PolicyGetChannelCapacityResponse { channel_capacity_msat: Some(50_000_000), }); - *fake.cln_fundchannel_error.lock().unwrap() = Some(ClnRpcError { - code: Some(-1), - message: "insufficient funds".to_string(), - data: None, - }); + *fake.fund_channel_error.lock().unwrap() = Some(anyhow::anyhow!("insufficient funds")); let req = create_test_htlc_request(Some(scid), 10_000_000); @@ -1517,31 +1314,15 @@ mod tests { let scid = ShortChannelId::from(123456789u64); let ds_entry = create_test_datastore_entry(peer_id, None); - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); *fake.lsps2_getchannelcapacity_response.lock().unwrap() = Some(Lsps2PolicyGetChannelCapacityResponse { channel_capacity_msat: Some(50_000_000), }); - *fake.cln_fundchannel_response.lock().unwrap() = Some(FundchannelResponse { - channel_id: *Sha256::from_bytes_ref(&[1u8; 32]), - outnum: 0, - txid: String::default(), - channel_type: None, - close_to: None, - mindepth: None, - tx: String::default(), - }); + *fake.fund_channel_response.lock().unwrap() = + Some((*Sha256::from_bytes_ref(&[1u8; 32]), String::default())); let req = create_test_htlc_request(Some(scid), 10_000_000); @@ -1572,16 +1353,7 @@ mod tests { // Create entry with expected_payment_size (MPP mode) let mut ds_entry = create_test_datastore_entry(peer_id, None); ds_entry.expected_payment_size = Some(Msat::from_msat(1000000)); - let ds_entry_json = serde_json::to_string(&ds_entry).unwrap(); - - *fake.cln_listdatastore_response.lock().unwrap() = Some(ListdatastoreResponse { - datastore: vec![ListdatastoreDatastore { - key: scid_ds_key(scid), - generation: Some(1), - hex: None, - string: Some(ds_entry_json), - }], - }); + *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); let req = create_test_htlc_request(Some(scid), 10_000_000); diff --git a/plugins/lsps-plugin/src/core/lsps2/htlc.rs b/plugins/lsps-plugin/src/core/lsps2/htlc.rs new file mode 100644 index 000000000000..6e39cc07cf51 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/htlc.rs @@ -0,0 +1,802 @@ +use crate::{ + core::{ + lsps2::provider::{DatastoreProvider, LightningProvider, Lsps2OfferProvider}, + tlv::{TlvStream, TLV_FORWARD_AMT}, + }, + proto::{ + lsps0::{Msat, ShortChannelId}, + lsps2::{ + compute_opening_fee, + failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, + Lsps2PolicyGetChannelCapacityRequest, + }, + }, +}; +use bitcoin::hashes::sha256::Hash; +use chrono::Utc; +use std::time::Duration; +use thiserror::Error; + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum HtlcDecision { + NotOurs, + Forward { + payload: TlvStream, + forward_to: Hash, + extra_tlvs: TlvStream, + }, + + Reject { + reason: RejectReason, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum RejectReason { + OfferExpired { valid_until: chrono::DateTime }, + AmountBelowMinimum { minimum: Msat }, + AmountAboveMaximum { maximum: Msat }, + InsufficientForFee { fee: Msat }, + FeeOverflow, + PolicyDenied, + FundingFailed, + + // temporarily + MppNotSupported, +} + +impl RejectReason { + pub fn failure_code(&self) -> &'static str { + match self { + Self::OfferExpired { .. } => TEMPORARY_CHANNEL_FAILURE, + _ => UNKNOWN_NEXT_PEER, + } + } +} + +#[derive(Debug, Error)] +pub enum HtlcError { + #[error("failed to query channel capacity: {0}")] + CapacityQuery(#[source] anyhow::Error), + #[error("failed to fund channel: {0}")] + FundChannel(#[source] anyhow::Error), + #[error("channel ready check failed: {0}")] + ChannelReadyCheck(#[source] anyhow::Error), +} + +#[derive(Debug, Clone)] +pub struct Htlc { + pub amount_msat: Msat, + pub extra_tlvs: TlvStream, +} +impl Htlc { + pub fn new(amount_msat: Msat, tlvs: TlvStream) -> Self { + Self { + amount_msat, + extra_tlvs: tlvs, + } + } +} + +#[derive(Debug, Clone)] +pub struct Onion { + pub short_channel_id: ShortChannelId, + pub payload: TlvStream, +} + +pub struct HtlcAcceptedHookHandler { + api: A, + htlc_minimum_msat: u64, + backoff_listpeerchannels: Duration, +} + +impl HtlcAcceptedHookHandler { + pub fn new(api: A, htlc_minimum_msat: u64) -> Self { + Self { + api, + htlc_minimum_msat, + backoff_listpeerchannels: Duration::from_secs(10), + } + } +} +impl HtlcAcceptedHookHandler { + pub async fn handle(&self, htlc: &Htlc, onion: &Onion) -> Result { + // A) Is this SCID one that we care about? + let ds_rec = match self.api.get_buy_request(&onion.short_channel_id).await { + Ok(rec) => rec, + Err(_) => return Ok(HtlcDecision::NotOurs), + }; + + // Fixme: Check that we don't have a channel yet with the peer that we await to + // become READY to use. + // --- + + // Fixme: We only accept no-mpp for now, mpp and other flows will be added later on + // Fixme: We continue mpp for now to let the test mock handle the htlc, as we need + // to test the client implementation for mpp payments. + if ds_rec.expected_payment_size.is_some() { + return Ok(HtlcDecision::Reject { + reason: RejectReason::MppNotSupported, + }); + } + + // B) Is the fee option menu still valid? + if Utc::now() >= ds_rec.opening_fee_params.valid_until { + // Not valid anymore, remove from DS and fail HTLC. + let _ = self.api.del_buy_request(&onion.short_channel_id).await; + return Ok(HtlcDecision::Reject { + reason: RejectReason::OfferExpired { + valid_until: ds_rec.opening_fee_params.valid_until, + }, + }); + } + + // C) Is the amount in the boundaries of the fee menu? + if htlc.amount_msat.msat() < ds_rec.opening_fee_params.min_fee_msat.msat() { + return Ok(HtlcDecision::Reject { + reason: RejectReason::AmountBelowMinimum { + minimum: ds_rec.opening_fee_params.min_fee_msat, + }, + }); + } + + if htlc.amount_msat.msat() > ds_rec.opening_fee_params.max_payment_size_msat.msat() { + return Ok(HtlcDecision::Reject { + reason: RejectReason::AmountAboveMaximum { + maximum: ds_rec.opening_fee_params.max_payment_size_msat, + }, + }); + } + + // D) Check that the amount_msat covers the opening fee (only for non-mpp right now) + let opening_fee = match compute_opening_fee( + htlc.amount_msat.msat(), + ds_rec.opening_fee_params.min_fee_msat.msat(), + ds_rec.opening_fee_params.proportional.ppm() as u64, + ) { + Some(fee) if fee + self.htlc_minimum_msat < htlc.amount_msat.msat() => fee, + Some(fee) => { + return Ok(HtlcDecision::Reject { + reason: RejectReason::InsufficientForFee { + fee: Msat::from_msat(fee), + }, + }) + } + None => { + return Ok(HtlcDecision::Reject { + reason: RejectReason::FeeOverflow, + }) + } + }; + + // E) We made it, open a channel to the peer. + let ch_cap_req = Lsps2PolicyGetChannelCapacityRequest { + opening_fee_params: ds_rec.opening_fee_params, + init_payment_size: htlc.amount_msat, + scid: onion.short_channel_id, + }; + let ch_cap_res = self + .api + .get_channel_capacity(&ch_cap_req) + .await + .map_err(HtlcError::CapacityQuery)?; + + let cap = match ch_cap_res.channel_capacity_msat { + Some(c) => Msat::from_msat(c), + None => { + return Ok(HtlcDecision::Reject { + reason: RejectReason::PolicyDenied, + }) + } + }; + + // We take the policy-giver seriously, if the capacity is too low, we + // still try to open the channel. + // Fixme: We may check that the capacity is ge than the + // (amount_msat - opening fee) in the future. + // Fixme: Make this configurable, maybe return the whole request from + // the policy giver? + let (channel_id, _) = self + .api + .fund_jit_channel(&ds_rec.peer_id, &cap) + .await + .map_err(HtlcError::FundChannel)?; + + // F) Wait for the peer to send `channel_ready`. + // Fixme: Use event to check for channel ready, + // Fixme: Check for htlc timeout if peer refuses to send "ready". + // Fixme: handle unexpected channel states. + loop { + match self + .api + .is_channel_ready(&ds_rec.peer_id, &channel_id) + .await + { + Ok(true) => break, + Ok(false) => tokio::time::sleep(self.backoff_listpeerchannels).await, + Err(e) => return Err(HtlcError::ChannelReadyCheck(e)), + }; + } + + // G) We got a working channel, deduct fee and forward htlc. + let deducted_amt_msat = htlc.amount_msat.msat() - opening_fee; + let mut payload = onion.payload.clone(); + payload.set_tu64(TLV_FORWARD_AMT, deducted_amt_msat); + + let mut extra_tlvs = htlc.extra_tlvs.clone(); + extra_tlvs.set_u64(65537, opening_fee); + + Ok(HtlcDecision::Forward { + payload, + forward_to: channel_id, + extra_tlvs, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::tlv::TlvStream; + use crate::proto::lsps0::{Msat, Ppm, ShortChannelId}; + use crate::proto::lsps2::{ + DatastoreEntry, Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, Promise, + }; + use anyhow::{anyhow, Result as AnyResult}; + use async_trait::async_trait; + use bitcoin::hashes::{sha256::Hash as Sha256, Hash}; + use bitcoin::secp256k1::PublicKey; + use chrono::{TimeZone, Utc}; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + use std::u64; + + fn test_peer_id() -> PublicKey { + "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" + .parse() + .unwrap() + } + + fn test_scid() -> ShortChannelId { + ShortChannelId::from(123456789u64) + } + + fn test_channel_id() -> Sha256 { + Sha256::from_byte_array([1u8; 32]) + } + + fn valid_opening_fee_params() -> OpeningFeeParams { + OpeningFeeParams { + min_fee_msat: Msat(2_000), + proportional: Ppm(10_000), // 1% + valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), + min_lifetime: 1000, + max_client_to_self_delay: 2016, + min_payment_size_msat: Msat(1_000_000), + max_payment_size_msat: Msat(100_000_000), + promise: Promise::try_from("test").unwrap(), + } + } + + fn expired_opening_fee_params() -> OpeningFeeParams { + OpeningFeeParams { + valid_until: Utc.with_ymd_and_hms(2000, 1, 1, 0, 0, 0).unwrap(), + ..valid_opening_fee_params() + } + } + + fn test_datastore_entry(expected_payment_size: Option) -> DatastoreEntry { + DatastoreEntry { + peer_id: test_peer_id(), + opening_fee_params: valid_opening_fee_params(), + expected_payment_size, + } + } + + fn test_onion(scid: ShortChannelId, payload: TlvStream) -> Onion { + Onion { + short_channel_id: scid, + payload, + } + } + + fn test_htlc(amount_msat: u64, extra_tlvs: TlvStream) -> Htlc { + Htlc { + amount_msat: Msat::from_msat(amount_msat), + extra_tlvs, + } + } + + #[derive(Default, Clone)] + struct MockApi { + // Datastore + buy_request: Arc>>, + buy_request_error: Arc>, + del_called: Arc, + + // Policy + channel_capacity: Arc>>>, // Some(Some(cap)), Some(None) = denied, None = error + channel_capacity_error: Arc>, + + // Lightning + fund_result: Arc>>, + fund_error: Arc>, + channel_ready: Arc>, + channel_ready_checks: Arc, + } + + impl MockApi { + fn new() -> Self { + Self::default() + } + + fn with_buy_request(self, entry: DatastoreEntry) -> Self { + *self.buy_request.lock().unwrap() = Some(entry); + self + } + + fn with_no_buy_request(self) -> Self { + *self.buy_request_error.lock().unwrap() = true; + self + } + + fn with_channel_capacity(self, capacity_msat: u64) -> Self { + *self.channel_capacity.lock().unwrap() = Some(Some(capacity_msat)); + self + } + + fn with_channel_denied(self) -> Self { + *self.channel_capacity.lock().unwrap() = Some(None); + self + } + + fn with_channel_capacity_error(self) -> Self { + *self.channel_capacity_error.lock().unwrap() = true; + self + } + + fn with_fund_result(self, channel_id: Sha256, txid: &str) -> Self { + *self.fund_result.lock().unwrap() = Some((channel_id, txid.to_string())); + self + } + + fn with_fund_error(self) -> Self { + *self.fund_error.lock().unwrap() = true; + self + } + + fn with_channel_ready(self, ready: bool) -> Self { + *self.channel_ready.lock().unwrap() = ready; + self + } + + fn del_call_count(&self) -> usize { + self.del_called.load(Ordering::SeqCst) + } + + fn channel_ready_check_count(&self) -> usize { + self.channel_ready_checks.load(Ordering::SeqCst) + } + } + + #[async_trait] + impl DatastoreProvider for MockApi { + async fn store_buy_request( + &self, + _scid: &ShortChannelId, + _peer_id: &PublicKey, + _fee_params: &OpeningFeeParams, + _payment_size: &Option, + ) -> AnyResult { + unimplemented!("not needed for HTLC tests") + } + + async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { + if *self.buy_request_error.lock().unwrap() { + return Err(anyhow!("not found")); + } + self.buy_request + .lock() + .unwrap() + .clone() + .ok_or_else(|| anyhow!("not found")) + } + + async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { + self.del_called.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + } + + #[async_trait] + impl Lsps2OfferProvider for MockApi { + async fn get_offer( + &self, + _request: &Lsps2PolicyGetInfoRequest, + ) -> AnyResult { + unimplemented!("not needed for HTLC tests") + } + + async fn get_channel_capacity( + &self, + _params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> AnyResult { + if *self.channel_capacity_error.lock().unwrap() { + return Err(anyhow!("capacity error")); + } + let cap = self + .channel_capacity + .lock() + .unwrap() + .ok_or_else(|| anyhow!("no capacity set"))?; + Ok(Lsps2PolicyGetChannelCapacityResponse { + channel_capacity_msat: cap, + }) + } + } + + #[async_trait] + impl LightningProvider for MockApi { + async fn fund_jit_channel( + &self, + _peer_id: &PublicKey, + _amount: &Msat, + ) -> AnyResult<(Sha256, String)> { + if *self.fund_error.lock().unwrap() { + return Err(anyhow!("fund error")); + } + self.fund_result + .lock() + .unwrap() + .clone() + .ok_or_else(|| anyhow!("no fund result set")) + } + + async fn is_channel_ready( + &self, + _peer_id: &PublicKey, + _channel_id: &Sha256, + ) -> AnyResult { + self.channel_ready_checks.fetch_add(1, Ordering::SeqCst); + Ok(*self.channel_ready.lock().unwrap()) + } + } + + fn handler(api: MockApi) -> HtlcAcceptedHookHandler { + HtlcAcceptedHookHandler { + api, + htlc_minimum_msat: 1_000, + backoff_listpeerchannels: Duration::from_millis(1), // Fast for tests + } + } + + #[tokio::test] + async fn continues_when_scid_not_found() { + let api = MockApi::new().with_no_buy_request(); + let h = handler(api); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert_eq!(result, HtlcDecision::NotOurs); + } + + #[tokio::test] + async fn continues_when_mpp_payment() { + let entry = test_datastore_entry(Some(Msat(50_000_000))); // MPP = has expected size + let api = MockApi::new().with_buy_request(entry); + let h = handler(api); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert_eq!( + result, + HtlcDecision::Reject { + reason: RejectReason::MppNotSupported + } + ); + } + + #[tokio::test] + async fn fails_when_offer_expired() { + let mut entry = test_datastore_entry(None); + entry.opening_fee_params = expired_opening_fee_params(); + let api = MockApi::new().with_buy_request(entry); + let h = handler(api.clone()); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert!(matches!( + result, + HtlcDecision::Reject { + reason: RejectReason::OfferExpired { .. } + } + )); + assert_eq!(api.del_call_count(), 1); // Should delete expired entry + } + + #[tokio::test] + async fn fails_when_amount_below_min_fee() { + let entry = test_datastore_entry(None); + let api = MockApi::new().with_buy_request(entry); + let h = handler(api); + + // min_fee_msat is 2_000 + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(1_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert!(matches!( + result, + HtlcDecision::Reject { + reason: RejectReason::AmountBelowMinimum { .. } + } + )); + } + + #[tokio::test] + async fn fails_when_amount_above_max() { + let entry = test_datastore_entry(None); + let api = MockApi::new().with_buy_request(entry); + let h = handler(api); + + // max_payment_size_msat is 100_000_000 + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(200_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert!(matches!( + result, + HtlcDecision::Reject { + reason: RejectReason::AmountAboveMaximum { .. } + } + )); + } + + #[tokio::test] + async fn fails_when_amount_doesnt_cover_fee_plus_minimum() { + let entry = test_datastore_entry(None); + let api = MockApi::new().with_buy_request(entry); + let h = handler(api); + + // min_fee = 2_000, htlc_minimum = 1_000 + // Amount must be > fee + htlc_minimum + // At 3_000: fee ~= 2_000 + (3_000 * 10_000 / 1_000_000) = 2_030 + // 2_030 + 1_000 = 3_030 > 3_000, so should fail + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(3_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert!(matches!( + result, + HtlcDecision::Reject { + reason: RejectReason::InsufficientForFee { .. } + } + )); + } + + #[tokio::test] + async fn fails_when_fee_computation_overflows() { + let mut entry = test_datastore_entry(None); + entry.opening_fee_params.min_fee_msat = Msat(u64::MAX / 2); + entry.opening_fee_params.proportional = Ppm(u32::MAX); + entry.opening_fee_params.min_payment_size_msat = Msat(1); + entry.opening_fee_params.max_payment_size_msat = Msat(u64::MAX); + + let api = MockApi::new().with_buy_request(entry); + let h = handler(api); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(u64::MAX / 2, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert!(matches!( + result, + HtlcDecision::Reject { + reason: RejectReason::FeeOverflow, + } + )); + } + + #[tokio::test] + async fn fails_when_channel_capacity_errors() { + let entry = test_datastore_entry(None); + let api = MockApi::new() + .with_buy_request(entry) + .with_channel_capacity_error(); + let h = handler(api); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.expect_err("should fail"); + + assert!(matches!(result, HtlcError::CapacityQuery(_))); + } + + #[tokio::test] + async fn fails_when_policy_denies_channel() { + let entry = test_datastore_entry(None); + let api = MockApi::new().with_buy_request(entry).with_channel_denied(); + let h = handler(api); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert!(matches!( + result, + HtlcDecision::Reject { + reason: RejectReason::PolicyDenied, + } + )); + } + + #[tokio::test] + async fn fails_when_fund_channel_errors() { + let entry = test_datastore_entry(None); + let api = MockApi::new() + .with_buy_request(entry) + .with_channel_capacity(50_000_000) + .with_fund_error(); + let h = handler(api); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.expect_err("should fail"); + + assert!(matches!(result, HtlcError::FundChannel(_))); + } + + #[tokio::test] + async fn success_flow_continues_with_modified_payload() { + let entry = test_datastore_entry(None); + let api = MockApi::new() + .with_buy_request(entry) + .with_channel_capacity(50_000_000) + .with_fund_result(test_channel_id(), "txid123") + .with_channel_ready(true); + let h = handler(api.clone()); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + let HtlcDecision::Forward { + payload, + forward_to, + extra_tlvs, + } = result + else { + panic!("expected forward, got {:?}", result) + }; + + assert_eq!(forward_to, test_channel_id()); + assert!(!payload.0.is_empty()); + assert!(!extra_tlvs.0.is_empty()); + } + + #[tokio::test] + async fn polls_until_channel_ready() { + let entry = test_datastore_entry(None); + let api = MockApi::new() + .with_buy_request(entry) + .with_channel_capacity(50_000_000) + .with_fund_result(test_channel_id(), "txid123") + .with_channel_ready(false); + + let h = handler(api.clone()); + + // Spawn handler, will block on channel ready + let handle = tokio::spawn(async move { + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + result + }); + + // Let it poll a few times + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(api.channel_ready_check_count() > 1); + + // Now make channel ready + *api.channel_ready.lock().unwrap() = true; + + let result = handle.await.unwrap(); + assert!(matches!(result, HtlcDecision::Forward { .. })); + } + + #[tokio::test] + async fn deducts_fee_from_forward_amount() { + let entry = test_datastore_entry(None); + let api = MockApi::new() + .with_buy_request(entry) + .with_channel_capacity(50_000_000) + .with_fund_result(test_channel_id(), "txid123") + .with_channel_ready(true); + let h = handler(api); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + let HtlcDecision::Forward { payload, .. } = result else { + panic!("expected forward, got {:?}", result) + }; + + // Verify payload contains deducted amount + // fee = max(min_fee, amount * proportional / 1_000_000) + // fee = max(2_000, 10_000_000 * 10_000 / 1_000_000) = max(2_000, 100_000) = 100_000 + // deducted = 10_000_000 - 100_000 = 9_900_000 + let forward_amt = payload.get_tu64(TLV_FORWARD_AMT).unwrap(); + assert_eq!(forward_amt, Some(9_900_000)); + } + + #[tokio::test] + async fn extra_tlvs_contain_opening_fee() { + let entry = test_datastore_entry(None); + let api = MockApi::new() + .with_buy_request(entry) + .with_channel_capacity(50_000_000) + .with_fund_result(test_channel_id(), "txid123") + .with_channel_ready(true); + let h = handler(api); + + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(10_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + let HtlcDecision::Forward { extra_tlvs, .. } = result else { + panic!("expected forward, got {:?}", result) + }; + + // Opening fee should be in TLV 65537 + let opening_fee = extra_tlvs.get_u64(65537).unwrap(); + assert_eq!(opening_fee, Some(100_000)); // Same fee calculation as above + } + + #[tokio::test] + async fn handles_minimum_valid_amount() { + let entry = test_datastore_entry(None); + let api = MockApi::new() + .with_buy_request(entry) + .with_channel_capacity(50_000_000) + .with_fund_result(test_channel_id(), "txid123") + .with_channel_ready(true); + let h = handler(api); + + // Just enough to cover fee + htlc_minimum + // fee at 1_000_000 = max(2_000, 1_000_000 * 10_000 / 1_000_000) = max(2_000, 10_000) = 10_000 + // Need: fee + htlc_minimum < amount + // 10_000 + 1_000 = 11_000 < 1_000_000 ✓ + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(1_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert!(matches!(result, HtlcDecision::Forward { .. })); + } + + #[tokio::test] + async fn handles_maximum_valid_amount() { + let entry = test_datastore_entry(None); + let api = MockApi::new() + .with_buy_request(entry) + .with_channel_capacity(200_000_000) + .with_fund_result(test_channel_id(), "txid123") + .with_channel_ready(true); + let h = handler(api); + + // max_payment_size_msat is 100_000_000 + let onion = test_onion(test_scid(), TlvStream::default()); + let htlc = test_htlc(100_000_000, TlvStream::default()); + let result = h.handle(&htlc, &onion).await.unwrap(); + + assert!(matches!(result, HtlcDecision::Forward { .. })); + } +} diff --git a/plugins/lsps-plugin/src/core/lsps2/mod.rs b/plugins/lsps-plugin/src/core/lsps2/mod.rs new file mode 100644 index 000000000000..18bf1cb51ce1 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/mod.rs @@ -0,0 +1,3 @@ +pub mod htlc; +pub mod provider; +pub mod service; diff --git a/plugins/lsps-plugin/src/core/lsps2/provider.rs b/plugins/lsps-plugin/src/core/lsps2/provider.rs new file mode 100644 index 000000000000..6466630a4748 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/provider.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use async_trait::async_trait; +use bitcoin::hashes::sha256::Hash; +use bitcoin::secp256k1::PublicKey; + +use crate::proto::{ + lsps0::{Msat, ShortChannelId}, + lsps2::{ + DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, + Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, + }, +}; + +pub type Blockheight = u32; + +#[async_trait] +pub trait BlockheightProvider: Send + Sync { + async fn get_blockheight(&self) -> Result; +} + +#[async_trait] +pub trait DatastoreProvider: Send + Sync { + async fn store_buy_request( + &self, + scid: &ShortChannelId, + peer_id: &PublicKey, + offer: &OpeningFeeParams, + expected_payment_size: &Option, + ) -> Result; + + async fn get_buy_request(&self, scid: &ShortChannelId) -> Result; + async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()>; +} + +#[async_trait] +pub trait LightningProvider: Send + Sync { + async fn fund_jit_channel(&self, peer_id: &PublicKey, amount: &Msat) -> Result<(Hash, String)>; + async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Hash) -> Result; +} + +#[async_trait] +pub trait Lsps2OfferProvider: Send + Sync { + async fn get_offer( + &self, + request: &Lsps2PolicyGetInfoRequest, + ) -> Result; + + async fn get_channel_capacity( + &self, + params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> Result; +} diff --git a/plugins/lsps-plugin/src/core/lsps2/service.rs b/plugins/lsps-plugin/src/core/lsps2/service.rs new file mode 100644 index 000000000000..a3ab32406e71 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/service.rs @@ -0,0 +1,605 @@ +use crate::{ + core::{ + lsps2::provider::{BlockheightProvider, DatastoreProvider, Lsps2OfferProvider}, + router::JsonRpcRouterBuilder, + server::LspsProtocol, + }, + proto::{ + jsonrpc::{RpcError, RpcErrorExt as _}, + lsps0::{LSPS0RpcErrorExt as _, ShortChannelId}, + lsps2::{ + Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, Lsps2GetInfoResponse, + Lsps2PolicyGetInfoRequest, OpeningFeeParams, ShortChannelIdJITExt, + }, + }, + register_handler, +}; +use async_trait::async_trait; +use bitcoin::secp256k1::PublicKey; +use std::sync::Arc; + +const DEFAULT_CLTV_EXPIRY_DELTA: u32 = 144; + +#[async_trait] +pub trait Lsps2Handler: Send + Sync + 'static { + async fn handle_get_info( + &self, + request: Lsps2GetInfoRequest, + ) -> std::result::Result; + + async fn handle_buy( + &self, + peer_id: PublicKey, + request: Lsps2BuyRequest, + ) -> Result; +} + +impl LspsProtocol for Arc +where + H: Lsps2Handler + Send + Sync + 'static, +{ + fn register_handler(&self, router: &mut JsonRpcRouterBuilder) { + register_handler!(router, self, "lsps2.get_info", handle_get_info); + register_handler!(router, self, "lsps2.buy", handle_buy, with_peer); + } + + fn protocol(&self) -> u8 { + 2 + } +} + +pub struct Lsps2ServiceHandler { + pub api: Arc, + pub promise_secret: [u8; 32], +} + +impl Lsps2ServiceHandler { + pub fn new(api: Arc, promise_seret: &[u8; 32]) -> Self { + Lsps2ServiceHandler { + api, + promise_secret: promise_seret.to_owned(), + } + } +} + +#[async_trait] +impl Lsps2Handler + for Lsps2ServiceHandler +{ + async fn handle_get_info( + &self, + request: Lsps2GetInfoRequest, + ) -> std::result::Result { + let res_data = self + .api + .get_offer(&Lsps2PolicyGetInfoRequest { + token: request.token.clone(), + }) + .await + .map_err(|_| RpcError::internal_error("internal error"))?; + + if res_data.client_rejected { + return Err(RpcError::client_rejected("client was rejected")); + }; + + let opening_fee_params_menu = res_data + .policy_opening_fee_params_menu + .iter() + .map(|v| v.with_promise(&self.promise_secret)) + .collect::>(); + + Ok(Lsps2GetInfoResponse { + opening_fee_params_menu, + }) + } + + async fn handle_buy( + &self, + peer_id: PublicKey, + request: Lsps2BuyRequest, + ) -> core::result::Result { + let fee_params = request.opening_fee_params; + + // FIXME: In the future we should replace the \`None\` with a meaningful + // value that reflects the inbound capacity for this node from the + // public network for a better pre-condition check on the payment_size. + fee_params.validate(&self.promise_secret, request.payment_size_msat, None)?; + + // Generate a tmp scid to identify jit channel request in htlc. + let blockheight = self + .api + .get_blockheight() + .await + .map_err(|_| RpcError::internal_error("internal error"))?; + + // FIXME: Future task: Check that we don't conflict with any jit scid we + // already handed out -> Check datastore entries. + let jit_scid = ShortChannelId::generate_jit(blockheight, 12); // Approximately 2 hours in the future. + + let ok = self + .api + .store_buy_request(&jit_scid, &peer_id, &fee_params, &request.payment_size_msat) + .await + .map_err(|_| RpcError::internal_error("internal error"))?; + + if !ok { + return Err(RpcError::internal_error("internal error"))?; + } + + Ok(Lsps2BuyResponse { + jit_channel_scid: jit_scid, + // We can make this configurable if necessary. + lsp_cltv_expiry_delta: DEFAULT_CLTV_EXPIRY_DELTA, + // We can implement the other mode later on as we might have to do + // some additional work on core-lightning to enable this. + client_trusts_lsp: false, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::lsps0::{Msat, Ppm}; + use crate::proto::lsps2::{ + DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, + Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoResponse, OpeningFeeParams, + PolicyOpeningFeeParams, Promise, + }; + use anyhow::{anyhow, Result as AnyResult}; + use chrono::{TimeZone, Utc}; + use std::sync::{Arc, Mutex}; + + fn test_peer_id() -> PublicKey { + "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" + .parse() + .unwrap() + } + + fn test_secret() -> [u8; 32] { + [0x42; 32] + } + + fn test_policy_params() -> PolicyOpeningFeeParams { + PolicyOpeningFeeParams { + min_fee_msat: Msat(2_000), + proportional: Ppm(10_000), + valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), + min_lifetime: 1000, + max_client_to_self_delay: 2016, + min_payment_size_msat: Msat(1_000_000), + max_payment_size_msat: Msat(100_000_000), + } + } + + fn test_opening_fee_params(secret: &[u8; 32]) -> OpeningFeeParams { + test_policy_params().with_promise(secret) + } + + fn expired_opening_fee_params(secret: &[u8; 32]) -> OpeningFeeParams { + let mut policy = test_policy_params(); + policy.valid_until = Utc.with_ymd_and_hms(2000, 1, 1, 0, 0, 0).unwrap(); + policy.with_promise(secret) + } + + #[derive(Default, Clone)] + struct MockApi { + // Responses + offer_response: Arc>>, + blockheight: Arc>>, + store_result: Arc>>, + + // Errors + offer_error: Arc>, + blockheight_error: Arc>, + store_error: Arc>, + + // Capture calls + stored_requests: Arc>>, + } + + #[derive(Clone, Debug)] + struct StoredBuyRequest { + peer_id: PublicKey, + payment_size: Option, + } + + impl MockApi { + fn new() -> Self { + Self::default() + } + + fn with_offer(self, response: Lsps2PolicyGetInfoResponse) -> Self { + *self.offer_response.lock().unwrap() = Some(response); + self + } + + fn with_offer_menu(self, menu: Vec) -> Self { + self.with_offer(Lsps2PolicyGetInfoResponse { + policy_opening_fee_params_menu: menu, + client_rejected: false, + }) + } + + fn with_client_rejected(self) -> Self { + *self.offer_response.lock().unwrap() = Some(Lsps2PolicyGetInfoResponse { + policy_opening_fee_params_menu: vec![], + client_rejected: true, + }); + self + } + + fn with_blockheight(self, height: u32) -> Self { + *self.blockheight.lock().unwrap() = Some(height); + self + } + + fn with_store_result(self, ok: bool) -> Self { + *self.store_result.lock().unwrap() = Some(ok); + self + } + + fn with_offer_error(self) -> Self { + *self.offer_error.lock().unwrap() = true; + self + } + + fn with_blockheight_error(self) -> Self { + *self.blockheight_error.lock().unwrap() = true; + self + } + + fn with_store_error(self) -> Self { + *self.store_error.lock().unwrap() = true; + self + } + + fn stored_requests(&self) -> Vec { + self.stored_requests.lock().unwrap().clone() + } + } + + #[async_trait] + impl Lsps2OfferProvider for MockApi { + async fn get_offer( + &self, + _request: &Lsps2PolicyGetInfoRequest, + ) -> AnyResult { + if *self.offer_error.lock().unwrap() { + return Err(anyhow!("offer error")); + } + self.offer_response + .lock() + .unwrap() + .clone() + .ok_or_else(|| anyhow!("no offer response set")) + } + + async fn get_channel_capacity( + &self, + _params: &Lsps2PolicyGetChannelCapacityRequest, + ) -> AnyResult { + unimplemented!("not needed for service tests") + } + } + + #[async_trait] + impl BlockheightProvider for MockApi { + async fn get_blockheight(&self) -> AnyResult { + if *self.blockheight_error.lock().unwrap() { + return Err(anyhow!("blockheight error")); + } + self.blockheight + .lock() + .unwrap() + .ok_or_else(|| anyhow!("no blockheight set")) + } + } + + #[async_trait] + impl DatastoreProvider for MockApi { + async fn store_buy_request( + &self, + _scid: &ShortChannelId, + peer_id: &PublicKey, + _fee_params: &OpeningFeeParams, + payment_size: &Option, + ) -> AnyResult { + if *self.store_error.lock().unwrap() { + return Err(anyhow!("store error")); + } + + self.stored_requests.lock().unwrap().push(StoredBuyRequest { + peer_id: *peer_id, + payment_size: *payment_size, + }); + + Ok(self.store_result.lock().unwrap().unwrap_or(true)) + } + + async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { + unimplemented!("not needed for service tests") + } + + async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } + } + + fn handler(api: MockApi) -> Lsps2ServiceHandler { + Lsps2ServiceHandler::new(Arc::new(api), &test_secret()) + } + + #[tokio::test] + async fn get_info_returns_fee_params_with_promise() { + let api = MockApi::new().with_offer_menu(vec![test_policy_params()]); + let h = handler(api); + + let result = h.handle_get_info(Lsps2GetInfoRequest { token: None }).await; + + let response = result.unwrap(); + assert_eq!(response.opening_fee_params_menu.len(), 1); + + let params = &response.opening_fee_params_menu[0]; + assert_eq!(params.min_fee_msat, Msat(2_000)); + assert_eq!(params.proportional, Ppm(10_000)); + assert!(!params.promise.0.is_empty()); + } + + #[tokio::test] + async fn get_info_returns_multiple_fee_params() { + let mut params1 = test_policy_params(); + params1.min_fee_msat = Msat(1_000); + + let mut params2 = test_policy_params(); + params2.min_fee_msat = Msat(2_000); + + let api = MockApi::new().with_offer_menu(vec![params1, params2]); + let h = handler(api); + + let result = h.handle_get_info(Lsps2GetInfoRequest { token: None }).await; + + let response = result.unwrap(); + assert_eq!(response.opening_fee_params_menu.len(), 2); + assert_eq!( + response.opening_fee_params_menu[0].min_fee_msat, + Msat(1_000) + ); + assert_eq!( + response.opening_fee_params_menu[1].min_fee_msat, + Msat(2_000) + ); + } + + #[tokio::test] + async fn get_info_returns_empty_menu() { + let api = MockApi::new().with_offer_menu(vec![]); + let h = handler(api); + + let result = h.handle_get_info(Lsps2GetInfoRequest { token: None }).await; + + let response = result.unwrap(); + assert!(response.opening_fee_params_menu.is_empty()); + } + + #[tokio::test] + async fn get_info_rejects_client() { + let api = MockApi::new().with_client_rejected(); + let h = handler(api); + + let result = h.handle_get_info(Lsps2GetInfoRequest { token: None }).await; + + let err = result.unwrap_err(); + assert_eq!(err.code, 001); // client_rejected code + } + + #[tokio::test] + async fn get_info_handles_api_error() { + let api = MockApi::new().with_offer_error(); + let h = handler(api); + + let result = h.handle_get_info(Lsps2GetInfoRequest { token: None }).await; + + let err = result.unwrap_err(); + assert_eq!(err.code, -32603); // internal error + } + + #[tokio::test] + async fn buy_success_with_payment_size() { + let api = MockApi::new() + .with_blockheight(800_000) + .with_store_result(true); + let h = handler(api.clone()); + + let request = Lsps2BuyRequest { + opening_fee_params: test_opening_fee_params(&test_secret()), + payment_size_msat: Some(Msat(50_000_000)), + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + let response = result.unwrap(); + assert!(response.jit_channel_scid.to_u64() > 0); + assert_eq!(response.lsp_cltv_expiry_delta, DEFAULT_CLTV_EXPIRY_DELTA); + assert!(!response.client_trusts_lsp); + + // Verify stored + let stored = api.stored_requests(); + assert_eq!(stored.len(), 1); + assert_eq!(stored[0].peer_id, test_peer_id()); + assert_eq!(stored[0].payment_size, Some(Msat(50_000_000))); + } + + #[tokio::test] + async fn buy_success_without_payment_size() { + let api = MockApi::new() + .with_blockheight(800_000) + .with_store_result(true); + let h = handler(api.clone()); + + let request = Lsps2BuyRequest { + opening_fee_params: test_opening_fee_params(&test_secret()), + payment_size_msat: None, + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + assert!(result.is_ok()); + assert_eq!(api.stored_requests()[0].payment_size, None); + } + + #[tokio::test] + async fn buy_rejects_invalid_promise() { + let api = MockApi::new(); + let h = handler(api); + + let mut fee_params = test_opening_fee_params(&test_secret()); + fee_params.promise = Promise::try_from("invalid").unwrap(); + + let request = Lsps2BuyRequest { + opening_fee_params: fee_params, + payment_size_msat: Some(Msat(50_000_000)), + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + let err = result.unwrap_err(); + assert_eq!(err.code, 201); // invalid/unrecognized params + } + + #[tokio::test] + async fn buy_rejects_expired_offer() { + let api = MockApi::new(); + let h = handler(api); + + let request = Lsps2BuyRequest { + opening_fee_params: expired_opening_fee_params(&test_secret()), + payment_size_msat: Some(Msat(50_000_000)), + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + let err = result.unwrap_err(); + assert_eq!(err.code, 201); + } + + #[tokio::test] + async fn buy_rejects_payment_below_min() { + let api = MockApi::new(); + let h = handler(api); + + let request = Lsps2BuyRequest { + opening_fee_params: test_opening_fee_params(&test_secret()), + payment_size_msat: Some(Msat(100)), // Below min_payment_size_msat + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn buy_rejects_payment_above_max() { + let api = MockApi::new(); + let h = handler(api); + + let request = Lsps2BuyRequest { + opening_fee_params: test_opening_fee_params(&test_secret()), + payment_size_msat: Some(Msat(999_999_999_999)), // Above max + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn buy_rejects_when_fee_exceeds_payment() { + let api = MockApi::new(); + let h = handler(api); + + // Payment size barely above min_fee, but fee calculation might exceed it + let mut fee_params = test_policy_params(); + fee_params.min_fee_msat = Msat(10_000); + fee_params.min_payment_size_msat = Msat(1); + let fee_params = fee_params.with_promise(&test_secret()); + + let request = Lsps2BuyRequest { + opening_fee_params: fee_params, + payment_size_msat: Some(Msat(5_000)), // Less than min_fee + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + let err = result.unwrap_err(); + assert_eq!(err.code, 202); // fee exceeds payment + } + + #[tokio::test] + async fn buy_handles_blockheight_error() { + let api = MockApi::new().with_blockheight_error(); + let h = handler(api); + + let request = Lsps2BuyRequest { + opening_fee_params: test_opening_fee_params(&test_secret()), + payment_size_msat: Some(Msat(50_000_000)), + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + let err = result.unwrap_err(); + assert_eq!(err.code, -32603); + } + + #[tokio::test] + async fn buy_handles_store_error() { + let api = MockApi::new().with_blockheight(800_000).with_store_error(); + let h = handler(api); + + let request = Lsps2BuyRequest { + opening_fee_params: test_opening_fee_params(&test_secret()), + payment_size_msat: Some(Msat(50_000_000)), + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + let err = result.unwrap_err(); + assert_eq!(err.code, -32603); + } + + #[tokio::test] + async fn buy_handles_store_returns_false() { + let api = MockApi::new() + .with_blockheight(800_000) + .with_store_result(false); + let h = handler(api); + + let request = Lsps2BuyRequest { + opening_fee_params: test_opening_fee_params(&test_secret()), + payment_size_msat: Some(Msat(50_000_000)), + }; + + let result = h.handle_buy(test_peer_id(), request).await; + + let err = result.unwrap_err(); + assert_eq!(err.code, -32603); + } + + #[tokio::test] + async fn buy_generates_unique_scids() { + let api = MockApi::new() + .with_blockheight(800_000) + .with_store_result(true); + let h = handler(api); + + let request = Lsps2BuyRequest { + opening_fee_params: test_opening_fee_params(&test_secret()), + payment_size_msat: None, + }; + + let r1 = h.handle_buy(test_peer_id(), request.clone()).await.unwrap(); + let r2 = h.handle_buy(test_peer_id(), request).await.unwrap(); + + assert_ne!(r1.jit_channel_scid, r2.jit_channel_scid); + } +} diff --git a/plugins/lsps-plugin/src/core/mod.rs b/plugins/lsps-plugin/src/core/mod.rs new file mode 100644 index 000000000000..19c7e1671a1c --- /dev/null +++ b/plugins/lsps-plugin/src/core/mod.rs @@ -0,0 +1,7 @@ +pub mod client; +pub mod features; +pub mod lsps2; +pub mod router; +pub mod server; +pub mod tlv; +pub mod transport; diff --git a/plugins/lsps-plugin/src/core/router.rs b/plugins/lsps-plugin/src/core/router.rs new file mode 100644 index 000000000000..63b84b680e65 --- /dev/null +++ b/plugins/lsps-plugin/src/core/router.rs @@ -0,0 +1,385 @@ +use crate::proto::jsonrpc::{RpcError, RpcErrorExt}; +use bitcoin::secp256k1::PublicKey; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::value::RawValue; +use std::{collections::HashMap, future::Future, pin::Pin}; + +pub type BoxedHandler = Box< + dyn Fn( + &RequestContext, + &[u8], + ) -> Pin> + Send>> + + Send + + Sync, +>; + +/// Convenience macro to register a handler at the JsonRpcRouterBuilder. +#[macro_export] +macro_rules! register_handler { + ($builder:expr, $handler:expr, $method:literal, $fn:ident) => {{ + let h = $handler.clone(); + $crate::core::router::JsonRpcRouterBuilder::register($builder, $method, move |p| { + let h = h.clone(); + async move { h.$fn(p).await } + }); + }}; + + // With context (peer_id) + ($builder:expr, $handler:expr, $method:literal, $fn:ident, with_peer) => {{ + let h = $handler.clone(); + $crate::core::router::JsonRpcRouterBuilder::register_with_context( + $builder, + $method, + move |ctx, p| { + let h = h.clone(); + async move { h.$fn(ctx.peer_id, p).await } + }, + ); + }}; +} + +#[derive(Clone)] +pub struct RequestContext { + pub peer_id: PublicKey, +} + +/// Builder for a generic JSON-RPC 2.0 router +pub struct JsonRpcRouterBuilder { + handlers: HashMap<&'static str, BoxedHandler>, +} + +impl JsonRpcRouterBuilder { + pub fn new() -> Self { + Self { + handlers: HashMap::new(), + } + } + + pub fn register(&mut self, method: &'static str, handler: F) + where + P: DeserializeOwned + Send + 'static, + R: Serialize + Send + 'static, + F: Fn(P) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + Send + 'static, + { + let boxed: BoxedHandler = Box::new(move |_ctx, params: &[u8]| { + let handler = handler.clone(); + let params: Result = serde_json::from_slice(params); + Box::pin(async move { + let params = params.map_err(|e| RpcError::invalid_params(e))?; + let result = handler(params).await?; + serde_json::to_value(&result).map_err(|_| RpcError::internal_error("parsing error")) + }) + }); + self.handlers.insert(method, boxed); + } + + pub fn register_with_context(&mut self, method: &'static str, handler: F) + where + P: DeserializeOwned + Send + 'static, + R: Serialize + Send + 'static, + F: Fn(RequestContext, P) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + Send + 'static, + { + let boxed: BoxedHandler = Box::new(move |ctx: &RequestContext, params: &[u8]| { + let handler = handler.clone(); + let ctx = ctx.clone(); + let params: Result = serde_json::from_slice(params); + Box::pin(async move { + let params = params.map_err(|e| RpcError::invalid_params(e))?; + let result = handler(ctx, params).await?; + serde_json::to_value(&result).map_err(|_| RpcError::internal_error("parsing error")) + }) + }); + self.handlers.insert(method, boxed); + } + + pub fn build(self) -> JsonRpcRouter { + JsonRpcRouter { + handlers: self.handlers, + } + } +} + +/// Generic JSON-RPC 2.0 router +pub struct JsonRpcRouter { + handlers: HashMap<&'static str, BoxedHandler>, +} + +impl JsonRpcRouter { + pub async fn handle(&self, ctx: &RequestContext, request: &[u8]) -> Option> { + #[derive(Deserialize)] + struct BorrowedRequest<'a> { + jsonrpc: &'a str, + method: &'a str, + #[serde(borrow)] + id: Option<&'a str>, + #[serde(borrow)] + params: Option<&'a RawValue>, + } + + let req: BorrowedRequest<'_> = match serde_json::from_slice(request) { + Ok(req) => req, + Err(_) => { + return Some(error_response( + None, + RpcError::parse_error("failed to parse request"), + )) + } + }; + + if req.jsonrpc != "2.0" { + return Some(error_response(req.id, RpcError::invalid_request(""))); + } + + let handler = match self.handlers.get(req.method) { + Some(h) => h, + None => return Some(error_response(req.id, RpcError::method_not_found(""))), + }; + + // Notification -> no response + let id = match req.id { + Some(id) => id, + None => return None, + }; + + let param_bytes = match req.params { + Some(raw) => raw.get().as_bytes(), + None => b"{}", + }; + + Some(match handler(ctx, param_bytes).await { + Ok(r) => success_response(id, r), + Err(e) => error_response(Some(id), e), + }) + } + + pub fn methods(&self) -> Vec<&'static str> { + self.handlers.keys().copied().collect() + } +} + +fn success_response(id: &str, result: serde_json::Value) -> Vec { + serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": result + })) + .unwrap() +} + +fn error_response(id: Option<&str>, error: RpcError) -> Vec { + serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "error": { + "code": error.code, + "message": error.message, + "data": error.data + } + })) + .unwrap() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::jsonrpc::{INVALID_PARAMS, INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR}; + use serde::{Deserialize, Serialize}; + use serde_json::{self, json}; + + // Simple types for testing + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] + struct AddParams { + a: i32, + b: i32, + } + + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] + struct AddResult { + sum: i32, + } + + fn test_peer_id() -> PublicKey { + "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" + .parse() + .unwrap() + } + + fn test_context() -> RequestContext { + RequestContext { + peer_id: test_peer_id(), + } + } + + #[tokio::test] + async fn dispatches_to_registered_handler_and_returns_success() { + let mut builder = JsonRpcRouterBuilder::new(); + builder.register("add", |p: AddParams| async move { + Ok(AddResult { sum: p.a + p.b }) + }); + + let router = builder.build(); + + let req = json!({ + "jsonrpc": "2.0", + "method": "add", + "id": "1", + "params": { "a": 1, "b": 2 }, + }); + + let req_bytes = serde_json::to_vec(&req).unwrap(); + + let resp_bytes = router + .handle(&test_context(), &req_bytes) + .await + .expect("should not be a notification"); + + let resp: serde_json::Value = serde_json::from_slice(&resp_bytes).unwrap(); + + assert_eq!(resp["jsonrpc"], "2.0"); + assert_eq!(resp["id"], "1"); + assert_eq!(resp["result"]["sum"], 3); + assert!(resp.get("error").is_none()); + } + + #[tokio::test] + async fn returns_none_for_notification() { + let mut builder = JsonRpcRouterBuilder::new(); + builder.register("add", |p: AddParams| async move { + Ok(AddResult { sum: p.a + p.b }) + }); + + let router = builder.build(); + + // No `id` → notification + let req = json!({ + "jsonrpc": "2.0", + "method": "add", + "params": { "a": 10, "b": 20 }, + }); + + let req_bytes = serde_json::to_vec(&req).unwrap(); + let resp = router.handle(&test_context(), &req_bytes).await; + + assert!(resp.is_none(), "notifications must not produce a response"); + } + + #[tokio::test] + async fn unknown_method_returns_method_not_found() { + let builder = JsonRpcRouterBuilder::new(); + let router = builder.build(); + + let req = json!({ + "jsonrpc": "2.0", + "method": "does.not.exist", + "id": "42", + "params": {}, + }); + + let req_bytes = serde_json::to_vec(&req).unwrap(); + let resp_bytes = router + .handle(&test_context(), &req_bytes) + .await + .expect("not a notification"); + + let resp: serde_json::Value = serde_json::from_slice(&resp_bytes).unwrap(); + + assert_eq!(resp["jsonrpc"], "2.0"); + assert_eq!(resp["id"], "42"); + assert_eq!(resp["error"]["code"], METHOD_NOT_FOUND); + assert!(resp.get("result").is_none()); + } + + #[tokio::test] + async fn invalid_json_returns_parse_error_with_null_id() { + let builder = JsonRpcRouterBuilder::new(); + let router = builder.build(); + + // Not valid JSON at all + let resp_bytes = router + .handle(&test_context(), b"this is not json") + .await + .expect("parse error still produces a response"); + + let resp: serde_json::Value = serde_json::from_slice(&resp_bytes).unwrap(); + + assert_eq!(resp["jsonrpc"], "2.0"); + assert_eq!(resp["id"], serde_json::Value::Null); + assert_eq!(resp["error"]["code"], PARSE_ERROR); + } + + #[tokio::test] + async fn wrong_jsonrpc_version_returns_invalid_request() { + let builder = JsonRpcRouterBuilder::new(); + let router = builder.build(); + + let req = json!({ + "jsonrpc": "1.0", // wrong + "method": "add", + "id": "1", + "params": {}, + }); + + let req_bytes = serde_json::to_vec(&req).unwrap(); + let resp_bytes = router + .handle(&test_context(), &req_bytes) + .await + .expect("not a notification"); + + let resp: serde_json::Value = serde_json::from_slice(&resp_bytes).unwrap(); + + assert_eq!(resp["error"]["code"], INVALID_REQUEST); + assert_eq!(resp["id"], "1"); + } + + #[tokio::test] + async fn bad_params_return_invalid_params_error() { + let mut builder = JsonRpcRouterBuilder::new(); + builder.register("add", |p: AddParams| async move { + Ok(AddResult { sum: p.a + p.b }) + }); + + let router = builder.build(); + + // `params` is a string, but handler expects AddParams → serde should fail → invalid_params + let req = json!({ + "jsonrpc": "2.0", + "method": "add", + "id": "1", + "params": "not an object", + }); + + let req_bytes = serde_json::to_vec(&req).unwrap(); + let resp_bytes = router + .handle(&test_context(), &req_bytes) + .await + .expect("not a notification"); + + let resp: serde_json::Value = serde_json::from_slice(&resp_bytes).unwrap(); + + assert_eq!(resp["error"]["code"], INVALID_PARAMS); + assert_eq!(resp["id"], "1"); + assert!(resp.get("result").is_none()); + } + + #[test] + fn methods_returns_registered_method_names() { + let mut builder = JsonRpcRouterBuilder::new(); + + builder.register("add", |p: AddParams| async move { + Ok(AddResult { sum: p.a + p.b }) + }); + + builder.register("sub", |p: AddParams| async move { + Ok(AddResult { sum: p.a - p.b }) + }); + + let router = builder.build(); + + let mut methods = router.methods(); + methods.sort(); + + assert_eq!(methods, vec!["add", "sub"]); + } +} diff --git a/plugins/lsps-plugin/src/core/server.rs b/plugins/lsps-plugin/src/core/server.rs new file mode 100644 index 000000000000..936c01fae890 --- /dev/null +++ b/plugins/lsps-plugin/src/core/server.rs @@ -0,0 +1,172 @@ +use crate::core::router::{JsonRpcRouter, JsonRpcRouterBuilder, RequestContext}; +use crate::proto::lsps0::{Lsps0listProtocolsRequest, Lsps0listProtocolsResponse}; + +pub trait LspsProtocol: Send + Sync + 'static { + fn register_handler(&self, router: &mut JsonRpcRouterBuilder); + fn protocol(&self) -> u8; +} + +pub struct LspsService { + router: JsonRpcRouter, + supported_protocols: Vec, +} + +impl LspsService { + pub fn builder() -> LspsServiceBuilder { + LspsServiceBuilder::new() + } + + pub async fn handle(&self, ctx: &RequestContext, request: &[u8]) -> Option> { + self.router.handle(ctx, request).await + } + + pub fn protocols(&self) -> &[u8] { + &self.supported_protocols + } +} + +pub struct LspsServiceBuilder { + router_builder: JsonRpcRouterBuilder, + supported_protocols: Vec, +} + +impl LspsServiceBuilder { + pub fn new() -> Self { + Self { + router_builder: JsonRpcRouterBuilder::new(), + supported_protocols: vec![], + } + } + + pub fn with_protocol(mut self, method: M) -> Self + where + M: LspsProtocol, + { + let proto = method.protocol(); + self.supported_protocols.push(proto); + method.register_handler(&mut self.router_builder); + self + } + + pub fn build(mut self) -> LspsService { + self.supported_protocols.sort(); + self.supported_protocols.dedup(); + let supported_protocols: Vec = self + .supported_protocols + .into_iter() + .filter(|&p| p != 0) + .collect(); + + let protocols_for_rpc = supported_protocols.clone(); + self.router_builder.register( + "lsps0.list_protocols", + move |_p: Lsps0listProtocolsRequest| { + let protocols = protocols_for_rpc.clone(); + async move { Ok(Lsps0listProtocolsResponse { protocols }) } + }, + ); + + let router = self.router_builder.build(); + + LspsService { + router, + supported_protocols, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_context() -> RequestContext { + RequestContext { + peer_id: "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" + .parse() + .unwrap(), + } + } + + // Minimal mock - just tracks protocol number + struct MockProtocol(u8); + + impl LspsProtocol for MockProtocol { + fn register_handler(&self, _router: &mut JsonRpcRouterBuilder) { + // No-op, we just care about protocol number + } + + fn protocol(&self) -> u8 { + self.0 + } + } + + #[test] + fn test_protocols_sorted() { + let service = LspsService::builder() + .with_protocol(MockProtocol(5)) + .with_protocol(MockProtocol(1)) + .with_protocol(MockProtocol(2)) + .build(); + + assert_eq!(service.protocols(), &[1, 2, 5]); + } + + #[test] + fn test_protocols_deduped() { + let service = LspsService::builder() + .with_protocol(MockProtocol(2)) + .with_protocol(MockProtocol(2)) + .build(); + + assert_eq!(service.protocols(), &[2]); + } + + #[test] + fn test_protocol_zero_filtered() { + let service = LspsService::builder() + .with_protocol(MockProtocol(0)) + .with_protocol(MockProtocol(2)) + .build(); + + assert_eq!(service.protocols(), &[2]); + } + + #[tokio::test] + async fn test_list_protocols_returns_registered() { + let service = LspsService::builder() + .with_protocol(MockProtocol(2)) + .with_protocol(MockProtocol(1)) + .build(); + + let request = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": "1", + "method": "lsps0.list_protocols", + "params": {} + })) + .unwrap(); + + let response = service.handle(&test_context(), &request).await.unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&response).unwrap(); + + assert_eq!(parsed["result"]["protocols"], serde_json::json!([1, 2])); + } + + #[tokio::test] + async fn test_list_protocols_empty() { + let service = LspsService::builder().build(); + + let request = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": "1", + "method": "lsps0.list_protocols", + "params": {} + })) + .unwrap(); + + let response = service.handle(&test_context(), &request).await.unwrap(); + let parsed: serde_json::Value = serde_json::from_slice(&response).unwrap(); + + assert_eq!(parsed["result"]["protocols"], serde_json::json!([])); + } +} diff --git a/plugins/lsps-plugin/src/core/tlv.rs b/plugins/lsps-plugin/src/core/tlv.rs new file mode 100644 index 000000000000..7bf74442d587 --- /dev/null +++ b/plugins/lsps-plugin/src/core/tlv.rs @@ -0,0 +1,647 @@ +use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer}; +use std::{convert::TryFrom, fmt}; +use thiserror::Error; + +pub const TLV_FORWARD_AMT: u64 = 2; +pub const TLV_OUTGOING_CLTV: u64 = 4; +pub const TLV_SHORT_CHANNEL_ID: u64 = 6; +pub const TLV_PAYMENT_SECRET: u64 = 8; + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct TlvRecord { + pub type_: u64, + pub value: Vec, +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct TlvStream(pub Vec); + +#[derive(Debug, Error)] +pub enum TlvError { + #[error("duplicate tlv type {0}")] + DuplicateType(u64), + #[error("tlv types are not strictly increasing")] + NotSorted, + #[error("length mismatch type {0}: expected {1}, got {2}")] + LengthMismatch(u64, usize, usize), + #[error("truncated input")] + Truncated, + #[error("non-canonical bigsize encoding")] + NonCanonicalBigSize, + #[error("leftover bytes after parsing")] + TrailingBytes, + #[error("")] + Hex(#[from] hex::FromHexError), + #[error("length overflow")] + Overflow, + #[error("tu64 is not minimal, got a leading zero")] + LeadingZero, + #[error("failed to parse bytes to u64")] + BytesToU64, +} + +type Result = std::result::Result; + +impl TlvStream { + pub fn to_bytes(&mut self) -> Result> { + self.0.sort_by_key(|r| r.type_); + for w in self.0.windows(2) { + if w[0].type_ == w[1].type_ { + return Err(TlvError::DuplicateType(w[0].type_).into()); + } + if w[0].type_ > w[1].type_ { + return Err(TlvError::NotSorted.into()); + } + } + let mut out = Vec::new(); + for rec in &self.0 { + out.extend(encode_bigsize(rec.type_)); + out.extend(encode_bigsize(rec.value.len() as u64)); + out.extend(&rec.value); + } + Ok(out) + } + + pub fn from_bytes(mut bytes: &[u8]) -> Result { + let mut recs = Vec::new(); + let mut last_type: Option = None; + + while !bytes.is_empty() { + let (t, n1) = decode_bigsize(bytes)?; + bytes = &bytes[n1..]; + let (len, n2) = decode_bigsize(bytes)?; + bytes = &bytes[n2..]; + + let l = usize::try_from(len).map_err(|_| TlvError::Overflow)?; + if bytes.len() < l { + return Err(TlvError::Truncated.into()); + } + let v = bytes[..l].to_vec(); + bytes = &bytes[l..]; + + if let Some(prev) = last_type { + if t == prev { + return Err(TlvError::DuplicateType(t).into()); + } + if t < prev { + return Err(TlvError::NotSorted.into()); + } + } + last_type = Some(t); + recs.push(TlvRecord { type_: t, value: v }); + } + Ok(TlvStream(recs)) + } + + pub fn from_bytes_with_length_prefix(bytes: &[u8]) -> Result { + if bytes.is_empty() { + return Err(TlvError::Truncated.into()); + } + + let (length, length_bytes) = decode_bigsize(bytes)?; + let remaining = &bytes[length_bytes..]; + + let length_usize = usize::try_from(length).map_err(|_| TlvError::Overflow)?; + + if remaining.len() != length_usize { + return Err(TlvError::LengthMismatch(0, length_usize, remaining.len()).into()); + } + + Self::from_bytes(remaining) + } + + /// Attempt to auto-detect whether the input has a length prefix or not + /// First tries to parse as length-prefixed, then falls back to raw TLV + /// parsing. + pub fn from_bytes_auto(bytes: &[u8]) -> Result { + // Try length-prefixed first + if let Ok(stream) = Self::from_bytes_with_length_prefix(bytes) { + return Ok(stream); + } + + // Fall back to raw TLV parsing + Self::from_bytes(bytes) + } + + /// Get a reference to the value of a TLV record by type. + pub fn get(&self, type_: u64) -> Option<&[u8]> { + self.0 + .iter() + .find(|rec| rec.type_ == type_) + .map(|rec| rec.value.as_slice()) + } + + /// Insert a TLV record (replaces if type already exists). + pub fn insert(&mut self, type_: u64, value: Vec) { + // If the type already exists, replace its value. + if let Some(rec) = self.0.iter_mut().find(|rec| rec.type_ == type_) { + rec.value = value; + return; + } + // Otherwise push and re-sort to maintain canonical order. + self.0.push(TlvRecord { type_, value }); + self.0.sort_by_key(|r| r.type_); + } + + /// Remove a record by type. + pub fn remove(&mut self, type_: u64) -> Option> { + if let Some(pos) = self.0.iter().position(|rec| rec.type_ == type_) { + Some(self.0.remove(pos).value) + } else { + None + } + } + + /// Check if a type exists. + pub fn contains(&self, type_: u64) -> bool { + self.0.iter().any(|rec| rec.type_ == type_) + } + + /// Insert or override a `tu64` value for `type_` (keeps canonical TLV order). + pub fn set_tu64(&mut self, type_: u64, value: u64) { + let enc = encode_tu64(value); + if let Some(rec) = self.0.iter_mut().find(|r| r.type_ == type_) { + rec.value = enc; + } else { + self.0.push(TlvRecord { type_, value: enc }); + self.0.sort_by_key(|r| r.type_); + } + } + + /// Read a `tu64` if present, validating minimal encoding. + /// Returns Ok(None) if the type isn't present. + pub fn get_tu64(&self, type_: u64) -> Result> { + if let Some(rec) = self.0.iter().find(|r| r.type_ == type_) { + Ok(Some(decode_tu64(&rec.value)?)) + } else { + Ok(None) + } + } + + /// Insert or override a `u64` value for `type_` (keeps cannonical TLV + /// order). + pub fn set_u64(&mut self, type_: u64, value: u64) { + let enc = value.to_be_bytes().to_vec(); + if let Some(rec) = self.0.iter_mut().find(|r| r.type_ == type_) { + rec.value = enc; + } else { + self.0.push(TlvRecord { type_, value: enc }); + self.0.sort_by_key(|r| r.type_); + } + } + + /// Read a `u64` if present.Returns Ok(None) if the type isn't present. + pub fn get_u64(&self, type_: u64) -> Result> { + if let Some(rec) = self.0.iter().find(|r| r.type_ == type_) { + let value = + u64::from_be_bytes(rec.value[..].try_into().map_err(|_| TlvError::BytesToU64)?); + Ok(Some(value)) + } else { + Ok(None) + } + } +} + +impl Serialize for TlvStream { + fn serialize(&self, serializer: S) -> std::result::Result { + let mut tmp = self.clone(); + let bytes = tmp.to_bytes().map_err(serde::ser::Error::custom)?; + serializer.serialize_str(&hex::encode(bytes)) + } +} + +impl<'de> Deserialize<'de> for TlvStream { + fn deserialize>(deserializer: D) -> std::result::Result { + struct V; + impl<'de> serde::de::Visitor<'de> for V { + type Value = TlvStream; + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a hex string representing a Lightning TLV stream") + } + fn visit_str(self, s: &str) -> std::result::Result { + let bytes = hex::decode(s).map_err(E::custom)?; + TlvStream::from_bytes_auto(&bytes).map_err(E::custom) + } + } + deserializer.deserialize_str(V) + } +} + +impl TryFrom<&[u8]> for TlvStream { + type Error = TlvError; + fn try_from(value: &[u8]) -> std::result::Result { + TlvStream::from_bytes(value) + } +} + +impl From> for TlvStream { + fn from(v: Vec) -> Self { + TlvStream(v) + } +} + +/// BOLT #1 BigSize encoding +fn encode_bigsize(x: u64) -> Vec { + let mut out = Vec::new(); + if x < 0xfd { + out.push(x as u8); + } else if x <= 0xffff { + out.push(0xfd); + out.extend_from_slice(&(x as u16).to_be_bytes()); + } else if x <= 0xffff_ffff { + out.push(0xfe); + out.extend_from_slice(&(x as u32).to_be_bytes()); + } else { + out.push(0xff); + out.extend_from_slice(&x.to_be_bytes()); + } + out +} + +fn decode_bigsize(input: &[u8]) -> Result<(u64, usize)> { + if input.is_empty() { + return Err(TlvError::Truncated.into()); + } + match input[0] { + n @ 0x00..=0xfc => Ok((n as u64, 1)), + 0xfd => { + if input.len() < 3 { + return Err(TlvError::Truncated.into()); + } + let v = u16::from_be_bytes([input[1], input[2]]) as u64; + if v < 0xfd { + return Err(TlvError::NonCanonicalBigSize.into()); + } + Ok((v, 3)) + } + 0xfe => { + if input.len() < 5 { + return Err(TlvError::Truncated.into()); + } + let v = u32::from_be_bytes([input[1], input[2], input[3], input[4]]) as u64; + if v <= 0xffff { + return Err(TlvError::NonCanonicalBigSize.into()); + } + Ok((v, 5)) + } + 0xff => { + if input.len() < 9 { + return Err(TlvError::Truncated.into()); + } + let v = u64::from_be_bytes([ + input[1], input[2], input[3], input[4], input[5], input[6], input[7], input[8], + ]); + if v <= 0xffff_ffff { + return Err(TlvError::NonCanonicalBigSize.into()); + } + Ok((v, 9)) + } + } +} + +/// Encode a BOLT #1 `tu64`: big-endian, minimal length (no leading 0x00). +/// Value 0 is encoded as zero-length. +pub fn encode_tu64(v: u64) -> Vec { + if v == 0 { + return Vec::new(); + } + let bytes = v.to_be_bytes(); + let first = bytes.iter().position(|&b| b != 0).unwrap(); // safe: v != 0 + bytes[first..].to_vec() +} + +/// Decode a BOLT #1 `tu64`, enforcing minimal form. +/// Empty slice -> 0. Leading 0x00 or >8 bytes is invalid. +fn decode_tu64(raw: &[u8]) -> Result { + if raw.is_empty() { + return Ok(0); + } + if raw.len() > 8 { + return Err(TlvError::Overflow); + } + if raw[0] == 0 { + return Err(TlvError::LeadingZero); + } + let mut buf = [0u8; 8]; + buf[8 - raw.len()..].copy_from_slice(raw); + Ok(u64::from_be_bytes(buf)) +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + + // Small helpers to keep tests readable + fn rec(type_: u64, value: &[u8]) -> TlvRecord { + TlvRecord { + type_, + value: value.to_vec(), + } + } + + fn build_bytes(type_: u64, value: &[u8]) -> Vec { + let mut v = Vec::new(); + v.extend(super::encode_bigsize(type_)); + v.extend(super::encode_bigsize(value.len() as u64)); + v.extend(value); + v + } + + #[test] + fn encode_then_decode_roundtrip() -> Result<()> { + let mut stream = TlvStream(vec![rec(1, &[0x01, 0x02]), rec(5, &[0xaa])]); + + // Encode + let bytes = stream.to_bytes()?; + // Expect exact TLV sequence: + // type=1 -> 0x01, len=2 -> 0x02, value=0x01 0x02 + // type=5 -> 0x05, len=1 -> 0x01, value=0xaa + assert_eq!(hex::encode(&bytes), "010201020501aa"); + + // Decode back + let decoded = TlvStream::from_bytes(&bytes)?; + assert_eq!(decoded.0.len(), 2); + assert_eq!(decoded.0[0].type_, 1); + assert_eq!(decoded.0[0].value, vec![0x01, 0x02]); + assert_eq!(decoded.0[1].type_, 5); + assert_eq!(decoded.0[1].value, vec![0xaa]); + + Ok(()) + } + + #[test] + fn json_hex_roundtrip() -> Result<()> { + let stream = TlvStream(vec![rec(1, &[0x01, 0x02]), rec(5, &[0xaa])]); + + // Serialize to hex string in JSON + let json = serde_json::to_string(&stream)?; + // It's a quoted hex string; check inner value + let s: String = serde_json::from_str(&json)?; + assert_eq!(s, "010201020501aa"); + + // And back from JSON hex + let back: TlvStream = serde_json::from_str(&json)?; + assert_eq!(back.0.len(), 2); + assert_eq!(back.0[0].type_, 1); + assert_eq!(back.0[0].value, vec![0x01, 0x02]); + assert_eq!(back.0[1].type_, 5); + assert_eq!(back.0[1].value, vec![0xaa]); + + Ok(()) + } + + #[test] + fn decode_with_len_prefix() -> Result<()> { + let payload = "1202039896800401760608000073000f2c0007"; + let stream = TlvStream::from_bytes_with_length_prefix(&hex::decode(payload).unwrap())?; + // let stream: TlvStream = serde_json::from_str(payload)?; + println!("TLV {:?}", stream.0); + + Ok(()) + } + + #[test] + fn bigsize_boundary_minimal_encodings() -> Result<()> { + // Types at 0xfc, 0xfd, 0x10000 to exercise size switches + let mut stream = TlvStream(vec![ + rec(0x00fc, &[0x11]), // single-byte bigsize + rec(0x00fd, &[0x22]), // 0xfd prefix + u16 + rec(0x0001_0000, &[0x33]), // 0xfe prefix + u32 + ]); + + let bytes = stream.to_bytes()?; // just ensure it encodes + // Decode back to confirm roundtrip/canonical encodings accepted + let back = TlvStream::from_bytes(&bytes)?; + assert_eq!(back.0[0].type_, 0x00fc); + assert_eq!(back.0[1].type_, 0x00fd); + assert_eq!(back.0[2].type_, 0x0001_0000); + Ok(()) + } + + #[test] + fn decode_rejects_non_canonical_bigsize() { + // (1) Non-canonical: 0xfd 00 fc encodes 0xfc but should be a single byte + let mut bytes = Vec::new(); + bytes.extend([0xfd, 0x00, 0xfc]); // non-canonical type + bytes.extend([0x01]); // len = 1 + bytes.extend([0x00]); // value + let err = TlvStream::from_bytes(&bytes).unwrap_err(); + assert!(format!("{}", err).contains("non-canonical")); + + // (2) Non-canonical: 0xfe 00 00 00 ff encodes 0xff but should be 0xfd-form + let mut bytes = Vec::new(); + bytes.extend([0xfe, 0x00, 0x00, 0x00, 0xff]); + bytes.extend([0x01]); + bytes.extend([0x00]); + let err = TlvStream::from_bytes(&bytes).unwrap_err(); + assert!(format!("{}", err).contains("non-canonical")); + + // (3) Non-canonical: 0xff 00..01 encodes 1, which should be single byte + let mut bytes = Vec::new(); + bytes.extend([0xff, 0, 0, 0, 0, 0, 0, 0, 1]); + bytes.extend([0x01]); + bytes.extend([0x00]); + let err = TlvStream::from_bytes(&bytes).unwrap_err(); + assert!(format!("{}", err).contains("non-canonical")); + } + + #[test] + fn decode_rejects_out_of_order_types() { + // Build two TLVs but put type 5 before type 1 + let mut bad = Vec::new(); + bad.extend(build_bytes(5, &[0xaa])); + bad.extend(build_bytes(1, &[0x00])); + + let err = TlvStream::from_bytes(&bad).unwrap_err(); + assert!( + format!("{}", err).contains("increasing") || format!("{}", err).contains("sorted"), + "expected ordering error, got: {err}" + ); + } + + #[test] + fn decode_rejects_duplicate_types() { + // Two records with same type=1 + let mut bad = Vec::new(); + bad.extend(build_bytes(1, &[0x01])); + bad.extend(build_bytes(1, &[0x02])); + let err = TlvStream::from_bytes(&bad).unwrap_err(); + assert!( + format!("{}", err).contains("duplicate"), + "expected duplicate error, got: {err}" + ); + } + + #[test] + fn encode_rejects_duplicate_types() { + // insert duplicate types and expect encode to fail + let mut s = TlvStream(vec![rec(1, &[0x01]), rec(1, &[0x02])]); + let err = s.to_bytes().unwrap_err(); + assert!( + format!("{}", err).contains("duplicate"), + "expected duplicate error, got: {err}" + ); + } + + #[test] + fn decode_truncated_value() { + // type=1, len=2 but only 1 byte of value provided + let mut bytes = Vec::new(); + bytes.extend(encode_bigsize(1)); + bytes.extend(encode_bigsize(2)); + bytes.push(0x00); // missing one more byte + let err = TlvStream::from_bytes(&bytes).unwrap_err(); + assert!( + format!("{}", err).contains("truncated"), + "expected truncated error, got: {err}" + ); + } + + #[test] + fn set_and_get_u64_basic() -> Result<()> { + let mut s = TlvStream::default(); + s.set_u64(42, 123456789); + assert_eq!(s.get_u64(42)?, Some(123456789)); + Ok(()) + } + + #[test] + fn set_u64_overwrite_keeps_order() -> Result<()> { + let mut s = TlvStream(vec![ + TlvRecord { + type_: 1, + value: vec![0xaa], + }, + TlvRecord { + type_: 10, + value: vec![0xbb], + }, + ]); + + // insert between 1 and 10 + s.set_u64(5, 7); + assert_eq!( + s.0.iter().map(|r| r.type_).collect::>(), + vec![1, 5, 10] + ); + assert_eq!(s.get_u64(5)?, Some(7)); + + // overwrite existing 5 (no duplicate, order preserved) + s.set_u64(5, 9); + let types: Vec = s.0.iter().map(|r| r.type_).collect(); + assert_eq!(types, vec![1, 5, 10]); + assert_eq!(s.0.iter().filter(|r| r.type_ == 5).count(), 1); + assert_eq!(s.get_u64(5)?, Some(9)); + Ok(()) + } + + #[test] + fn set_and_get_tu64_basic() -> Result<()> { + let mut s = TlvStream::default(); + s.set_tu64(42, 123456789); + assert_eq!(s.get_tu64(42)?, Some(123456789)); + Ok(()) + } + + #[test] + fn get_u64_missing_returns_none() -> Result<()> { + let s = TlvStream::default(); + assert_eq!(s.get_u64(999)?, None); + Ok(()) + } + + #[test] + fn set_tu64_overwrite_keeps_order() -> Result<()> { + let mut s = TlvStream(vec![ + TlvRecord { + type_: 1, + value: vec![0xaa], + }, + TlvRecord { + type_: 10, + value: vec![0xbb], + }, + ]); + + // insert between 1 and 10 + s.set_tu64(5, 7); + assert_eq!( + s.0.iter().map(|r| r.type_).collect::>(), + vec![1, 5, 10] + ); + assert_eq!(s.get_tu64(5)?, Some(7)); + + // overwrite existing 5 (no duplicate, order preserved) + s.set_tu64(5, 9); + let types: Vec = s.0.iter().map(|r| r.type_).collect(); + assert_eq!(types, vec![1, 5, 10]); + assert_eq!(s.0.iter().filter(|r| r.type_ == 5).count(), 1); + assert_eq!(s.get_tu64(5)?, Some(9)); + Ok(()) + } + + #[test] + fn tu64_zero_encodes_empty_and_roundtrips() -> Result<()> { + let mut s = TlvStream::default(); + s.set_tu64(3, 0); + + // stored value is zero-length + let rec = s.0.iter().find(|r| r.type_ == 3).unwrap(); + assert!(rec.value.is_empty()); + + // wire round-trip + let mut sc = s.clone(); + let bytes = sc.to_bytes()?; + let s2 = TlvStream::from_bytes(&bytes)?; + assert_eq!(s2.get_tu64(3)?, Some(0)); + Ok(()) + } + + #[test] + fn get_tu64_missing_returns_none() -> Result<()> { + let s = TlvStream::default(); + assert_eq!(s.get_tu64(999)?, None); + Ok(()) + } + + #[test] + fn get_tu64_rejects_non_minimal_and_too_long() { + // non-minimal: leading zero + let mut s = TlvStream::default(); + s.0.push(TlvRecord { + type_: 9, + value: vec![0x00, 0x01], + }); + assert!(s.get_tu64(9).is_err()); + + // too long: 9 bytes + let mut s2 = TlvStream::default(); + s2.0.push(TlvRecord { + type_: 9, + value: vec![0; 9], + }); + assert!(s2.get_tu64(9).is_err()); + } + + #[test] + fn tu64_multi_roundtrip_bytes_and_json() -> Result<()> { + let mut s = TlvStream::default(); + s.set_tu64(42, 0); + s.set_tu64(7, 256); + + // wire roundtrip + let mut sc = s.clone(); + let bytes = sc.to_bytes()?; + let s2 = TlvStream::from_bytes(&bytes)?; + assert_eq!(s2.get_tu64(42)?, Some(0)); + assert_eq!(s2.get_tu64(7)?, Some(256)); + + // json hex roundtrip (custom Serialize/Deserialize) + let json = serde_json::to_string(&s)?; + let s3: TlvStream = serde_json::from_str(&json)?; + assert_eq!(s3.get_tu64(42)?, Some(0)); + assert_eq!(s3.get_tu64(7)?, Some(256)); + Ok(()) + } +} diff --git a/plugins/lsps-plugin/src/core/transport.rs b/plugins/lsps-plugin/src/core/transport.rs new file mode 100644 index 000000000000..e66d1652c79f --- /dev/null +++ b/plugins/lsps-plugin/src/core/transport.rs @@ -0,0 +1,283 @@ +use crate::proto::jsonrpc::{JsonRpcResponse, RequestObject}; +use async_trait::async_trait; +use bitcoin::secp256k1::PublicKey; +use core::fmt::Debug; +use serde::{de::DeserializeOwned, Serialize}; +use std::{collections::HashMap, sync::Arc, time::Duration}; +use thiserror::Error; +use tokio::sync::{oneshot, Mutex}; + +/// Transport-specific errors that may occur when sending or receiving JSON-RPC +/// messages. +#[derive(Error, Debug)] +pub enum Error { + #[error("Timeout")] + Timeout, + #[error("Internal error: {0}")] + Internal(String), + #[error("Couldn't parse JSON-RPC request")] + ParseRequest { + #[source] + source: serde_json::Error, + }, + #[error("request is missing id")] + MissingId, +} + +impl From for Error { + fn from(value: serde_json::Error) -> Self { + Self::ParseRequest { source: value } + } +} + +pub type Result = std::result::Result; + +/// Defines the interface for transporting JSON-RPC messages. +/// +/// Implementors of this trait are responsible for actually sending the JSON-RPC +/// request over some transport mechanism (RPC, Bolt8, etc.) +#[async_trait] +pub trait Transport: Send + Sync { + async fn request( + &self, + peer_id: &PublicKey, + request: &RequestObject

, + ) -> Result> + where + P: Serialize + Send + Sync, + R: DeserializeOwned + Send; +} + +#[async_trait] +pub trait MessageSender: Send + Sync + Clone + 'static { + async fn send(&self, peer: &PublicKey, payload: &[u8]) -> Result<()>; +} + +#[derive(Clone, Default)] +pub struct PendingRequests { + inner: Arc>>>>, +} + +impl PendingRequests { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub async fn insert(&self, id: String) -> oneshot::Receiver> { + let (tx, rx) = oneshot::channel(); + self.inner.lock().await.insert(id, tx); + rx + } + + pub async fn complete(&self, id: &str, data: Vec) -> bool { + if let Some(tx) = self.inner.lock().await.remove(id) { + tx.send(data).is_ok() + } else { + false + } + } + + pub async fn remove(&self, id: &str) { + self.inner.lock().await.remove(id); + } +} + +#[derive(Clone)] +pub struct MultiplexedTransport { + sender: S, + pending: PendingRequests, + timeout: Duration, +} + +impl MultiplexedTransport { + pub fn new(sender: S, pending: PendingRequests, timeout: Duration) -> Self { + Self { + sender, + pending, + timeout, + } + } + + pub fn pending(&self) -> &PendingRequests { + &self.pending + } + + pub fn sender(&self) -> &S { + &self.sender + } +} + +#[async_trait] +impl Transport for MultiplexedTransport { + async fn request( + &self, + peer_id: &PublicKey, + request: &RequestObject

, + ) -> Result> + where + P: Serialize + Send + Sync, + R: DeserializeOwned + Send, + { + let id = request.id.as_ref().ok_or(Error::MissingId)?; + let payload = serde_json::to_vec(request)?; + + // Register pending before sending + let rx = self.pending().insert(id.clone()).await; + + // Send via backend + if let Err(e) = self.sender.send(peer_id, &payload).await { + self.pending.remove(id).await; + return Err(e); + }; + + let response_bytes = tokio::time::timeout(self.timeout, rx) + .await + .map_err(|_| { + let pending = self.pending.clone(); + let id = id.clone(); + tokio::spawn(async move { pending.remove(&id).await }); + Error::Timeout + })? + .map_err(|_| Error::Internal("channel closed unexpectedly".into()))?; + + Ok(serde_json::from_slice(&response_bytes)?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + + fn test_peer() -> PublicKey { + "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" + .parse() + .unwrap() + } + + // Mock sender that captures calls + #[derive(Clone, Default)] + struct MockSender { + call_count: Arc, + should_fail: bool, + } + + #[async_trait] + impl MessageSender for MockSender { + async fn send(&self, _peer: &PublicKey, _payload: &[u8]) -> Result<()> { + self.call_count.fetch_add(1, Ordering::SeqCst); + if self.should_fail { + Err(Error::Internal("mock failure".into())) + } else { + Ok(()) + } + } + } + + #[tokio::test] + async fn test_pending_requests_roundtrip() { + let pending = PendingRequests::new(); + + let rx = pending.insert("test-id".to_string()).await; + + // Simulate response arriving + let completed = pending.complete("test-id", b"response".to_vec()).await; + assert!(completed); + + let result = rx.await.unwrap(); + assert_eq!(result, b"response"); + } + + #[tokio::test] + async fn test_pending_requests_unknown_id() { + let pending = PendingRequests::new(); + + let completed = pending.complete("unknown", b"data".to_vec()).await; + assert!(!completed); + } + + #[tokio::test] + async fn test_pending_requests_remove() { + let pending = PendingRequests::new(); + + let _rx = pending.insert("test-id".to_string()).await; + pending.remove("test-id").await; + + let completed = pending.complete("test-id", b"data".to_vec()).await; + assert!(!completed); + } + + #[tokio::test] + async fn test_transport_sends_via_sender() { + let sender = MockSender::default(); + let call_count = sender.call_count.clone(); + + let pending = PendingRequests::new(); + let transport = MultiplexedTransport::new(sender, pending, Duration::from_secs(1)); + + // Start request (will timeout since no response) + let request = RequestObject { + jsonrpc: "2.0".into(), + method: "test".into(), + params: Some(serde_json::json!({})), + id: Some("1".into()), + }; + + let result: Result> = + transport.request(&test_peer(), &request).await; + + // Should have sent + assert_eq!(call_count.load(Ordering::SeqCst), 1); + // Should timeout (no response) + assert!(matches!(result, Err(Error::Timeout))); + } + + #[tokio::test] + async fn test_transport_send_failure_cleans_up() { + let sender = MockSender { + should_fail: true, + ..Default::default() + }; + + let pending = PendingRequests::new(); + let transport = MultiplexedTransport::new(sender, pending, Duration::from_secs(1)); + + let request = RequestObject { + jsonrpc: "2.0".into(), + method: "test".into(), + params: Some(serde_json::json!({})), + id: Some("1".into()), + }; + + let result: Result> = + transport.request(&test_peer(), &request).await; + + assert!(matches!(result, Err(Error::Internal(_)))); + + // Pending should be cleaned up + let completed = transport.pending().complete("1", b"data".to_vec()).await; + assert!(!completed); + } + + #[tokio::test] + async fn test_transport_missing_id() { + let sender = MockSender::default(); + + let pending = PendingRequests::new(); + let transport = MultiplexedTransport::new(sender, pending, Duration::from_secs(1)); + + let request = RequestObject::<()> { + jsonrpc: "2.0".into(), + method: "test".into(), + params: None, + id: None, // Missing! + }; + + let result: Result> = + transport.request(&test_peer(), &request).await; + + assert!(matches!(result, Err(Error::MissingId))); + } +} diff --git a/plugins/lsps-plugin/src/jsonrpc/client.rs b/plugins/lsps-plugin/src/jsonrpc/client.rs deleted file mode 100644 index 5e0fe167d66c..000000000000 --- a/plugins/lsps-plugin/src/jsonrpc/client.rs +++ /dev/null @@ -1,351 +0,0 @@ -use async_trait::async_trait; -use core::fmt::Debug; -use log::{debug, error}; -use rand::rngs::OsRng; -use rand::TryRngCore; -use serde::{de::DeserializeOwned, Serialize}; -use serde_json::Value; -use std::sync::Arc; - -use crate::jsonrpc::{ - Error, JsonRpcRequest, JsonRpcResponse, RequestObject, ResponseObject, Result, -}; - -/// Defines the interface for transporting JSON-RPC messages. -/// -/// Implementors of this trait are responsible for actually sending the JSON-RPC -/// request over some transport mechanism (RPC, Bolt8, etc.) -#[async_trait] -pub trait Transport { - async fn send(&self, request: String) -> core::result::Result; - async fn notify(&self, request: String) -> core::result::Result<(), Error>; -} - -/// A typed JSON-RPC client that works with any transport implementation. -/// -/// This client handles the JSON-RPC protocol details including message -/// formatting, request ID generation, and response parsing. -#[derive(Clone)] -pub struct JsonRpcClient { - transport: Arc, -} - -impl JsonRpcClient { - pub fn new(transport: T) -> Self { - Self { - transport: Arc::new(transport), - } - } - - /// Makes a JSON-RPC method call with raw JSON parameters and returns a raw - /// JSON result. - pub async fn call_raw(&self, method: &str, params: Option) -> Result { - let id = generate_random_id(); - - debug!("Preparing request: method={}, id={}", method, id); - let request = RequestObject { - jsonrpc: "2.0".into(), - method: method.into(), - params, - id: Some(id.clone().into()), - }; - let res_obj = self.send_request(method, &request, id).await?; - Value::from_response(res_obj) - } - - /// Makes a typed JSON-RPC method call with a request object and returns a - /// typed response. - /// - /// This method provides type safety by using request and response types - /// that implement the necessary traits. - pub async fn call_typed(&self, request: RQ) -> Result - where - RQ: JsonRpcRequest + Serialize + Send + Sync, - RS: DeserializeOwned + Serialize + Debug + Send + Sync, - { - let method = RQ::METHOD; - let id = generate_random_id(); - - debug!("Preparing request: method={}, id={}", method, id); - let request = request.into_request(Some(id.clone().into())); - let res_obj = self.send_request(method, &request, id).await?; - RS::from_response(res_obj) - } - - /// Sends a notification with raw JSON parameters (no response expected). - pub async fn notify_raw(&self, method: &str, params: Option) -> Result<()> { - debug!("Preparing notification: method={}", method); - let request = RequestObject { - jsonrpc: "2.0".into(), - method: method.into(), - params, - id: None, - }; - Ok(self.send_notification(method, &request).await?) - } - - /// Sends a typed notification (no response expected). - pub async fn notify_typed(&self, request: RQ) -> Result<()> - where - RQ: JsonRpcRequest + Serialize + Send + Sync, - { - let method = RQ::METHOD; - - debug!("Preparing notification: method={}", method); - let request = request.into_request(None); - Ok(self.send_notification(method, &request).await?) - } - - async fn send_request( - &self, - method: &str, - payload: &RP, - id: String, - ) -> Result> - where - RP: Serialize + Send + Sync, - RS: DeserializeOwned + Serialize + Debug + Send + Sync, - { - let request_json = serde_json::to_string(&payload)?; - debug!( - "Sending request: method={}, id={}, request={:?}", - method, id, &request_json - ); - let start = tokio::time::Instant::now(); - let res_str = self.transport.send(request_json).await?; - let elapsed = start.elapsed(); - debug!( - "Received response: method={}, id={}, response={}, elapsed={}ms", - method, - id, - &res_str, - elapsed.as_millis() - ); - Ok(serde_json::from_str(&res_str)?) - } - - async fn send_notification(&self, method: &str, payload: &RP) -> Result<()> - where - RP: Serialize + Send + Sync, - { - let request_json = serde_json::to_string(&payload)?; - debug!("Sending notification: method={}", method); - let start = tokio::time::Instant::now(); - self.transport.notify(request_json).await?; - let elapsed = start.elapsed(); - debug!( - "Sent notification: method={}, elapsed={}ms", - method, - elapsed.as_millis() - ); - Ok(()) - } -} - -/// Generates a random ID for JSON-RPC requests. -/// -/// Uses a secure random number generator to create a hex-encoded ID. Falls back -/// to a timestamp-based ID if random generation fails. -fn generate_random_id() -> String { - let mut bytes = [0u8; 10]; - match OsRng.try_fill_bytes(&mut bytes) { - Ok(_) => hex::encode(bytes), - Err(e) => { - // Fallback to a timestamp-based ID if random generation fails - error!( - "Failed to generate random ID: {}, falling back to timestamp", - e - ); - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_nanos(); - format!("fallback-{}", timestamp) - } - } -} - -#[cfg(test)] - -mod test_json_rpc { - use serde::Deserialize; - use tokio::sync::OnceCell; - - use super::*; - use crate::jsonrpc::{self, RpcError}; - - #[derive(Clone)] - struct TestTransport { - req: Arc>, - res: Arc>, - err: Arc>, - } - - impl TestTransport { - // Get the last request as parsed JSON - fn last_request_json(&self) -> Option { - self.req - .get() - .and_then(|req_str| serde_json::from_str(req_str).ok()) - } - } - - #[async_trait] - impl Transport for TestTransport { - async fn send(&self, req: String) -> core::result::Result { - // Store the request - let _ = self.req.set(req); - - // Check for error first - if let Some(err) = &*self.err { - return Err(Error::Transport(jsonrpc::TransportError::Other(err.into()))); - } - - // Then check for response - if let Some(res) = &*self.res { - return Ok(res.clone()); - } - - panic!("TestTransport: neither result nor error is set."); - } - - async fn notify(&self, req: String) -> core::result::Result<(), Error> { - // Store the request - let _ = self.req.set(req); - - // Check for error - if let Some(err) = &*self.err { - return Err(Error::Transport(jsonrpc::TransportError::Other(err.into()))); - } - - Ok(()) - } - } - - #[derive(Default, Clone, Serialize, Deserialize, Debug)] - struct DummyCall { - foo: String, - bar: i32, - } - - impl JsonRpcRequest for DummyCall { - const METHOD: &'static str = "dummy_call"; - } - - #[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)] - struct DummyResponse { - foo: String, - bar: i32, - } - - #[tokio::test] - async fn test_typed_call_w_response() { - let req = DummyCall { - foo: String::from("hello world!"), - bar: 13, - }; - - let expected_res = DummyResponse { - foo: String::from("hello client!"), - bar: 10, - }; - - let res_obj = expected_res - .clone() - .into_response(String::from("unique-id-123")); - let res_str = serde_json::to_string(&res_obj).unwrap(); - - let transport = TestTransport { - req: Arc::new(OnceCell::const_new()), - res: Arc::new(Some(res_str)), - err: Arc::new(None), - }; - - let client_1 = JsonRpcClient::new(transport.clone()); - let res = client_1 - .call_typed::<_, DummyResponse>(req.clone()) - .await - .expect("Should have an OK result"); - assert_eq!(res, expected_res); - let transport_req = transport - .last_request_json() - .expect("Transport should have gotten a request"); - assert_eq!( - transport_req - .get("jsonrpc") - .and_then(|v| v.as_str()) - .unwrap(), - "2.0" - ); - assert_eq!( - transport_req - .get("params") - .and_then(|v| v.as_object()) - .unwrap(), - serde_json::to_value(&req).unwrap().as_object().unwrap() - ); - } - - #[tokio::test] - async fn test_typed_call_w_rpc_error() { - let req = DummyCall { - foo: "hello world!".into(), - bar: 13, - }; - - let err_res = RpcError::custom_error_with_data( - -32099, - "got a custom error", - serde_json::json!({"got": "some"}), - ); - - let res_obj = err_res.clone().into_response("unique-id-123".into()); - let res_str = serde_json::to_string(&res_obj).unwrap(); - - let transport = TestTransport { - req: Arc::new(OnceCell::const_new()), - res: Arc::new(Some(res_str)), - err: Arc::new(None), - }; - - let client_1 = JsonRpcClient::new(transport); - let res = client_1 - .call_typed::<_, DummyResponse>(req) - .await - .expect_err("Expected error response"); - assert!(match res { - Error::Rpc(rpc_error) => { - assert_eq!(rpc_error, err_res); - true - } - _ => false, - }); - } - - #[tokio::test] - async fn test_typed_call_w_transport_error() { - let req = DummyCall { - foo: "hello world!".into(), - bar: 13, - }; - - let transport = TestTransport { - req: Arc::new(OnceCell::const_new()), - res: Arc::new(None), - err: Arc::new(Some(String::from("transport error"))), - }; - - let client_1 = JsonRpcClient::new(transport); - let res = client_1 - .call_typed::<_, DummyResponse>(req) - .await - .expect_err("Expected error response"); - assert!(match res { - Error::Transport(err) => { - assert_eq!(err.to_string(), "Other error: transport error"); - true - } - _ => false, - }); - } -} diff --git a/plugins/lsps-plugin/src/jsonrpc/mod.rs b/plugins/lsps-plugin/src/jsonrpc/mod.rs deleted file mode 100644 index b7c871ee0932..000000000000 --- a/plugins/lsps-plugin/src/jsonrpc/mod.rs +++ /dev/null @@ -1,509 +0,0 @@ -pub mod client; -use log::debug; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::{self, Value}; -pub mod server; -use std::fmt; -use thiserror::Error; - -// Constants for JSON-RPC error codes -const PARSE_ERROR: i64 = -32700; -const INVALID_REQUEST: i64 = -32600; -const METHOD_NOT_FOUND: i64 = -32601; -const INVALID_PARAMS: i64 = -32602; -const INTERNAL_ERROR: i64 = -32603; - -/// Error type for JSON-RPC related operations. -/// -/// Encapsulates various error conditions that may occur during JSON-RPC -/// operations, including serialization errors, transport issues, and -/// protocol-specific errors. -#[derive(Error, Debug)] -pub enum Error { - #[error("JSON error: {0}")] - Json(#[from] serde_json::Error), - #[error("RPC error: {0}")] - Rpc(#[from] RpcError), - #[error("Transport error: {0}")] - Transport(#[from] TransportError), - #[error("Other error: {0}")] - Other(String), -} - -impl Error { - pub fn other(v: T) -> Self { - return Self::Other(v.to_string()); - } -} - -/// Transport-specific errors that may occur when sending or receiving JSON-RPC -/// messages. -#[derive(Error, Debug)] -pub enum TransportError { - #[error("Timeout")] - Timeout, - #[error("Other error: {0}")] - Other(String), -} - -/// Convenience type alias for Result with the JSON-RPC Error type. -pub type Result = std::result::Result; - -/// Trait for types that can be converted into JSON-RPC request objects. -/// -/// Implementing this trait allows a struct to be used as a typed JSON-RPC -/// request, with an associated method name and automatic conversion to the -/// request format. -pub trait JsonRpcRequest: Serialize { - const METHOD: &'static str; - fn into_request(self, id: Option) -> RequestObject - where - Self: Sized, - { - RequestObject { - jsonrpc: "2.0".into(), - method: Self::METHOD.into(), - params: Some(self), - id, - } - } -} - -/// Trait for types that can be converted from JSON-RPC response objects. -/// -/// This trait provides methods for converting between typed response objects -/// and JSON-RPC protocol response envelopes. -pub trait JsonRpcResponse -where - T: DeserializeOwned, -{ - fn into_response(self, id: String) -> ResponseObject - where - Self: Sized + DeserializeOwned, - { - ResponseObject { - jsonrpc: "2.0".into(), - id: id.into(), - result: Some(self), - error: None, - } - } - - fn from_response(resp: ResponseObject) -> Result - where - T: core::fmt::Debug, - { - match (resp.result, resp.error) { - (Some(result), None) => Ok(result), - (None, Some(error)) => Err(Error::Rpc(error)), - _ => { - debug!( - "Invalid JSON-RPC response - missing both result and error fields, or both set: id={}", - resp.id - ); - Err(Error::Rpc(RpcError::internal_error( - "not a valid json respone", - ))) - } - } - } -} - -/// Automatically implements the `JsonRpcResponse` trait for all types that -/// implement `DeserializeOwned`. This simplifies creating JSON-RPC services, -/// as you only need to define data structures that can be deserialized. -impl JsonRpcResponse for T where T: DeserializeOwned {} - -/// # RequestObject -/// -/// Represents a JSON-RPC 2.0 Request object, as defined in section 4 of the -/// specification. This structure encapsulates all necessary information for -/// a remote procedure call. -/// -/// # Type Parameters -/// -/// * `T`: The type of the `params` field. This *MUST* implement `Serialize` -/// to allow it to be encoded as JSON. Typically this will be a struct -/// implementing the `JsonRpcRequest` trait. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct RequestObject -where - T: Serialize, -{ - /// **REQUIRED**. MUST be `"2.0"`. - pub jsonrpc: String, - /// **REQUIRED**. The method to be invoked. - pub method: String, - /// A struct containing the method parameters. - #[serde(skip_serializing_if = "is_none_or_null")] - pub params: Option, - /// An identifier established by the Client that MUST contain a String. - /// # Note: this is special to LSPS0, might change to match the more general - /// JSON-RPC 2.0 sepec if needed. - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, -} - -impl RequestObject -where - T: Serialize, -{ - /// Returns the inner data object contained by params for handling or future - /// processing. - pub fn into_inner(self) -> Option { - self.params - } -} - -/// Helper function to check if params is None or would serialize to null. -fn is_none_or_null(opt: &Option) -> bool { - match opt { - None => true, - Some(val) => match serde_json::to_value(&val) { - Ok(Value::Null) => true, - _ => false, - }, - } -} - -/// # ResponseObject -/// -/// Represents a JSON-RPC 2.0 Response object, as defined in section 5.0 of the -/// specification. This structure encapsulates either a successful result or -/// an error. -/// -/// # Type Parameters -/// -/// * `T`: The type of the `result` field, which will be returned upon a -/// succesful execution of the procedure. *MUST* implement both `Serialize` -/// (to allow construction of responses) and `DeserializeOwned` (to allow -/// receipt and parsing of responses). -#[derive(Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct ResponseObject -where - T: DeserializeOwned, -{ - /// **REQUIRED**. MUST be `"2.0"`. - jsonrpc: String, - /// **REQUIRED**. The identifier of the original request this is a response. - id: String, - /// **REQUIRED on success**. The data if there is a request and non-errored. - /// MUST NOT exist if there was an error triggered during invocation. - #[serde(skip_serializing_if = "Option::is_none")] - result: Option, - /// **REQUIRED on error** An error type if there was a failure. - error: Option, -} - -impl ResponseObject -where - T: DeserializeOwned + Serialize + core::fmt::Debug, -{ - /// Returns a potential data (result) if the code execution passed else it - /// returns with RPC error, data (error details) if there was - pub fn into_inner(self) -> Result { - T::from_response(self) - } -} - -/// # RpcError -/// -/// Represents an error object in a JSON-RPC 2.0 Response object (section 5.1). -/// Provides structured information about an error that occurred during the -/// method invocation. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct RpcError { - /// **REQUIRED**. An integer indicating the type of error. - pub code: i64, - /// **REQUIRED**. A string containing a short description of the error. - pub message: String, - /// A primitive that can be either Primitive or Structured type if there - /// were. - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -impl RpcError { - pub fn into_response(self, id: String) -> ResponseObject { - ResponseObject { - jsonrpc: "2.0".into(), - id: id.into(), - result: None, - error: Some(self), - } - } -} - -impl RpcError { - /// Reserved for implementation-defined server-errors. - pub fn custom_error(code: i64, message: T) -> Self { - RpcError { - code, - message: message.to_string(), - data: None, - } - } - - /// Reserved for implementation-defined server-errors. - pub fn custom_error_with_data( - code: i64, - message: T, - data: serde_json::Value, - ) -> Self { - RpcError { - code, - message: message.to_string(), - data: Some(data), - } - } - - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - pub fn parse_error(message: T) -> Self { - Self::custom_error(PARSE_ERROR, message) - } - - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - pub fn parse_error_with_data( - message: T, - data: serde_json::Value, - ) -> Self { - Self::custom_error_with_data(PARSE_ERROR, message, data) - } - - /// The JSON sent is not a valid Request object. - pub fn invalid_request(message: T) -> Self { - Self::custom_error(INVALID_REQUEST, message) - } - - /// The JSON sent is not a valid Request object. - pub fn invalid_request_with_data( - message: T, - data: serde_json::Value, - ) -> Self { - Self::custom_error_with_data(INVALID_REQUEST, message, data) - } - - /// The method does not exist / is not available. - pub fn method_not_found(message: T) -> Self { - Self::custom_error(METHOD_NOT_FOUND, message) - } - - /// The method does not exist / is not available. - pub fn method_not_found_with_data( - message: T, - data: serde_json::Value, - ) -> Self { - Self::custom_error_with_data(METHOD_NOT_FOUND, message, data) - } - - /// Invalid method parameter(s). - pub fn invalid_params(message: T) -> Self { - Self::custom_error(INVALID_PARAMS, message) - } - - /// Invalid method parameter(s). - pub fn invalid_params_with_data( - message: T, - data: serde_json::Value, - ) -> Self { - Self::custom_error_with_data(INVALID_PARAMS, message, data) - } - - /// Internal JSON-RPC error. - pub fn internal_error(message: T) -> Self { - Self::custom_error(INTERNAL_ERROR, message) - } - - /// Internal JSON-RPC error. - pub fn internal_error_with_data( - message: T, - data: serde_json::Value, - ) -> Self { - Self::custom_error_with_data(INTERNAL_ERROR, message, data) - } -} - -impl fmt::Display for RpcError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "JSON-RPC Error (code: {}, message: {}, data: {:?})", - self.code, self.message, self.data - ) - } -} - -impl std::error::Error for RpcError {} - -#[cfg(test)] -mod test_message_serialization { - use super::*; - use serde_json::json; - - #[test] - fn test_empty_params_serialization() { - // Empty params should serialize to `"params":{}` instead of - // `"params":null`. - #[derive(Debug, Serialize, Deserialize)] - pub struct SayHelloRequest; - impl JsonRpcRequest for SayHelloRequest { - const METHOD: &'static str = "say_hello"; - } - let rpc_request = SayHelloRequest.into_request(Some("unique-id-123".into())); - assert!(!serde_json::to_string(&rpc_request) - .expect("could not convert to json") - .contains("\"params\"")); - } - - #[test] - fn test_request_serialization_and_deserialization() { - // Ensure that we correctly serialize to a valid JSON-RPC 2.0 request. - #[derive(Default, Debug, Serialize, Deserialize)] - pub struct SayNameRequest { - name: String, - age: i32, - } - impl JsonRpcRequest for SayNameRequest { - const METHOD: &'static str = "say_name"; - } - let rpc_request = SayNameRequest { - name: "Satoshi".to_string(), - age: 99, - } - .into_request(Some("unique-id-123".into())); - - let json_value: serde_json::Value = serde_json::to_value(&rpc_request).unwrap(); - let expected_value: serde_json::Value = serde_json::json!({ - "jsonrpc": "2.0", - "method": "say_name", - "params": { - "name": "Satoshi", - "age": 99 - }, - "id": "unique-id-123" - }); - assert_eq!(json_value, expected_value); - - let request: RequestObject = serde_json::from_value(json_value).unwrap(); - assert_eq!(request.method, "say_name"); - assert_eq!(request.jsonrpc, "2.0"); - - let request: RequestObject = - serde_json::from_value(expected_value).unwrap(); - let inner = request.into_inner(); - assert_eq!(inner.unwrap().name, rpc_request.params.unwrap().name); - } - - #[test] - fn test_response_deserialization() { - // Check that we can convert a JSON-RPC response into a typed result. - #[derive(Debug, Serialize, Deserialize, PartialEq)] - pub struct SayNameResponse { - name: String, - age: i32, - message: String, - } - - let json_response = r#" - { - "jsonrpc": "2.0", - "result": { - "age": 99, - "message": "Hello Satoshi!", - "name": "Satoshi" - }, - "id": "unique-id-123" - }"#; - - let response_object: ResponseObject = - serde_json::from_str(json_response).unwrap(); - - let response: SayNameResponse = response_object.into_inner().unwrap(); - let expected_response = SayNameResponse { - name: "Satoshi".into(), - age: 99, - message: "Hello Satoshi!".into(), - }; - - assert_eq!(response, expected_response); - } - - #[test] - fn test_empty_result() { - // Check that we correctly deserialize an empty result. - #[derive(Debug, Serialize, Deserialize, PartialEq)] - pub struct DummyResponse {} - - let json_response = r#" - { - "jsonrpc": "2.0", - "result": {}, - "id": "unique-id-123" - }"#; - - let response_object: ResponseObject = - serde_json::from_str(json_response).unwrap(); - - let response: DummyResponse = response_object.into_inner().unwrap(); - let expected_response = DummyResponse {}; - - assert_eq!(response, expected_response); - } - #[test] - fn test_error_deserialization() { - // Check that we deserialize an error if we got one. - #[derive(Debug, Serialize, Deserialize, PartialEq)] - pub struct DummyResponse {} - - let json_response = r#" - { - "jsonrpc": "2.0", - "id": "unique-id-123", - "error": { - "code": -32099, - "message": "something bad happened", - "data": { - "f1": "v1", - "f2": 2 - } - } - }"#; - - let response_object: ResponseObject = - serde_json::from_str(json_response).unwrap(); - - let response = response_object.into_inner(); - let err = response.unwrap_err(); - match err { - Error::Rpc(err) => { - assert_eq!(err.code, -32099); - assert_eq!(err.message, "something bad happened"); - assert_eq!( - err.data, - serde_json::from_str("{\"f1\":\"v1\",\"f2\":2}").unwrap() - ); - } - _ => assert!(false), - } - } - - #[test] - fn test_error_serialization() { - let error = RpcError::invalid_request("Invalid request"); - let serialized = serde_json::to_string(&error).unwrap(); - assert_eq!(serialized, r#"{"code":-32600,"message":"Invalid request"}"#); - - let error_with_data = RpcError::internal_error_with_data( - "Internal server error", - json!({"details": "Something went wrong"}), - ); - let serialized_with_data = serde_json::to_string(&error_with_data).unwrap(); - assert_eq!( - serialized_with_data, - r#"{"code":-32603,"message":"Internal server error","data":{"details":"Something went wrong"}}"# - ); - } -} diff --git a/plugins/lsps-plugin/src/jsonrpc/server.rs b/plugins/lsps-plugin/src/jsonrpc/server.rs deleted file mode 100644 index f9e3334b0059..000000000000 --- a/plugins/lsps-plugin/src/jsonrpc/server.rs +++ /dev/null @@ -1,302 +0,0 @@ -use crate::jsonrpc::{Result, RpcError}; -use async_trait::async_trait; -use log::{debug, trace}; -use std::{collections::HashMap, sync::Arc}; - -/// Responsible for writing JSON-RPC responses back to clients. -/// -/// This trait abstracts the mechanism for sending responses back to the client, -/// allowing handlers to remain transport-agnostic. Implementations of this -/// trait handle the actual transmission of response data over the underlying -/// transport. -#[async_trait] -pub trait JsonRpcResponseWriter: Send + 'static { - /// Writes the provided payload as a response. - async fn write(&mut self, payload: &[u8]) -> Result<()>; -} - -/// Processes JSON-RPC requests and produces responses. -/// -/// This trait defines the interface for handling specific JSON-RPC methods. -/// Each method supported by the server should have a corresponding handler -/// that implements this trait. -#[async_trait] -pub trait RequestHandler: Send + Sync + 'static { - /// Handles a JSON-RPC request. - async fn handle(&self, payload: &[u8]) -> core::result::Result, RpcError>; -} - -/// Builder for creating JSON-RPC servers. -pub struct JsonRpcServerBuilder { - handlers: HashMap>, -} - -impl JsonRpcServerBuilder { - pub fn new() -> Self { - Self { - handlers: HashMap::new(), - } - } - - /// Registers a handler for a specific JSON-RPC method. - pub fn with_handler(mut self, method: String, handler: Arc) -> Self { - self.handlers.insert(method, handler); - self - } - - /// Builds a JSON-RPC server with the configured handlers. - pub fn build(self) -> JsonRpcServer { - JsonRpcServer { - handlers: Arc::new(self.handlers), - } - } -} - -/// Server for handling JSON-RPC 2.0 requests. -/// -/// Dispatches incoming JSON-RPC requests to the appropriate handlers based on -/// the method name, and manages the response lifecycle. -#[derive(Clone)] -pub struct JsonRpcServer { - handlers: Arc>>, -} - -impl JsonRpcServer { - pub fn builder() -> JsonRpcServerBuilder { - JsonRpcServerBuilder::new() - } - - // Processes a JSON-RPC message and writes the response. - /// - /// This is the main entry point for handling JSON-RPC requests. It: - /// 1. Parses and validates the incoming request - /// 2. Routes the request to the appropriate handler - /// 3. Writes the response back to the client (if needed) - pub async fn handle_message( - &self, - payload: &[u8], - writer: &mut dyn JsonRpcResponseWriter, - ) -> Result<()> { - trace!("Handle request with payload: {:?}", payload); - let value: serde_json::Value = serde_json::from_slice(payload)?; - let id = value.get("id").and_then(|id| id.as_str()); - let method = value.get("method").and_then(|method| method.as_str()); - let jsonrpc = value.get("jsonrpc").and_then(|jrpc| jrpc.as_str()); - - trace!( - "Validate request: id={:?}, method={:?}, jsonrpc={:?}", - id, - method, - jsonrpc - ); - let method = match (jsonrpc, method) { - (Some(jrpc), Some(method)) if jrpc == "2.0" => method, - (_, _) => { - debug!("Got invalid request {}", value); - let err = RpcError { - code: -32600, - message: "Invalid request".into(), - data: None, - }; - return self.maybe_write_error(id, err, writer).await; - } - }; - - trace!("Get handler for id={:?}, method={:?}", id, method); - if let Some(handler) = self.handlers.get(method) { - trace!( - "Call handler for id={:?}, method={:?}, with payload={:?}", - id, - method, - payload - ); - match handler.handle(payload).await { - Ok(res) => return self.maybe_write(id, &res, writer).await, - Err(e) => { - debug!("Handler returned with error: {}", e); - return self.maybe_write_error(id, e, writer).await; - } - }; - } else { - debug!("No handler found for method: {}", method); - let err = RpcError { - code: -32601, - message: "Method not found".into(), - data: None, - }; - return self.maybe_write_error(id, err, writer).await; - } - } - - /// Writes a response if the request has an ID. - /// - /// For notifications (requests without an ID), no response is written. - async fn maybe_write( - &self, - id: Option<&str>, - payload: &[u8], - writer: &mut dyn JsonRpcResponseWriter, - ) -> Result<()> { - // No need to respond when we don't have an id - it's a notification - if id.is_some() { - return writer.write(payload).await; - } - Ok(()) - } - - /// Writes an error response if the request has an ID. - /// - /// For notifications (requests without an ID), no response is written. - async fn maybe_write_error( - &self, - id: Option<&str>, - err: RpcError, - writer: &mut dyn JsonRpcResponseWriter, - ) -> Result<()> { - // No need to respond when we don't have an id - it's a notification - if let Some(id) = id { - let err_res = err.clone().into_response(id.into()); - let err_vec = serde_json::to_vec(&err_res).map_err(|e| RpcError::internal_error(e))?; - return writer.write(&err_vec).await; - } - Ok(()) - } -} - -#[cfg(test)] -mod test_json_rpc_server { - use super::*; - - #[derive(Default)] - struct MockWriter { - log_content: String, - } - - #[async_trait] - impl JsonRpcResponseWriter for MockWriter { - async fn write(&mut self, payload: &[u8]) -> Result<()> { - println!("Write payload={:?}", &payload); - let byte_str = String::from_utf8(payload.to_vec()).unwrap(); - self.log_content = byte_str; - Ok(()) - } - } - - // Echo handler - pub struct Echo; - - #[async_trait] - impl RequestHandler for Echo { - async fn handle(&self, payload: &[u8]) -> core::result::Result, RpcError> { - println!("Called handler with payload: {:?}", &payload); - Ok(payload.to_vec()) - } - } - - #[tokio::test] - async fn test_notification() { - // A notification should not respond to the client so there is no need - // to write payload to the writer; - let server = JsonRpcServer::builder() - .with_handler("echo".to_string(), Arc::new(Echo)) - .build(); - - let mut writer = MockWriter { - log_content: String::default(), - }; - - let msg = r#"{"jsonrpc":"2.0","method":"echo","params":{"age":99,"name":"Satoshi"}}"#; // No id signals a notification. - let res = server.handle_message(msg.as_bytes(), &mut writer).await; - assert!(res.is_ok()); - assert!(writer.log_content.is_empty()); // Was a notification we don't expect a response; - } - - #[tokio::test] - async fn missing_method_field() { - // We verify the request data, check that we return an error when we - // don't understand the request. - let server = JsonRpcServer::builder() - .with_handler("echo".to_string(), Arc::new(Echo)) - .build(); - - let mut writer = MockWriter { - log_content: String::default(), - }; - - let msg = r#"{"jsonrpc":"2.0","params":{"age":99,"name":"Satoshi"},"id":"unique-id-123"}"#; - let res = server.handle_message(msg.as_bytes(), &mut writer).await; - assert!(res.is_ok()); - let expected = r#"{"jsonrpc":"2.0","id":"unique-id-123","error":{"code":-32600,"message":"Invalid request"}}"#; // Unknown method say_hello - assert_eq!(writer.log_content, expected); - } - - #[tokio::test] - async fn wrong_version() { - // We only accept requests that have jsonrpc version 2.0. - let server = JsonRpcServer::builder() - .with_handler("echo".to_string(), Arc::new(Echo)) - .build(); - - let mut writer = MockWriter { - log_content: String::default(), - }; - - let msg = r#"{"jsonrpc":"1.0","method":"echo","params":{"age":99,"name":"Satoshi"},"id":"unique-id-123"}"#; - let res = server.handle_message(msg.as_bytes(), &mut writer).await; - assert!(res.is_ok()); - let expected = r#"{"jsonrpc":"2.0","id":"unique-id-123","error":{"code":-32600,"message":"Invalid request"}}"#; // Unknown method say_hello - assert_eq!(writer.log_content, expected); - } - - #[tokio::test] - async fn propper_request() { - // Check that we call the handler and write back to the writer when - // processing a well-formed request. - let server = JsonRpcServer::builder() - .with_handler("echo".to_string(), Arc::new(Echo)) - .build(); - - let mut writer = MockWriter { - log_content: String::default(), - }; - - let msg = r#"{"jsonrpc":"2.0","method":"echo","params":{"age":99,"name":"Satoshi"},"id":"unique-id-123"}"#; - let res = server.handle_message(msg.as_bytes(), &mut writer).await; - assert!(res.is_ok()); - assert_eq!(writer.log_content, msg.to_string()); - } - - #[tokio::test] - async fn unknown_method() { - // We don't know the method and need to send back an error to the client. - let server = JsonRpcServer::builder() - .with_handler("echo".to_string(), Arc::new(Echo)) - .build(); - - let mut writer = MockWriter { - log_content: String::default(), - }; - - let msg = r#"{"jsonrpc":"2.0","method":"say_hello","params":{"age":99,"name":"Satoshi"},"id":"unique-id-123"}"#; // Unknown method say_hello - let res = server.handle_message(msg.as_bytes(), &mut writer).await; - assert!(res.is_ok()); - let expected = r#"{"jsonrpc":"2.0","id":"unique-id-123","error":{"code":-32601,"message":"Method not found"}}"#; // Unknown method say_hello - assert_eq!(writer.log_content, expected); - } - - #[tokio::test] - async fn test_handler() { - let server = JsonRpcServer::builder() - .with_handler("echo".to_string(), Arc::new(Echo)) - .build(); - - let mut writer = MockWriter { - log_content: String::default(), - }; - - let msg = r#"{"jsonrpc":"2.0","method":"echo","params":{"age":99,"name":"Satoshi"},"id":"unique-id-123"}"#; - let res = server.handle_message(msg.as_bytes(), &mut writer).await; - assert!(res.is_ok()); - assert_eq!(writer.log_content, msg.to_string()); - } -} diff --git a/plugins/lsps-plugin/src/lib.rs b/plugins/lsps-plugin/src/lib.rs index f14b96c7de90..e1f5e07f4303 100644 --- a/plugins/lsps-plugin/src/lib.rs +++ b/plugins/lsps-plugin/src/lib.rs @@ -1,6 +1,3 @@ -pub mod jsonrpc; -pub mod lsps0; -pub mod lsps2; -pub mod util; - -pub const LSP_FEATURE_BIT: usize = 729; +pub mod cln_adapters; +pub mod core; +pub mod proto; diff --git a/plugins/lsps-plugin/src/lsps0/handler.rs b/plugins/lsps-plugin/src/lsps0/handler.rs deleted file mode 100644 index 6b552f477cd9..000000000000 --- a/plugins/lsps-plugin/src/lsps0/handler.rs +++ /dev/null @@ -1,90 +0,0 @@ -use crate::{ - jsonrpc::{server::RequestHandler, JsonRpcResponse, RequestObject, RpcError}, - lsps0::model::{Lsps0listProtocolsRequest, Lsps0listProtocolsResponse}, - util::unwrap_payload_with_peer_id, -}; -use async_trait::async_trait; - -pub struct Lsps0ListProtocolsHandler { - pub lsps2_enabled: bool, -} - -#[async_trait] -impl RequestHandler for Lsps0ListProtocolsHandler { - async fn handle(&self, payload: &[u8]) -> core::result::Result, RpcError> { - let (payload, _) = unwrap_payload_with_peer_id(payload); - - let req: RequestObject = - serde_json::from_slice(&payload).unwrap(); - if let Some(id) = req.id { - let mut protocols = vec![]; - if self.lsps2_enabled { - protocols.push(2); - } - let res = Lsps0listProtocolsResponse { protocols }.into_response(id); - let res_vec = serde_json::to_vec(&res).unwrap(); - return Ok(res_vec); - } - // If request has no ID (notification), return empty Ok result. - Ok(vec![]) - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::{ - jsonrpc::{JsonRpcRequest, ResponseObject}, - util::wrap_payload_with_peer_id, - }; - use cln_rpc::primitives::PublicKey; - - const PUBKEY: [u8; 33] = [ - 0x02, 0x79, 0xbe, 0x66, 0x7e, 0xf9, 0xdc, 0xbb, 0xac, 0x55, 0xa0, 0x62, 0x95, 0xce, 0x87, - 0x0b, 0x07, 0x02, 0x9b, 0xfc, 0xdb, 0x2d, 0xce, 0x28, 0xd9, 0x59, 0xf2, 0x81, 0x5b, 0x16, - 0xf8, 0x17, 0x98, - ]; - - fn create_peer_id() -> PublicKey { - PublicKey::from_slice(&PUBKEY).expect("Valid pubkey") - } - - fn create_wrapped_request(request: &RequestObject) -> Vec { - let payload = serde_json::to_vec(request).expect("Failed to serialize request"); - wrap_payload_with_peer_id(&payload, create_peer_id()) - } - - #[tokio::test] - async fn test_lsps2_disabled_returns_empty_protocols() { - let handler = Lsps0ListProtocolsHandler { - lsps2_enabled: false, - }; - - let request = Lsps0listProtocolsRequest {}.into_request(Some("test-id".to_string())); - let payload = create_wrapped_request(&request); - - let result = handler.handle(&payload).await.unwrap(); - let response: ResponseObject = - serde_json::from_slice(&result).unwrap(); - - let data = response.into_inner().expect("Should have result data"); - assert!(data.protocols.is_empty()); - } - - #[tokio::test] - async fn test_lsps2_enabled_returns_protocol_2() { - let handler = Lsps0ListProtocolsHandler { - lsps2_enabled: true, - }; - - let request = Lsps0listProtocolsRequest {}.into_request(Some("test-id".to_string())); - let payload = create_wrapped_request(&request); - - let result = handler.handle(&payload).await.unwrap(); - let response: ResponseObject = - serde_json::from_slice(&result).unwrap(); - - let data = response.into_inner().expect("Should have result data"); - assert_eq!(data.protocols, vec![2]); - } -} diff --git a/plugins/lsps-plugin/src/lsps0/mod.rs b/plugins/lsps-plugin/src/lsps0/mod.rs deleted file mode 100644 index f32b0a55819b..000000000000 --- a/plugins/lsps-plugin/src/lsps0/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod handler; -pub mod model; -pub mod primitives; -pub mod transport; diff --git a/plugins/lsps-plugin/src/lsps0/model.rs b/plugins/lsps-plugin/src/lsps0/model.rs deleted file mode 100644 index 0327120e08f9..000000000000 --- a/plugins/lsps-plugin/src/lsps0/model.rs +++ /dev/null @@ -1,15 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::jsonrpc::JsonRpcRequest; - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct Lsps0listProtocolsRequest {} - -impl JsonRpcRequest for Lsps0listProtocolsRequest { - const METHOD: &'static str = "lsps0.list_protocols"; -} - -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] -pub struct Lsps0listProtocolsResponse { - pub protocols: Vec, -} diff --git a/plugins/lsps-plugin/src/lsps0/transport.rs b/plugins/lsps-plugin/src/lsps0/transport.rs deleted file mode 100644 index b19974e7769d..000000000000 --- a/plugins/lsps-plugin/src/lsps0/transport.rs +++ /dev/null @@ -1,533 +0,0 @@ -use crate::jsonrpc::{client::Transport, Error, TransportError}; -use async_trait::async_trait; -use cln_plugin::Plugin; -use cln_rpc::{primitives::PublicKey, ClnRpc}; -use log::{debug, error, trace}; -use serde::{de::Visitor, Deserialize, Serialize}; -use std::{ - collections::HashMap, - path::PathBuf, - str::FromStr, - sync::{Arc, Weak}, -}; -use tokio::{ - sync::{mpsc, RwLock}, - time::Duration, -}; - -pub const LSPS0_MESSAGE_TYPE: u16 = 37913; -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(60); - -/// Trait that must be implemented by plugin state to access the custom message hook manager. -/// -/// This trait allows the hook handler to access the custom message hook manager -/// from the plugin state, enabling proper message routing. -pub trait WithCustomMessageHookManager { - fn get_custommsg_hook_manager(&self) -> &CustomMessageHookManager; -} - -// Manages subscriptions for the custom message hook. -/// -/// The `CustomMessageHookManager` is responsible for: -/// 1. Maintaining a registry of message ID to receiver mappings -/// 2. Processing incoming LSPS0 messages and routing them to subscribers -/// 3. Cleaning up expired subscriptions -/// -/// It uses weak references to avoid memory leaks when timeouts occ -#[derive(Clone)] -pub struct CustomMessageHookManager { - /// Maps message IDs to weak references of response channels - subs: Arc>>>>, -} - -impl CustomMessageHookManager { - /// Creates a new CustomMessageHookManager. - pub fn new() -> Self { - Self { - subs: Arc::new(RwLock::new(HashMap::new())), - } - } - - /// Subscribes to receive a message with a specific ID. - /// - /// Registers a weak reference to a channel that will receive the message - /// when it arrives. Using weak references allows for automatic cleanup if - /// the receiver is dropped due to timeout. - async fn subscribe_hook_once>( - &self, - id: I, - channel: Weak>, - ) { - let id = id.into(); - trace!("Subscribe to custom message hook for message id={}", id); - let mut sub_lock = self.subs.write().await; - sub_lock.insert(id, channel); - } - - /// Processes an incoming LSP message. - /// - /// Extracts the message ID from the payload, finds the corresponding - /// subscriber, and forwards the message to them if found. - async fn process_lsp_message(&self, payload: CustomMsg, peer_id: &str) -> bool { - // Convert the binary payload to a string - let lsps_msg_string = match String::from_utf8(payload.payload.clone()) { - Ok(v) => v, - Err(e) => { - error!("Failed to deserialize custommsg payload from {peer_id}: {e}"); - return false; - } - }; - - let id = match extract_message_id(&lsps_msg_string) { - Ok(v) => v, - Err(e) => { - error!("Failed to get id from lsps message from {peer_id}: {e}"); - return false; - } - }; - - let mut subs_lock = self.subs.write().await; - // Clean up any expired subscriptions - subs_lock.retain(|_, v| Weak::strong_count(v) > 0); - subs_lock.shrink_to_fit(); - - // Find send to, and remove the subscriber for this message ID - if let Some(tx) = subs_lock.remove(&id).map(|v| v.upgrade()).flatten() { - if let Err(e) = tx.send(payload).await { - error!("Failed to send custommsg to subscriber for id={}: {e}", id); - return false; - } - return true; - } - - debug!( - "No subscriber found for message with id={} from {peer_id}", - id - ); - false - } - - /// Handles the custommsg hook from Core Lightning. - /// - /// This method should be registered as a hook handler in a Core Lightning - /// plugin. It processes incoming custom messages and routes LSPS0 messages - /// to the appropriate subscribers. - pub async fn on_custommsg( - p: Plugin, - v: serde_json::Value, - ) -> Result - where - S: Clone + Sync + Send + 'static + WithCustomMessageHookManager, - { - // Default response is to continue processing. - let continue_response = Ok(serde_json::json!({ - "result": "continue" - })); - - // Parse the custom message hook return value. - let custommsg: CustomMsgHookReturn = match serde_json::from_value(v) { - Ok(v) => v, - Err(e) => { - error!("Failed to deserialize custommsg: {e}"); - return continue_response; - } - }; - - // Only process LSPS0 message types. - if custommsg.payload.message_type != LSPS0_MESSAGE_TYPE { - debug!( - "Custommsg is not of type LSPS0 (got {}), skipping", - custommsg.payload.message_type - ); - return continue_response; - } - - // Route the message to the appropriate handler. - // Can be moved into a separate task via tokio::spawn if needed; - let hook_watcher = p.state().get_custommsg_hook_manager(); - hook_watcher - .process_lsp_message(custommsg.payload, &custommsg.peer_id) - .await; - return continue_response; - } -} - -/// Transport implementation for JSON-RPC over Lightning Network using BOLT8 -/// and BOLT1 custom messages. -/// -/// The `Bolt8Transport` allows JSON-RPC requests to be transmitted as custom -/// messages between Lightning Network nodes. It uses Core Lightning's -/// `sendcustommsg` RPC call to send messages and the `custommsg` hook to -/// receive responses. -#[derive(Clone)] -pub struct Bolt8Transport { - /// The node ID of the destination node. - endpoint: cln_rpc::primitives::PublicKey, - /// Path to the Core Lightning RPC socket. - rpc_path: PathBuf, - /// Timeout for requests. - request_timeout: Duration, - /// Hook manager for routing messages. - hook_watcher: CustomMessageHookManager, -} - -impl Bolt8Transport { - /// Creates a new Bolt8Transport. - /// - /// # Arguments - /// * `endpoint` - Node ID of the destination node as a hex string - /// * `rpc_path` - Path to the Core Lightning socket - /// * `hook_watcher` - Hook manager to use for message routing - /// * `timeout` - Optional timeout for requests (defaults to DEFAULT_TIMEOUT) - /// - /// # Returns - /// A new `Bolt8Transport` instance or an error if the node ID is invalid - pub fn new( - endpoint: &str, - rpc_path: PathBuf, - hook_watcher: CustomMessageHookManager, - timeout: Option, - ) -> Result { - let endpoint = cln_rpc::primitives::PublicKey::from_str(endpoint) - .map_err(|e| TransportError::Other(e.to_string()))?; - let timeout = timeout.unwrap_or(DEFAULT_TIMEOUT); - Ok(Self { - endpoint, - rpc_path, - request_timeout: timeout, - hook_watcher, - }) - } - - /// Connects to the Core Lightning node. - async fn connect_to_node(&self) -> Result { - ClnRpc::new(&self.rpc_path) - .await - .map_err(|e| Error::Transport(TransportError::Other(e.to_string()))) - } - - /// Sends a custom message to the destination node. - async fn send_custom_msg(&self, client: &mut ClnRpc, payload: Vec) -> Result<(), Error> { - send_custommsg(client, payload, self.endpoint).await - } - - /// Waits for a response with timeout. - async fn wait_for_response( - &self, - mut rx: mpsc::Receiver, - ) -> Result { - tokio::time::timeout(self.request_timeout, rx.recv()) - .await - .map_err(|_| Error::Transport(TransportError::Timeout))? - .ok_or(Error::Transport(TransportError::Other(String::from( - "Channel unexpectedly closed", - )))) - } -} - -/// Sends a custom message to the destination node. -pub async fn send_custommsg( - client: &mut ClnRpc, - payload: Vec, - peer: PublicKey, -) -> Result<(), Error> { - let msg = CustomMsg { - message_type: LSPS0_MESSAGE_TYPE, - payload, - }; - - let request = cln_rpc::model::requests::SendcustommsgRequest { - msg: msg.to_string(), - node_id: peer, - }; - - client - .call_typed(&request) - .await - .map_err(|e| { - Error::Transport(TransportError::Other(format!( - "Failed to send custom msg: {e}" - ))) - }) - .map(|r| { - trace!("Successfully queued custom msg: {}", r.status); - () - }) -} - -#[async_trait] -impl Transport for Bolt8Transport { - /// Sends a JSON-RPC request and waits for a response. - async fn send(&self, request: String) -> core::result::Result { - let id = extract_message_id(&request)?; - let mut client = self.connect_to_node().await?; - - let (tx, rx) = mpsc::channel(1); - trace!( - "Subscribing to custom msg hook manager for request id={}", - id - ); - - // Create a strong reference that will be dropped after timeout. - let tx_arc = Arc::new(tx); - - self.hook_watcher - .subscribe_hook_once(id, Arc::downgrade(&tx_arc)) - .await; - self.send_custom_msg(&mut client, request.into_bytes()) - .await?; - - let res = self.wait_for_response(rx).await?; - - if res.message_type != LSPS0_MESSAGE_TYPE { - return Err(Error::Transport(TransportError::Other(format!( - "unexpected response message type: expected {}, got {}", - LSPS0_MESSAGE_TYPE, res.message_type - )))); - } - - core::str::from_utf8(&res.payload) - .map_err(|e| { - Error::Transport(TransportError::Other(format!( - "failed to decode msg payload {:?}: {}", - res.payload, e - ))) - }) - .map(|s| s.into()) - } - - /// Sends a notification without waiting for a response. - async fn notify(&self, request: String) -> core::result::Result<(), Error> { - let mut client = self.connect_to_node().await?; - self.send_custom_msg(&mut client, request.into_bytes()) - .await - } -} - -// Extracts the message ID from a JSON-RPC message. -fn extract_message_id(msg: &str) -> core::result::Result { - let id_only: IdOnly = serde_json::from_str(msg)?; - Ok(id_only.id) -} - -/// Represents a custom message with type and payload. -#[derive(Clone, Debug, PartialEq)] -pub struct CustomMsg { - pub message_type: u16, - pub payload: Vec, -} - -impl core::fmt::Display for CustomMsg { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut bytes = Vec::with_capacity(2 + self.payload.len()); - bytes.extend_from_slice(&self.message_type.to_be_bytes()); - bytes.extend_from_slice(&self.payload); - write!(f, "{}", hex::encode(bytes)) - } -} - -impl FromStr for CustomMsg { - type Err = Error; - - fn from_str(s: &str) -> Result { - let bytes = hex::decode(s).map_err(Error::other)?; - - if bytes.len() < 2 { - return Err(Error::other( - "hex string too short to contain a valid message_type", - )); - } - - let message_type_bytes: [u8; 2] = bytes[..2].try_into().map_err(Error::other)?; - let message_type = u16::from_be_bytes(message_type_bytes); - let payload = bytes[2..].to_owned(); - Ok(CustomMsg { - message_type, - payload, - }) - } -} - -impl Serialize for CustomMsg { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&self.to_string()) - } -} - -/// Visitor for deserializing CustomMsg from strings. -struct CustomMsgVisitor; - -impl<'de> Visitor<'de> for CustomMsgVisitor { - type Value = CustomMsg; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a hex string representing a CustomMsg") - } - - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, - { - CustomMsg::from_str(v).map_err(E::custom) - } -} - -impl<'de> Deserialize<'de> for CustomMsg { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_str(CustomMsgVisitor) - } -} - -/// Struct to extract just the ID from a JSON-RPC message. -#[derive(Clone, Debug, Serialize, Deserialize)] -struct IdOnly { - id: String, -} - -/// Return type from custommsg hook. -#[derive(Clone, Debug, Serialize, Deserialize)] -struct CustomMsgHookReturn { - peer_id: String, - payload: CustomMsg, -} - -#[cfg(test)] -mod test_transport { - use super::*; - use serde_json::json; - - // Helper to create a test JSON-RPC request - fn create_test_request(id: &str) -> String { - serde_json::to_string(&json!({ - "jsonrpc": "2.0", - "method": "test_method", - "params": {"test": "value"}, - "id": id - })) - .unwrap() - } - - #[tokio::test] - async fn test_deserialize_custommsg() { - let hex_str = r#"94197b226a736f6e727063223a22322e30222c226d6574686f64223a226c737073302e6c6973745f70726f746f636f6c73222c22706172616d73223a7b7d2c226964223a226135633665613536366333383038313936346263227d"#; - let msg = CustomMsg::from_str(hex_str).unwrap(); - assert_eq!(msg.message_type, LSPS0_MESSAGE_TYPE); - } - - #[tokio::test] - async fn test_extract_message_id() { - // Test with string ID - let request = create_test_request("test-id-123"); - let id = extract_message_id(&request).unwrap(); - assert_eq!(id, "test-id-123"); - } - - #[tokio::test] - async fn custom_msg_serialization() { - let original = CustomMsg { - message_type: 0x1234, - payload: b"test payload".to_vec(), - }; - - // Test to_string and parsing from that string - let serialized = original.to_string(); - - // Convert hex to bytes - let bytes = hex::decode(&serialized).unwrap(); - - // Verify structure - assert_eq!(bytes[0], 0x12); - assert_eq!(bytes[1], 0x34); - assert_eq!(&bytes[2..], b"test payload"); - - // Test deserialization - let deserialized: CustomMsg = - serde_json::from_str(&serde_json::to_string(&serialized).unwrap()).unwrap(); - - assert_eq!(deserialized.message_type, original.message_type); - assert_eq!(deserialized.payload, original.payload); - } - - #[tokio::test] - async fn hook_manager_subscribe_and_process() { - let hook_manager = CustomMessageHookManager::new(); - - // Create test message - let test_id = "test-id-456"; - let test_request = create_test_request(test_id); - let test_msg = CustomMsg { - message_type: LSPS0_MESSAGE_TYPE, - payload: test_request.as_bytes().to_vec(), - }; - - // Set up a subscription - let (tx, mut rx) = mpsc::channel(1); - let tx_arc = Arc::new(tx); - hook_manager - .subscribe_hook_once(test_id, Arc::downgrade(&tx_arc)) - .await; - - // Process the message - let processed = hook_manager - .process_lsp_message(test_msg.clone(), "peer123") - .await; - assert!(processed); - - // Verify the received message - let received_msg = rx.recv().await.unwrap(); - assert_eq!(received_msg.message_type, LSPS0_MESSAGE_TYPE); - assert_eq!(received_msg.payload, test_request.as_bytes()); - } - - #[tokio::test] - async fn hook_manager_no_subscriber() { - let hook_manager = CustomMessageHookManager::new(); - - // Create test message with ID that has no subscriber - let test_request = create_test_request("unknown-id"); - let test_msg = CustomMsg { - message_type: LSPS0_MESSAGE_TYPE, - payload: test_request.as_bytes().to_vec(), - }; - - // Process the message - let processed = hook_manager.process_lsp_message(test_msg, "peer123").await; - assert!(!processed); - } - - #[tokio::test] - async fn hook_manager_clean_up_after_timeout() { - let hook_manager = CustomMessageHookManager::new(); - - // Create test message - let test_id = "test-id-456"; - let test_request = create_test_request(test_id); - let test_msg = CustomMsg { - message_type: LSPS0_MESSAGE_TYPE, - payload: test_request.as_bytes().to_vec(), - }; - - // Set up a subscription - let (tx, _rx) = mpsc::channel(1); - let tx_arc = Arc::new(tx); - hook_manager - .subscribe_hook_once(test_id, Arc::downgrade(&tx_arc)) - .await; - - // drop the reference pointer here to simulate a timeout. - drop(tx_arc); - - // Should not process as the reference has been dropped. - let processed = hook_manager - .process_lsp_message(test_msg.clone(), "peer123") - .await; - assert!(!processed); - assert!(hook_manager.subs.read().await.is_empty()); - } -} diff --git a/plugins/lsps-plugin/src/lsps2/cln.rs b/plugins/lsps-plugin/src/lsps2/cln.rs deleted file mode 100644 index c2789c256fc0..000000000000 --- a/plugins/lsps-plugin/src/lsps2/cln.rs +++ /dev/null @@ -1,830 +0,0 @@ -//! Backfill structs for missing or incomplete Core Lightning types. -//! -//! This module provides struct implementations that are not available or -//! fully accessible in the core-lightning crate, enabling better compatibility -//! and interoperability with Core Lightning's RPC interface. -use cln_rpc::primitives::{Amount, ShortChannelId}; -use hex::FromHex; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -use crate::lsps2::cln::tlv::TlvStream; - -pub const TLV_FORWARD_AMT: u64 = 2; -pub const TLV_OUTGOING_CLTV: u64 = 4; -pub const TLV_SHORT_CHANNEL_ID: u64 = 6; -pub const TLV_PAYMENT_SECRET: u64 = 8; - -#[derive(Debug, Deserialize)] -#[allow(unused)] -pub struct Onion { - pub forward_msat: Option, - #[serde(deserialize_with = "from_hex")] - pub next_onion: Vec, - pub outgoing_cltv_value: Option, - pub payload: TlvStream, - // pub payload: TlvStream, - #[serde(deserialize_with = "from_hex")] - pub shared_secret: Vec, - pub short_channel_id: Option, - pub total_msat: Option, - #[serde(rename = "type")] - pub type_: Option, -} - -#[derive(Debug, Deserialize)] -#[allow(unused)] -pub struct Htlc { - pub amount_msat: Amount, - pub cltv_expiry: u32, - pub cltv_expiry_relative: u16, - pub id: u64, - #[serde(deserialize_with = "from_hex")] - pub payment_hash: Vec, - pub short_channel_id: ShortChannelId, - pub extra_tlvs: Option, -} - -#[derive(Debug, Serialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum HtlcAcceptedResult { - Continue, - Fail, - Resolve, -} - -impl std::fmt::Display for HtlcAcceptedResult { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let s = match self { - HtlcAcceptedResult::Continue => "continue", - HtlcAcceptedResult::Fail => "fail", - HtlcAcceptedResult::Resolve => "resolve", - }; - write!(f, "{s}") - } -} - -#[derive(Debug, Deserialize)] -pub struct HtlcAcceptedRequest { - pub htlc: Htlc, - pub onion: Onion, - pub forward_to: Option, -} - -#[derive(Debug, Serialize)] -pub struct HtlcAcceptedResponse { - pub result: HtlcAcceptedResult, - #[serde(skip_serializing_if = "Option::is_none")] - pub payment_key: Option, - #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] - pub payload: Option>, - #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] - pub forward_to: Option>, - #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] - pub extra_tlvs: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub failure_message: Option, - #[serde(skip_serializing_if = "Option::is_none", serialize_with = "to_hex")] - pub failure_onion: Option>, -} - -impl HtlcAcceptedResponse { - pub fn continue_( - payload: Option>, - forward_to: Option>, - extra_tlvs: Option>, - ) -> Self { - Self { - result: HtlcAcceptedResult::Continue, - payment_key: None, - payload, - forward_to, - extra_tlvs, - failure_message: None, - failure_onion: None, - } - } - - pub fn fail(failure_message: Option, failure_onion: Option>) -> Self { - Self { - result: HtlcAcceptedResult::Fail, - payment_key: None, - payload: None, - forward_to: None, - extra_tlvs: None, - failure_message, - failure_onion, - } - } -} - -#[derive(Debug, Deserialize)] -pub struct InvoicePaymentRequest { - pub payment: InvoicePaymentRequestPayment, -} - -#[derive(Debug, Deserialize)] -pub struct InvoicePaymentRequestPayment { - pub label: String, - pub preimage: String, - pub msat: u64, -} - -#[derive(Debug, Deserialize)] -pub struct OpenChannelRequest { - pub openchannel: OpenChannelRequestOpenChannel, -} - -#[derive(Debug, Deserialize)] -pub struct OpenChannelRequestOpenChannel { - pub id: String, - pub funding_msat: u64, - pub push_msat: u64, - pub dust_limit_msat: u64, - pub max_htlc_value_in_flight_msat: u64, - pub channel_reserve_msat: u64, - pub htlc_minimum_msat: u64, - pub feerate_per_kw: u32, - pub to_self_delay: u32, - pub max_accepted_htlcs: u32, - pub channel_flags: u64, -} - -/// Deserializes a lowercase hex string to a `Vec`. -pub fn from_hex<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - use serde::de::Error; - String::deserialize(deserializer) - .and_then(|string| Vec::from_hex(string).map_err(|err| Error::custom(err.to_string()))) -} - -pub fn to_hex(bytes: &Option>, serializer: S) -> Result -where - S: Serializer, -{ - match bytes { - Some(data) => serializer.serialize_str(&hex::encode(data)), - None => serializer.serialize_none(), - } -} - -pub mod tlv { - use anyhow::Result; - use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer}; - use std::{convert::TryFrom, fmt}; - - #[derive(Clone, Debug)] - pub struct TlvRecord { - pub type_: u64, - pub value: Vec, - } - - #[derive(Clone, Debug, Default)] - pub struct TlvStream(pub Vec); - - #[derive(Debug)] - pub enum TlvError { - DuplicateType(u64), - NotSorted, - LengthMismatch(u64, usize, usize), - Truncated, - NonCanonicalBigSize, - TrailingBytes, - Hex(hex::FromHexError), - Other(String), - } - - impl fmt::Display for TlvError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - TlvError::DuplicateType(t) => write!(f, "duplicate tlv type {}", t), - TlvError::NotSorted => write!(f, "tlv types must be strictly increasing"), - TlvError::LengthMismatch(t, e, g) => { - write!(f, "length mismatch type {}: expected {}, got {}", t, e, g) - } - TlvError::Truncated => write!(f, "truncated input"), - TlvError::NonCanonicalBigSize => write!(f, "non-canonical bigsize encoding"), - TlvError::TrailingBytes => write!(f, "leftover bytes after parsing"), - TlvError::Hex(e) => write!(f, "hex error: {}", e), - TlvError::Other(s) => write!(f, "{}", s), - } - } - } - - impl std::error::Error for TlvError {} - impl From for TlvError { - fn from(e: hex::FromHexError) -> Self { - TlvError::Hex(e) - } - } - - impl TlvStream { - pub fn to_bytes(&mut self) -> Result> { - self.0.sort_by_key(|r| r.type_); - for w in self.0.windows(2) { - if w[0].type_ == w[1].type_ { - return Err(TlvError::DuplicateType(w[0].type_).into()); - } - if w[0].type_ > w[1].type_ { - return Err(TlvError::NotSorted.into()); - } - } - let mut out = Vec::new(); - for rec in &self.0 { - out.extend(encode_bigsize(rec.type_)); - out.extend(encode_bigsize(rec.value.len() as u64)); - out.extend(&rec.value); - } - Ok(out) - } - - pub fn from_bytes(mut bytes: &[u8]) -> Result { - let mut recs = Vec::new(); - let mut last_type: Option = None; - - while !bytes.is_empty() { - let (t, n1) = decode_bigsize(bytes)?; - bytes = &bytes[n1..]; - let (len, n2) = decode_bigsize(bytes)?; - bytes = &bytes[n2..]; - - let l = - usize::try_from(len).map_err(|_| TlvError::Other("length too large".into()))?; - if bytes.len() < l { - return Err(TlvError::Truncated.into()); - } - let v = bytes[..l].to_vec(); - bytes = &bytes[l..]; - - if let Some(prev) = last_type { - if t == prev { - return Err(TlvError::DuplicateType(t).into()); - } - if t < prev { - return Err(TlvError::NotSorted.into()); - } - } - last_type = Some(t); - recs.push(TlvRecord { type_: t, value: v }); - } - Ok(TlvStream(recs)) - } - - pub fn from_bytes_with_length_prefix(bytes: &[u8]) -> Result { - if bytes.is_empty() { - return Err(TlvError::Truncated.into()); - } - - let (length, length_bytes) = decode_bigsize(bytes)?; - let remaining = &bytes[length_bytes..]; - - let length_usize = usize::try_from(length) - .map_err(|_| TlvError::Other("length prefix too large".into()))?; - - if remaining.len() != length_usize { - return Err(TlvError::LengthMismatch(0, length_usize, remaining.len()).into()); - } - - Self::from_bytes(remaining) - } - - /// Attempt to auto-detect whether the input has a length prefix or not - /// First tries to parse as length-prefixed, then falls back to raw TLV - /// parsing. - pub fn from_bytes_auto(bytes: &[u8]) -> Result { - // Try length-prefixed first - if let Ok(stream) = Self::from_bytes_with_length_prefix(bytes) { - return Ok(stream); - } - - // Fall back to raw TLV parsing - Self::from_bytes(bytes) - } - - /// Get a reference to the value of a TLV record by type. - pub fn get(&self, type_: u64) -> Option<&[u8]> { - self.0 - .iter() - .find(|rec| rec.type_ == type_) - .map(|rec| rec.value.as_slice()) - } - - /// Insert a TLV record (replaces if type already exists). - pub fn insert(&mut self, type_: u64, value: Vec) { - // If the type already exists, replace its value. - if let Some(rec) = self.0.iter_mut().find(|rec| rec.type_ == type_) { - rec.value = value; - return; - } - // Otherwise push and re-sort to maintain canonical order. - self.0.push(TlvRecord { type_, value }); - self.0.sort_by_key(|r| r.type_); - } - - /// Remove a record by type. - pub fn remove(&mut self, type_: u64) -> Option> { - if let Some(pos) = self.0.iter().position(|rec| rec.type_ == type_) { - Some(self.0.remove(pos).value) - } else { - None - } - } - - /// Check if a type exists. - pub fn contains(&self, type_: u64) -> bool { - self.0.iter().any(|rec| rec.type_ == type_) - } - - /// Insert or override a `tu64` value for `type_` (keeps canonical TLV order). - pub fn set_tu64(&mut self, type_: u64, value: u64) { - let enc = encode_tu64(value); - if let Some(rec) = self.0.iter_mut().find(|r| r.type_ == type_) { - rec.value = enc; - } else { - self.0.push(TlvRecord { type_, value: enc }); - self.0.sort_by_key(|r| r.type_); - } - } - - /// Read a `tu64` if present, validating minimal encoding. - /// Returns Ok(None) if the type isn't present. - pub fn get_tu64(&self, type_: u64) -> Result, TlvError> { - if let Some(rec) = self.0.iter().find(|r| r.type_ == type_) { - Ok(Some(decode_tu64(&rec.value)?)) - } else { - Ok(None) - } - } - - /// Insert or override a `u64` value for `type_` (keeps cannonical TLV - /// order). - pub fn set_u64(&mut self, type_: u64, value: u64) { - let enc = value.to_be_bytes().to_vec(); - if let Some(rec) = self.0.iter_mut().find(|r| r.type_ == type_) { - rec.value = enc; - } else { - self.0.push(TlvRecord { type_, value: enc }); - self.0.sort_by_key(|r| r.type_); - } - } - - /// Read a `u64` if present.Returns Ok(None) if the type isn't present. - pub fn get_u64(&self, type_: u64) -> Result, TlvError> { - if let Some(rec) = self.0.iter().find(|r| r.type_ == type_) { - let value = u64::from_be_bytes( - rec.value[..] - .try_into() - .map_err(|e| TlvError::Other(format!("failed not decode to u64: {e}")))?, - ); - Ok(Some(value)) - } else { - Ok(None) - } - } - } - - impl Serialize for TlvStream { - fn serialize(&self, serializer: S) -> Result { - let mut tmp = self.clone(); - let bytes = tmp.to_bytes().map_err(serde::ser::Error::custom)?; - serializer.serialize_str(&hex::encode(bytes)) - } - } - - impl<'de> Deserialize<'de> for TlvStream { - fn deserialize>(deserializer: D) -> Result { - struct V; - impl<'de> serde::de::Visitor<'de> for V { - type Value = TlvStream; - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "a hex string representing a Lightning TLV stream") - } - fn visit_str(self, s: &str) -> Result { - let bytes = hex::decode(s).map_err(E::custom)?; - TlvStream::from_bytes_auto(&bytes).map_err(E::custom) - } - } - deserializer.deserialize_str(V) - } - } - - impl TryFrom<&[u8]> for TlvStream { - type Error = anyhow::Error; - fn try_from(value: &[u8]) -> Result { - TlvStream::from_bytes(value) - } - } - - impl From> for TlvStream { - fn from(v: Vec) -> Self { - TlvStream(v) - } - } - - /// BOLT #1 BigSize encoding - fn encode_bigsize(x: u64) -> Vec { - let mut out = Vec::new(); - if x < 0xfd { - out.push(x as u8); - } else if x <= 0xffff { - out.push(0xfd); - out.extend_from_slice(&(x as u16).to_be_bytes()); - } else if x <= 0xffff_ffff { - out.push(0xfe); - out.extend_from_slice(&(x as u32).to_be_bytes()); - } else { - out.push(0xff); - out.extend_from_slice(&x.to_be_bytes()); - } - out - } - - fn decode_bigsize(input: &[u8]) -> Result<(u64, usize)> { - if input.is_empty() { - return Err(TlvError::Truncated.into()); - } - match input[0] { - n @ 0x00..=0xfc => Ok((n as u64, 1)), - 0xfd => { - if input.len() < 3 { - return Err(TlvError::Truncated.into()); - } - let v = u16::from_be_bytes([input[1], input[2]]) as u64; - if v < 0xfd { - return Err(TlvError::NonCanonicalBigSize.into()); - } - Ok((v, 3)) - } - 0xfe => { - if input.len() < 5 { - return Err(TlvError::Truncated.into()); - } - let v = u32::from_be_bytes([input[1], input[2], input[3], input[4]]) as u64; - if v <= 0xffff { - return Err(TlvError::NonCanonicalBigSize.into()); - } - Ok((v, 5)) - } - 0xff => { - if input.len() < 9 { - return Err(TlvError::Truncated.into()); - } - let v = u64::from_be_bytes([ - input[1], input[2], input[3], input[4], input[5], input[6], input[7], input[8], - ]); - if v <= 0xffff_ffff { - return Err(TlvError::NonCanonicalBigSize.into()); - } - Ok((v, 9)) - } - } - } - - /// Encode a BOLT #1 `tu64`: big-endian, minimal length (no leading 0x00). - /// Value 0 is encoded as zero-length. - pub fn encode_tu64(v: u64) -> Vec { - if v == 0 { - return Vec::new(); - } - let bytes = v.to_be_bytes(); - let first = bytes.iter().position(|&b| b != 0).unwrap(); // safe: v != 0 - bytes[first..].to_vec() - } - - /// Decode a BOLT #1 `tu64`, enforcing minimal form. - /// Empty slice -> 0. Leading 0x00 or >8 bytes is invalid. - fn decode_tu64(raw: &[u8]) -> Result { - if raw.is_empty() { - return Ok(0); - } - if raw.len() > 8 { - return Err(TlvError::Other("tu64 too long".into())); - } - if raw[0] == 0 { - return Err(TlvError::Other("non-minimal tu64 (leading zero)".into())); - } - let mut buf = [0u8; 8]; - buf[8 - raw.len()..].copy_from_slice(raw); - Ok(u64::from_be_bytes(buf)) - } - - #[cfg(test)] - mod tests { - use super::*; - use anyhow::Result; - - // Small helpers to keep tests readable - fn rec(type_: u64, value: &[u8]) -> TlvRecord { - TlvRecord { - type_, - value: value.to_vec(), - } - } - - fn build_bytes(type_: u64, value: &[u8]) -> Vec { - let mut v = Vec::new(); - v.extend(super::encode_bigsize(type_)); - v.extend(super::encode_bigsize(value.len() as u64)); - v.extend(value); - v - } - - #[test] - fn encode_then_decode_roundtrip() -> Result<()> { - let mut stream = TlvStream(vec![rec(1, &[0x01, 0x02]), rec(5, &[0xaa])]); - - // Encode - let bytes = stream.to_bytes()?; - // Expect exact TLV sequence: - // type=1 -> 0x01, len=2 -> 0x02, value=0x01 0x02 - // type=5 -> 0x05, len=1 -> 0x01, value=0xaa - assert_eq!(hex::encode(&bytes), "010201020501aa"); - - // Decode back - let decoded = TlvStream::from_bytes(&bytes)?; - assert_eq!(decoded.0.len(), 2); - assert_eq!(decoded.0[0].type_, 1); - assert_eq!(decoded.0[0].value, vec![0x01, 0x02]); - assert_eq!(decoded.0[1].type_, 5); - assert_eq!(decoded.0[1].value, vec![0xaa]); - - Ok(()) - } - - #[test] - fn json_hex_roundtrip() -> Result<()> { - let stream = TlvStream(vec![rec(1, &[0x01, 0x02]), rec(5, &[0xaa])]); - - // Serialize to hex string in JSON - let json = serde_json::to_string(&stream)?; - // It's a quoted hex string; check inner value - let s: String = serde_json::from_str(&json)?; - assert_eq!(s, "010201020501aa"); - - // And back from JSON hex - let back: TlvStream = serde_json::from_str(&json)?; - assert_eq!(back.0.len(), 2); - assert_eq!(back.0[0].type_, 1); - assert_eq!(back.0[0].value, vec![0x01, 0x02]); - assert_eq!(back.0[1].type_, 5); - assert_eq!(back.0[1].value, vec![0xaa]); - - Ok(()) - } - - #[test] - fn decode_with_len_prefix() -> Result<()> { - let payload = "1202039896800401760608000073000f2c0007"; - let stream = TlvStream::from_bytes_with_length_prefix(&hex::decode(payload).unwrap())?; - // let stream: TlvStream = serde_json::from_str(payload)?; - println!("TLV {:?}", stream.0); - - Ok(()) - } - - #[test] - fn bigsize_boundary_minimal_encodings() -> Result<()> { - // Types at 0xfc, 0xfd, 0x10000 to exercise size switches - let mut stream = TlvStream(vec![ - rec(0x00fc, &[0x11]), // single-byte bigsize - rec(0x00fd, &[0x22]), // 0xfd prefix + u16 - rec(0x0001_0000, &[0x33]), // 0xfe prefix + u32 - ]); - - let bytes = stream.to_bytes()?; // just ensure it encodes - // Decode back to confirm roundtrip/canonical encodings accepted - let back = TlvStream::from_bytes(&bytes)?; - assert_eq!(back.0[0].type_, 0x00fc); - assert_eq!(back.0[1].type_, 0x00fd); - assert_eq!(back.0[2].type_, 0x0001_0000); - Ok(()) - } - - #[test] - fn decode_rejects_non_canonical_bigsize() { - // (1) Non-canonical: 0xfd 00 fc encodes 0xfc but should be a single byte - let mut bytes = Vec::new(); - bytes.extend([0xfd, 0x00, 0xfc]); // non-canonical type - bytes.extend([0x01]); // len = 1 - bytes.extend([0x00]); // value - let err = TlvStream::from_bytes(&bytes).unwrap_err(); - assert!(format!("{}", err).contains("non-canonical")); - - // (2) Non-canonical: 0xfe 00 00 00 ff encodes 0xff but should be 0xfd-form - let mut bytes = Vec::new(); - bytes.extend([0xfe, 0x00, 0x00, 0x00, 0xff]); - bytes.extend([0x01]); - bytes.extend([0x00]); - let err = TlvStream::from_bytes(&bytes).unwrap_err(); - assert!(format!("{}", err).contains("non-canonical")); - - // (3) Non-canonical: 0xff 00..01 encodes 1, which should be single byte - let mut bytes = Vec::new(); - bytes.extend([0xff, 0, 0, 0, 0, 0, 0, 0, 1]); - bytes.extend([0x01]); - bytes.extend([0x00]); - let err = TlvStream::from_bytes(&bytes).unwrap_err(); - assert!(format!("{}", err).contains("non-canonical")); - } - - #[test] - fn decode_rejects_out_of_order_types() { - // Build two TLVs but put type 5 before type 1 - let mut bad = Vec::new(); - bad.extend(build_bytes(5, &[0xaa])); - bad.extend(build_bytes(1, &[0x00])); - - let err = TlvStream::from_bytes(&bad).unwrap_err(); - assert!( - format!("{}", err).contains("increasing") || format!("{}", err).contains("sorted"), - "expected ordering error, got: {err}" - ); - } - - #[test] - fn decode_rejects_duplicate_types() { - // Two records with same type=1 - let mut bad = Vec::new(); - bad.extend(build_bytes(1, &[0x01])); - bad.extend(build_bytes(1, &[0x02])); - let err = TlvStream::from_bytes(&bad).unwrap_err(); - assert!( - format!("{}", err).contains("duplicate"), - "expected duplicate error, got: {err}" - ); - } - - #[test] - fn encode_rejects_duplicate_types() { - // insert duplicate types and expect encode to fail - let mut s = TlvStream(vec![rec(1, &[0x01]), rec(1, &[0x02])]); - let err = s.to_bytes().unwrap_err(); - assert!( - format!("{}", err).contains("duplicate"), - "expected duplicate error, got: {err}" - ); - } - - #[test] - fn decode_truncated_value() { - // type=1, len=2 but only 1 byte of value provided - let mut bytes = Vec::new(); - bytes.extend(encode_bigsize(1)); - bytes.extend(encode_bigsize(2)); - bytes.push(0x00); // missing one more byte - let err = TlvStream::from_bytes(&bytes).unwrap_err(); - assert!( - format!("{}", err).contains("truncated"), - "expected truncated error, got: {err}" - ); - } - - #[test] - fn set_and_get_u64_basic() -> Result<()> { - let mut s = TlvStream::default(); - s.set_u64(42, 123456789); - assert_eq!(s.get_u64(42)?, Some(123456789)); - Ok(()) - } - - #[test] - fn set_u64_overwrite_keeps_order() -> Result<()> { - let mut s = TlvStream(vec![ - TlvRecord { - type_: 1, - value: vec![0xaa], - }, - TlvRecord { - type_: 10, - value: vec![0xbb], - }, - ]); - - // insert between 1 and 10 - s.set_u64(5, 7); - assert_eq!( - s.0.iter().map(|r| r.type_).collect::>(), - vec![1, 5, 10] - ); - assert_eq!(s.get_u64(5)?, Some(7)); - - // overwrite existing 5 (no duplicate, order preserved) - s.set_u64(5, 9); - let types: Vec = s.0.iter().map(|r| r.type_).collect(); - assert_eq!(types, vec![1, 5, 10]); - assert_eq!(s.0.iter().filter(|r| r.type_ == 5).count(), 1); - assert_eq!(s.get_u64(5)?, Some(9)); - Ok(()) - } - - #[test] - fn set_and_get_tu64_basic() -> Result<()> { - let mut s = TlvStream::default(); - s.set_tu64(42, 123456789); - assert_eq!(s.get_tu64(42)?, Some(123456789)); - Ok(()) - } - - #[test] - fn get_u64_missing_returns_none() -> Result<()> { - let s = TlvStream::default(); - assert_eq!(s.get_u64(999)?, None); - Ok(()) - } - - #[test] - fn set_tu64_overwrite_keeps_order() -> Result<()> { - let mut s = TlvStream(vec![ - TlvRecord { - type_: 1, - value: vec![0xaa], - }, - TlvRecord { - type_: 10, - value: vec![0xbb], - }, - ]); - - // insert between 1 and 10 - s.set_tu64(5, 7); - assert_eq!( - s.0.iter().map(|r| r.type_).collect::>(), - vec![1, 5, 10] - ); - assert_eq!(s.get_tu64(5)?, Some(7)); - - // overwrite existing 5 (no duplicate, order preserved) - s.set_tu64(5, 9); - let types: Vec = s.0.iter().map(|r| r.type_).collect(); - assert_eq!(types, vec![1, 5, 10]); - assert_eq!(s.0.iter().filter(|r| r.type_ == 5).count(), 1); - assert_eq!(s.get_tu64(5)?, Some(9)); - Ok(()) - } - - #[test] - fn tu64_zero_encodes_empty_and_roundtrips() -> Result<()> { - let mut s = TlvStream::default(); - s.set_tu64(3, 0); - - // stored value is zero-length - let rec = s.0.iter().find(|r| r.type_ == 3).unwrap(); - assert!(rec.value.is_empty()); - - // wire round-trip - let mut sc = s.clone(); - let bytes = sc.to_bytes()?; - let s2 = TlvStream::from_bytes(&bytes)?; - assert_eq!(s2.get_tu64(3)?, Some(0)); - Ok(()) - } - - #[test] - fn get_tu64_missing_returns_none() -> Result<()> { - let s = TlvStream::default(); - assert_eq!(s.get_tu64(999)?, None); - Ok(()) - } - - #[test] - fn get_tu64_rejects_non_minimal_and_too_long() { - // non-minimal: leading zero - let mut s = TlvStream::default(); - s.0.push(TlvRecord { - type_: 9, - value: vec![0x00, 0x01], - }); - assert!(s.get_tu64(9).is_err()); - - // too long: 9 bytes - let mut s2 = TlvStream::default(); - s2.0.push(TlvRecord { - type_: 9, - value: vec![0; 9], - }); - assert!(s2.get_tu64(9).is_err()); - } - - #[test] - fn tu64_multi_roundtrip_bytes_and_json() -> Result<()> { - let mut s = TlvStream::default(); - s.set_tu64(42, 0); - s.set_tu64(7, 256); - - // wire roundtrip - let mut sc = s.clone(); - let bytes = sc.to_bytes()?; - let s2 = TlvStream::from_bytes(&bytes)?; - assert_eq!(s2.get_tu64(42)?, Some(0)); - assert_eq!(s2.get_tu64(7)?, Some(256)); - - // json hex roundtrip (custom Serialize/Deserialize) - let json = serde_json::to_string(&s)?; - let s3: TlvStream = serde_json::from_str(&json)?; - assert_eq!(s3.get_tu64(42)?, Some(0)); - assert_eq!(s3.get_tu64(7)?, Some(256)); - Ok(()) - } - } -} diff --git a/plugins/lsps-plugin/src/lsps2/mod.rs b/plugins/lsps-plugin/src/lsps2/mod.rs deleted file mode 100644 index 60d8ebf5f101..000000000000 --- a/plugins/lsps-plugin/src/lsps2/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -use cln_plugin::options; - -pub mod cln; -pub mod handler; -pub mod model; - -pub const OPTION_ENABLED: options::FlagConfigOption = options::ConfigOption::new_flag( - "experimental-lsps2-service", - "Enables lsps2 for the LSP service", -); - -pub const OPTION_PROMISE_SECRET: options::StringConfigOption = - options::ConfigOption::new_str_no_default( - "experimental-lsps2-promise-secret", - "A 64-character hex string that is the secret for promises", - ); - -pub const DS_MAIN_KEY: &'static str = "lsps"; -pub const DS_SUB_KEY: &'static str = "lsps2"; diff --git a/plugins/lsps-plugin/src/proto/jsonrpc.rs b/plugins/lsps-plugin/src/proto/jsonrpc.rs new file mode 100644 index 000000000000..24b307131d7d --- /dev/null +++ b/plugins/lsps-plugin/src/proto/jsonrpc.rs @@ -0,0 +1,564 @@ +use rand::{rngs::OsRng, TryRngCore as _}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::{self, Value}; +use std::fmt; + +// Constants for JSON-RPC error codes. +pub const PARSE_ERROR: i64 = -32700; +pub const INVALID_REQUEST: i64 = -32600; +pub const METHOD_NOT_FOUND: i64 = -32601; +pub const INVALID_PARAMS: i64 = -32602; +pub const INTERNAL_ERROR: i64 = -32603; + +/// Trait for types that can be converted into JSON-RPC request objects. +/// +/// Implementing this trait allows a struct to be used as a typed JSON-RPC +/// request, with an associated method name and automatic conversion to the +/// request format. +pub trait JsonRpcRequest: Serialize { + const METHOD: &'static str; + fn into_request(self) -> RequestObject + where + Self: Sized, + { + Self::into_request_with_id(self, Some(generate_random_id())) + } + + fn into_request_with_id(self, id: Option) -> RequestObject + where + Self: Sized, + { + RequestObject { + jsonrpc: "2.0".into(), + method: Self::METHOD.into(), + params: Some(self), + id, + } + } +} + +/// Generates a random ID for JSON-RPC requests. +/// +/// Uses a secure random number generator to create a hex-encoded ID. Falls back +/// to a timestamp-based ID if random generation fails. +fn generate_random_id() -> String { + let mut bytes = [0u8; 10]; + match OsRng.try_fill_bytes(&mut bytes) { + Ok(_) => hex::encode(bytes), + Err(_) => { + // Fallback to a timestamp-based ID if random generation fails + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + format!("fallback-{}", timestamp) + } + } +} + +/// # RequestObject +/// +/// Represents a JSON-RPC 2.0 Request object, as defined in section 4 of the +/// specification. This structure encapsulates all necessary information for +/// a remote procedure call. +/// +/// # Type Parameters +/// +/// * `T`: The type of the `params` field. This *MUST* implement `Serialize` +/// to allow it to be encoded as JSON. Typically this will be a struct +/// implementing the `JsonRpcRequest` trait. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RequestObject { + /// **REQUIRED**. MUST be `"2.0"`. + pub jsonrpc: String, + /// **REQUIRED**. The method to be invoked. + pub method: String, + /// A struct containing the method parameters. + #[serde(skip_serializing_if = "is_none_or_null")] + pub params: Option, + /// An identifier established by the Client that MUST contain a String. + /// # Note: this is special to LSPS0, might change to match the more general + /// JSON-RPC 2.0 sepec if needed. + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, +} + +impl RequestObject +where + T: Serialize, +{ + /// Returns the inner data object contained by params for handling or future + /// processing. + pub fn into_inner(self) -> Option { + self.params + } +} + +/// Helper function to check if params is None or would serialize to null. +fn is_none_or_null(opt: &Option) -> bool { + match opt { + None => true, + Some(val) => match serde_json::to_value(&val) { + Ok(Value::Null) => true, + _ => false, + }, + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct JsonRpcResponse { + id: String, + body: JsonRpcResponseBody, +} + +impl JsonRpcResponse<()> { + pub fn error>(error: RpcError, id: T) -> Self { + Self { + id: id.into(), + body: JsonRpcResponseBody::Error { error }, + } + } +} + +impl JsonRpcResponse { + pub fn success>(result: R, id: T) -> Self { + Self { + id: id.into(), + body: JsonRpcResponseBody::Success { result }, + } + } + + pub fn into_result(self) -> std::result::Result { + self.body.into_result() + } + + pub fn as_result(&self) -> std::result::Result<&R, &RpcError> { + self.body.as_result() + } + + pub fn is_ok(&self) -> bool { + self.body.is_ok() + } + + pub fn is_err(&self) -> bool { + self.body.is_err() + } + + pub fn map(self, f: F) -> JsonRpcResponse + where + F: FnOnce(R) -> U, + { + JsonRpcResponse { + id: self.id, + body: self.body.map(f), + } + } + + pub fn unwrap(self) -> R { + self.body.unwrap() + } + + pub fn expect(self, msg: &str) -> R { + self.body.expect(msg) + } + + pub fn unwrap_err(self) -> RpcError { + self.body.unwrap_err() + } + + pub fn expect_err(self, msg: &str) -> RpcError { + self.body.expect_err(msg) + } +} + +// Custom Serialize to match JSON-RPC 2.0 wire format +impl Serialize for JsonRpcResponse { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + + let mut state = serializer.serialize_struct("JsonRpcResponse", 3)?; + state.serialize_field("jsonrpc", "2.0")?; + state.serialize_field("id", &self.id)?; + + match &self.body { + JsonRpcResponseBody::Success { result } => { + state.serialize_field("result", result)?; + } + JsonRpcResponseBody::Error { error } => { + state.serialize_field("error", error)?; + } + } + + state.end() + } +} + +// Custom Deserialize from JSON-RPC 2.0 wire format +impl<'de, R: DeserializeOwned> Deserialize<'de> for JsonRpcResponse { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct RawResponse { + jsonrpc: String, + result: Option, + error: Option, + id: String, + } + + let raw = RawResponse::deserialize(deserializer)?; + + if raw.jsonrpc != "2.0" { + return Err(serde::de::Error::custom(format!( + "Invalid JSON-RPC version: {}", + raw.jsonrpc + ))); + } + + let body = match (raw.result, raw.error) { + (Some(result), None) => JsonRpcResponseBody::Success { result }, + (None, Some(error)) => JsonRpcResponseBody::Error { error }, + (Some(_), Some(_)) => { + return Err(serde::de::Error::custom( + "Response cannot have both result and error", + )) + } + (None, None) => { + return Err(serde::de::Error::custom( + "Response must have either result or error", + )) + } + }; + + Ok(JsonRpcResponse { id: raw.id, body }) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum JsonRpcResponseBody { + Success { result: R }, + Error { error: RpcError }, +} + +impl JsonRpcResponseBody { + pub fn into_result(self) -> std::result::Result { + match self { + Self::Success { result } => Ok(result), + Self::Error { error } => Err(error), + } + } + + pub fn as_result(&self) -> std::result::Result<&R, &RpcError> { + match self { + Self::Success { result } => Ok(result), + Self::Error { error } => Err(error), + } + } + + pub fn is_ok(&self) -> bool { + matches!(self, JsonRpcResponseBody::Success { .. }) + } + + pub fn is_err(&self) -> bool { + matches!(self, JsonRpcResponseBody::Error { .. }) + } + + pub fn map(self, f: F) -> JsonRpcResponseBody + where + F: FnOnce(R) -> U, + { + match self { + Self::Success { result } => JsonRpcResponseBody::Success { result: f(result) }, + Self::Error { error } => JsonRpcResponseBody::Error { error }, + } + } + + pub fn unwrap(self) -> R { + match self { + Self::Success { result } => result, + Self::Error { error } => panic!("Called unwrap on RPC Error: {}", error), + } + } + + pub fn expect(self, msg: &str) -> R { + match self { + Self::Success { result } => result, + Self::Error { error } => panic!("{}: {}", msg, error), + } + } + + pub fn unwrap_err(self) -> RpcError { + match self { + JsonRpcResponseBody::Success { .. } => { + panic!("Called unwrap_err on RPC Success") + } + JsonRpcResponseBody::Error { error } => error, + } + } + + pub fn expect_err(self, msg: &str) -> RpcError { + match self { + JsonRpcResponseBody::Success { .. } => panic!("{}", msg), + JsonRpcResponseBody::Error { error } => error, + } + } +} + +/// Macro to generate RpcError helper methods for protocol-specific error codes +/// +/// This generates two methods for each error code: +/// - `method_name(message)` - Creates error without data +/// - `method_name_with_data(message, data)` - Creates error with data +macro_rules! rpc_error_methods { + ($($method:ident => $code:expr),* $(,)?) => { + $( + paste::paste! { + fn $method(message: T) -> $crate::proto::jsonrpc::RpcError { + $crate::proto::jsonrpc::RpcError { + code: $code, + message: message.to_string(), + data: None, + } + } + + fn [<$method _with_data>]( + message: T, + data: serde_json::Value, + ) -> $crate::proto::jsonrpc::RpcError { + $crate::proto::jsonrpc::RpcError { + code: $code, + message: message.to_string(), + data: Some(data), + } + } + } + )* + }; +} + +/// # RpcError +/// +/// Represents an error object in a JSON-RPC 2.0 Response object (section 5.1). +/// Provides structured information about an error that occurred during the +/// method invocation. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RpcError { + /// **REQUIRED**. An integer indicating the type of error. + pub code: i64, + /// **REQUIRED**. A string containing a short description of the error. + pub message: String, + /// A primitive that can be either Primitive or Structured type if there + /// were. + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +impl RpcError { + /// Reserved for implementation-defined server-errors. + pub fn custom_error(code: i64, message: T) -> Self { + RpcError { + code, + message: message.to_string(), + data: None, + } + } + + /// Reserved for implementation-defined server-errors. + pub fn custom_error_with_data( + code: i64, + message: T, + data: serde_json::Value, + ) -> Self { + RpcError { + code, + message: message.to_string(), + data: Some(data), + } + } +} + +pub trait RpcErrorExt { + rpc_error_methods! { + parse_error => PARSE_ERROR, + internal_error => INTERNAL_ERROR, + invalid_params => INVALID_PARAMS, + method_not_found => METHOD_NOT_FOUND, + invalid_request => INVALID_REQUEST, + } +} + +impl RpcErrorExt for RpcError {} + +impl fmt::Display for RpcError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "JSON-RPC Error (code: {}, message: {}, data: {:?})", + self.code, self.message, self.data + ) + } +} + +impl std::error::Error for RpcError {} + +#[cfg(test)] +mod test_message_serialization { + use super::*; + use serde_json::json; + + #[test] + fn test_empty_params_serialization() { + // Empty params should serialize to `"params":{}` instead of + // `"params":null`. + #[derive(Debug, Serialize, Deserialize)] + pub struct SayHelloRequest; + impl JsonRpcRequest for SayHelloRequest { + const METHOD: &'static str = "say_hello"; + } + let rpc_request = SayHelloRequest.into_request(); + assert!(!serde_json::to_string(&rpc_request) + .expect("could not convert to json") + .contains("\"params\"")); + } + + #[test] + fn test_request_serialization_and_deserialization() { + // Ensure that we correctly serialize to a valid JSON-RPC 2.0 request. + #[derive(Default, Debug, Serialize, Deserialize)] + pub struct SayNameRequest { + name: String, + age: i32, + } + impl JsonRpcRequest for SayNameRequest { + const METHOD: &'static str = "say_name"; + } + let rpc_request = SayNameRequest { + name: "Satoshi".to_string(), + age: 99, + } + .into_request_with_id(Some("unique-id-123".into())); + + let json_value: serde_json::Value = serde_json::to_value(&rpc_request).unwrap(); + let expected_value: serde_json::Value = serde_json::json!({ + "jsonrpc": "2.0", + "method": "say_name", + "params": { + "name": "Satoshi", + "age": 99 + }, + "id": "unique-id-123" + }); + assert_eq!(json_value, expected_value); + + let request: RequestObject = serde_json::from_value(json_value).unwrap(); + assert_eq!(request.method, "say_name"); + assert_eq!(request.jsonrpc, "2.0"); + + let request: RequestObject = + serde_json::from_value(expected_value).unwrap(); + let inner = request.into_inner(); + assert_eq!(inner.unwrap().name, rpc_request.params.unwrap().name); + } + + #[test] + fn test_response_deserialization() { + // Check that we can convert a JSON-RPC response into a typed result. + #[derive(Debug, Serialize, Deserialize, PartialEq)] + pub struct SayNameResponse { + name: String, + age: i32, + message: String, + } + + let json_response = r#" + { + "jsonrpc": "2.0", + "result": { + "age": 99, + "message": "Hello Satoshi!", + "name": "Satoshi" + }, + "id": "unique-id-123" + }"#; + + let response: JsonRpcResponse = + serde_json::from_str(json_response).unwrap(); + + let result = response.into_result().unwrap(); + assert_eq!( + result, + SayNameResponse { + name: "Satoshi".into(), + age: 99, + message: "Hello Satoshi!".into(), + } + ); + } + + #[test] + fn test_empty_result() { + // Check that we correctly deserialize an empty result. + #[derive(Debug, Serialize, Deserialize, PartialEq)] + pub struct DummyResponse {} + + let json_response = r#" + { + "jsonrpc": "2.0", + "result": {}, + "id": "unique-id-123" + }"#; + + let response: JsonRpcResponse = serde_json::from_str(json_response).unwrap(); + let result = response.into_result().unwrap(); + assert_eq!(result, DummyResponse {}); + } + #[test] + fn test_error_deserialization() { + // Check that we deserialize an error if we got one. + #[derive(Debug, Serialize, Deserialize, PartialEq)] + pub struct DummyResponse {} + + let json_response = r#" + { + "jsonrpc": "2.0", + "id": "unique-id-123", + "error": { + "code": -32099, + "message": "something bad happened", + "data": { + "f1": "v1", + "f2": 2 + } + } + }"#; + + let response: JsonRpcResponse = serde_json::from_str(json_response).unwrap(); + assert!(response.is_err()); + + let err = response.into_result().unwrap_err(); + assert!(matches!(err, RpcError { .. })); + assert_eq!(err.message, "something bad happened"); + assert_eq!(err.data, Some(serde_json::json!({"f1":"v1","f2":2}))); + } + + #[test] + fn test_error_serialization() { + let error = RpcError::invalid_request("Invalid request"); + let serialized = serde_json::to_string(&error).unwrap(); + assert_eq!(serialized, r#"{"code":-32600,"message":"Invalid request"}"#); + + let error_with_data = RpcError::internal_error_with_data( + "Internal server error", + json!({"details": "Something went wrong"}), + ); + let serialized_with_data = serde_json::to_string(&error_with_data).unwrap(); + assert_eq!( + serialized_with_data, + r#"{"code":-32603,"message":"Internal server error","data":{"details":"Something went wrong"}}"# + ); + } +} diff --git a/plugins/lsps-plugin/src/lsps0/primitives.rs b/plugins/lsps-plugin/src/proto/lsps0.rs similarity index 59% rename from plugins/lsps-plugin/src/lsps0/primitives.rs rename to plugins/lsps-plugin/src/proto/lsps0.rs index 3b2ff52d1e8c..2cb72812931f 100644 --- a/plugins/lsps-plugin/src/lsps0/primitives.rs +++ b/plugins/lsps-plugin/src/proto/lsps0.rs @@ -1,11 +1,70 @@ +use crate::proto::jsonrpc::{JsonRpcRequest, RpcError}; use core::fmt; -use serde::{ - de::{self}, - Deserialize, Deserializer, Serialize, Serializer, -}; +use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use thiserror::Error; const MSAT_PER_SAT: u64 = 1_000; +// Optional feature bet to set according to LSPS0. +pub const LSP_FEATURE_BIT: usize = 729; + +// Required message type for BOLT8 transport. +pub const LSPS0_MESSAGE_TYPE: u16 = 37913; + +// Lsps0 specific error codes defined in BLIP-50. +// Are in the range 00000 to 00099. +pub mod error_codes { + pub const CLIENT_REJECTED: i64 = 001; +} + +pub trait LSPS0RpcErrorExt { + rpc_error_methods! { + client_rejected => error_codes::CLIENT_REJECTED, + } +} + +impl LSPS0RpcErrorExt for RpcError {} + +#[derive(Error, Debug)] +pub enum Error { + #[error("invalid message type: got {type_}")] + InvalidMessageType { type_: u16 }, + #[error("Invalid UTF-8 in message payload")] + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("message too short: expected at least 2 bytes for type field, got {0}")] + TooShort(usize), +} + +pub type Result = std::result::Result; + +/// Encode raw payload bytes into an LSPS0 frame. +/// +/// Format: +/// [0-1] Message type: 37913 (0x9419) as big-endian u16 +/// [2..] Payload bytes +pub fn encode_frame(payload: &[u8]) -> Vec { + let mut bytes = Vec::with_capacity(2 + payload.len()); + bytes.extend_from_slice(&LSPS0_MESSAGE_TYPE.to_be_bytes()); + bytes.extend_from_slice(payload); + bytes +} + +/// Decode an LSPS0 frame and return the raw payload bytes. +/// +/// Validates that the header matches the LSPS0 message type (37913). +pub fn decode_frame(bytes: &[u8]) -> Result<&[u8]> { + if bytes.len() < 2 { + return Err(Error::TooShort(bytes.len())); + } + let message_type = u16::from_be_bytes([bytes[0], bytes[1]]); + if message_type != LSPS0_MESSAGE_TYPE { + return Err(Error::InvalidMessageType { + type_: message_type, + }); + } + Ok(&bytes[2..]) +} + /// Represents a monetary amount as defined in LSPS0.msat. Is converted to a /// `String` in json messages with a suffix `_msat` or `_sat` and internally /// represented as Millisatoshi `u64`. @@ -42,7 +101,7 @@ impl core::fmt::Display for Msat { } impl Serialize for Msat { - fn serialize(&self, serializer: S) -> Result + fn serialize(&self, serializer: S) -> std::result::Result where S: Serializer, { @@ -51,7 +110,7 @@ impl Serialize for Msat { } impl<'de> Deserialize<'de> for Msat { - fn deserialize(deserializer: D) -> Result + fn deserialize(deserializer: D) -> std::result::Result where D: Deserializer<'de>, { @@ -64,7 +123,7 @@ impl<'de> Deserialize<'de> for Msat { formatter.write_str("a string representing a number") } - fn visit_str(self, value: &str) -> Result + fn visit_str(self, value: &str) -> std::result::Result where E: de::Error, { @@ -75,14 +134,14 @@ impl<'de> Deserialize<'de> for Msat { } // Also handle if JSON mistakenly has a number instead of string - fn visit_u64(self, value: u64) -> Result + fn visit_u64(self, value: u64) -> std::result::Result where E: de::Error, { Ok(Msat::from_msat(value)) } - fn visit_i64(self, value: i64) -> Result + fn visit_i64(self, value: i64) -> std::result::Result where E: de::Error, { @@ -105,7 +164,7 @@ impl<'de> Deserialize<'de> for Msat { /// provides more clarity. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[serde(transparent)] // Key attribute! Serialize/Deserialize as the inner u32 -pub struct Ppm(pub u32); // u32 is sufficient as 1,000,000 fits easily +pub struct Ppm(pub u32); // u32 is sufficient as 1,000,000 fits easily (spec definition) impl Ppm { /// Constructs a new `Ppm` from a u32. @@ -139,16 +198,71 @@ pub type ShortChannelId = cln_rpc::primitives::ShortChannelId; /// timezone. pub type DateTime = chrono::DateTime; +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Lsps0listProtocolsRequest {} + +impl JsonRpcRequest for Lsps0listProtocolsRequest { + const METHOD: &'static str = "lsps0.list_protocols"; +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct Lsps0listProtocolsResponse { + pub protocols: Vec, +} + #[cfg(test)] mod tests { use super::*; - use serde_json; + + const TEST_JSON: &str = r#"{"jsonrpc":"2.0","method":"test","id":"1"}"#; #[derive(Debug, Serialize, Deserialize)] struct TestMessage { amount: Msat, } + #[test] + fn test_encode_frame() { + let json = TEST_JSON.as_bytes(); + let wire_bytes = encode_frame(json); + + assert_eq!(wire_bytes.len(), 2 + json.len()); + assert_eq!(wire_bytes[0], 0x94); + assert_eq!(wire_bytes[1], 0x19); + assert_eq!(&wire_bytes[2..], json); + } + + #[test] + fn test_encode_decode_frame_roundtrip() { + let json = TEST_JSON.as_bytes(); + let wire_bytes = encode_frame(json); + let decoded = decode_frame(&wire_bytes).expect("should decode the frame"); + + assert_eq!(decoded, json) + } + + #[test] + fn test_decode_empty_frame() { + let result = decode_frame(&[]); + assert!(matches!(result, Err(Error::TooShort(0)))); + } + + #[test] + fn test_decode_single_byte_frame() { + let result = decode_frame(&[0x94]); + assert!(matches!(result, Err(Error::TooShort(1)))); + } + + #[test] + fn test_decode_frame_with_wrong_message_type() { + let bytes = vec![0x00, 0x01, b'{', b'}']; + let result = decode_frame(&bytes); + assert!(matches!( + result, + Err(Error::InvalidMessageType { type_: 1 }) + )); + } + /// Test serialization of a struct containing Msat. #[test] fn test_msat_serialization() { diff --git a/plugins/lsps-plugin/src/lsps2/model.rs b/plugins/lsps-plugin/src/proto/lsps2.rs similarity index 90% rename from plugins/lsps-plugin/src/lsps2/model.rs rename to plugins/lsps-plugin/src/proto/lsps2.rs index bcf8de079d05..82767a27ea53 100644 --- a/plugins/lsps-plugin/src/lsps2/model.rs +++ b/plugins/lsps-plugin/src/proto/lsps2.rs @@ -1,6 +1,6 @@ -use crate::{ +use crate::proto::{ jsonrpc::{JsonRpcRequest, RpcError}, - lsps0::primitives::{DateTime, Msat, Ppm, ShortChannelId}, + lsps0::{DateTime, Msat, Ppm, ShortChannelId}, }; use bitcoin::hashes::{sha256, Hash, HashEngine, Hmac, HmacEngine}; use chrono::Utc; @@ -12,55 +12,71 @@ pub mod failure_codes { pub const UNKNOWN_NEXT_PEER: &'static str = "4010"; } +// Lsps2 specific error codes defined in BLIP-52. +// Are in the range 00200 to 00299. +pub mod error_codes { + pub const INVALID_OPENING_FEE_PARAMS: i64 = 201; + pub const PAYMENT_SIZE_TOO_SMALL: i64 = 202; + pub const PAYMENT_SIZE_TOO_LARGE: i64 = 203; +} + #[derive(Clone, Debug, PartialEq)] pub enum Error { InvalidOpeningFeeParams, PaymentSizeTooSmall, PaymentSizeTooLarge, - ClientRejected, } impl core::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let err_str = match self { - Error::InvalidOpeningFeeParams => "invalid opening fee params", - Error::PaymentSizeTooSmall => "payment size too small", - Error::PaymentSizeTooLarge => "payment size too large", - Error::ClientRejected => "client rejected", + Error::InvalidOpeningFeeParams => "Invalid opening fee params", + Error::PaymentSizeTooSmall => "Payment size too small", + Error::PaymentSizeTooLarge => "Payment size too large", }; write!(f, "{}", &err_str) } } impl From for RpcError { - fn from(value: Error) -> Self { - match value { - Error::InvalidOpeningFeeParams => RpcError { - code: 201, - message: "invalid opening fee params".to_string(), - data: None, - }, - Error::PaymentSizeTooSmall => RpcError { - code: 202, - message: "payment size too small".to_string(), - data: None, - }, - Error::PaymentSizeTooLarge => RpcError { - code: 203, - message: "payment size too large".to_string(), - data: None, - }, - Error::ClientRejected => RpcError { - code: 001, - message: "client rejected".to_string(), - data: None, - }, + fn from(error: Error) -> Self { + match error { + Error::InvalidOpeningFeeParams => RpcError::invalid_opening_fee_params(error), + Error::PaymentSizeTooSmall => RpcError::payment_size_too_small(error), + Error::PaymentSizeTooLarge => RpcError::payment_size_too_large(error), } } } impl core::error::Error for Error {} +pub trait LSPS2RpcErrorExt { + rpc_error_methods! { + invalid_opening_fee_params => error_codes::INVALID_OPENING_FEE_PARAMS, + payment_size_too_small => error_codes::PAYMENT_SIZE_TOO_SMALL, + payment_size_too_large => error_codes::PAYMENT_SIZE_TOO_LARGE + } +} + +impl LSPS2RpcErrorExt for RpcError {} + +pub trait ShortChannelIdJITExt { + fn generate_jit(blockheight: u32, distance: u32) -> Self; +} + +impl ShortChannelIdJITExt for ShortChannelId { + fn generate_jit(blockheight: u32, distance: u32) -> Self { + use rand::{rng, Rng as _}; + + let mut rng = rng(); + let block = blockheight + distance; + let tx_idx: u32 = rng.random_range(0..5000); + let output_idx: u16 = rng.random_range(0..10); + + (((block as u64) << 40) | ((tx_idx as u64) << 16) | (output_idx as u64)).into() + } +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct Lsps2GetInfoRequest { #[serde(skip_serializing_if = "Option::is_none")] @@ -99,7 +115,7 @@ impl core::error::Error for PromiseError {} #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(try_from = "String")] -pub struct Promise(String); +pub struct Promise(pub String); impl Promise { pub const MAX_BYTES: usize = 512; @@ -259,6 +275,8 @@ impl From for Lsps2PolicyGetInfoRequest { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Lsps2PolicyGetInfoResponse { pub policy_opening_fee_params_menu: Vec, + #[serde(default)] + pub client_rejected: bool, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -303,6 +321,19 @@ impl PolicyOpeningFeeParams { .collect(); promise } + + pub fn with_promise(&self, secret: &[u8; 32]) -> OpeningFeeParams { + OpeningFeeParams { + min_fee_msat: self.min_fee_msat, + proportional: self.proportional, + valid_until: self.valid_until, + min_lifetime: self.min_lifetime, + max_client_to_self_delay: self.max_client_to_self_delay, + min_payment_size_msat: self.min_payment_size_msat, + max_payment_size_msat: self.max_payment_size_msat, + promise: Promise(self.get_hmac_hex(secret)), + } + } } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] diff --git a/plugins/lsps-plugin/src/proto/mod.rs b/plugins/lsps-plugin/src/proto/mod.rs new file mode 100644 index 000000000000..cd9677939f0e --- /dev/null +++ b/plugins/lsps-plugin/src/proto/mod.rs @@ -0,0 +1,4 @@ +#[macro_use] +pub mod jsonrpc; +pub mod lsps0; +pub mod lsps2; diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index 03f3f7e670f6..bc59ba51dbe2 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -1,35 +1,71 @@ -use anyhow::{anyhow, bail}; -use async_trait::async_trait; -use cln_lsps::jsonrpc::server::JsonRpcResponseWriter; -use cln_lsps::jsonrpc::TransportError; -use cln_lsps::jsonrpc::{server::JsonRpcServer, JsonRpcRequest}; -use cln_lsps::lsps0::handler::Lsps0ListProtocolsHandler; -use cln_lsps::lsps0::model::Lsps0listProtocolsRequest; -use cln_lsps::lsps0::transport::{self, CustomMsg}; -use cln_lsps::lsps2::cln::{HtlcAcceptedRequest, HtlcAcceptedResponse}; -use cln_lsps::lsps2::handler::{ClnApiRpc, HtlcAcceptedHookHandler}; -use cln_lsps::lsps2::model::{Lsps2BuyRequest, Lsps2GetInfoRequest}; -use cln_lsps::util::wrap_payload_with_peer_id; -use cln_lsps::{lsps0, lsps2}; -use cln_plugin::Plugin; -use cln_rpc::notifications::CustomMsgNotification; -use cln_rpc::primitives::PublicKey; -use log::debug; +use anyhow::bail; +use bitcoin::hashes::Hash; +use cln_lsps::{ + cln_adapters::{ + hooks::service_custommsg_hook, rpc::ClnApiRpc, sender::ClnSender, state::ServiceState, + types::HtlcAcceptedRequest, + }, + core::{ + lsps2::{ + htlc::{Htlc, HtlcAcceptedHookHandler, HtlcDecision, Onion, RejectReason}, + service::Lsps2ServiceHandler, + }, + server::LspsService, + }, + proto::lsps0::Msat, +}; +use cln_plugin::{options, Plugin}; +use log::{debug, error, trace}; use std::path::{Path, PathBuf}; -use std::str::FromStr; use std::sync::Arc; +pub const OPTION_ENABLED: options::FlagConfigOption = options::ConfigOption::new_flag( + "experimental-lsps2-service", + "Enables lsps2 for the LSP service", +); + +pub const OPTION_PROMISE_SECRET: options::StringConfigOption = + options::ConfigOption::new_str_no_default( + "experimental-lsps2-promise-secret", + "A 64-character hex string that is the secret for promises", + ); + #[derive(Clone)] struct State { - lsps_service: JsonRpcServer, + lsps_service: Arc, + sender: ClnSender, lsps2_enabled: bool, } +impl State { + pub fn new(rpc_path: PathBuf, promise_secret: &[u8; 32]) -> Self { + let api = Arc::new(ClnApiRpc::new(rpc_path.clone())); + let sender = ClnSender::new(rpc_path); + let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(api, promise_secret)); + let lsps_service = Arc::new(LspsService::builder().with_protocol(lsps2_handler).build()); + Self { + lsps_service, + sender, + lsps2_enabled: true, + } + } +} + +impl ServiceState for State { + fn service(&self) -> Arc { + self.lsps_service.clone() + } + + fn sender(&self) -> cln_lsps::cln_adapters::sender::ClnSender { + self.sender.clone() + } +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { if let Some(plugin) = cln_plugin::Builder::new(tokio::io::stdin(), tokio::io::stdout()) - .option(lsps2::OPTION_ENABLED) - .option(lsps2::OPTION_PROMISE_SECRET) + .option(OPTION_ENABLED) + .option(OPTION_PROMISE_SECRET) // FIXME: Temporarily disabled lsp feature to please test cases, this is // ok as the feature is optional per spec. // We need to ensure that `connectd` only starts after all plugins have @@ -42,7 +78,7 @@ async fn main() -> Result<(), anyhow::Error> { // cln_plugin::FeatureBitsKind::Init, // util::feature_bit_to_hex(LSP_FEATURE_BIT), // ) - .hook("custommsg", on_custommsg) + .hook("custommsg", service_custommsg_hook) .hook("htlc_accepted", on_htlc_accepted) .configure() .await? @@ -50,9 +86,9 @@ async fn main() -> Result<(), anyhow::Error> { let rpc_path = Path::new(&plugin.configuration().lightning_dir).join(&plugin.configuration().rpc_file); - if plugin.option(&lsps2::OPTION_ENABLED)? { + if plugin.option(&OPTION_ENABLED)? { log::debug!("lsps2-service enabled"); - if let Some(secret_hex) = plugin.option(&lsps2::OPTION_PROMISE_SECRET)? { + if let Some(secret_hex) = plugin.option(&OPTION_PROMISE_SECRET)? { let secret_hex = secret_hex.trim().to_lowercase(); let decoded_bytes = match hex::decode(&secret_hex) { @@ -79,30 +115,7 @@ async fn main() -> Result<(), anyhow::Error> { } }; - let mut lsps_builder = JsonRpcServer::builder().with_handler( - Lsps0listProtocolsRequest::METHOD.to_string(), - Arc::new(Lsps0ListProtocolsHandler { - lsps2_enabled: plugin.option(&lsps2::OPTION_ENABLED)?, - }), - ); - - let cln_api_rpc = lsps2::handler::ClnApiRpc::new(rpc_path); - let getinfo_handler = - lsps2::handler::Lsps2GetInfoHandler::new(cln_api_rpc.clone(), secret); - let buy_handler = lsps2::handler::Lsps2BuyHandler::new(cln_api_rpc, secret); - lsps_builder = lsps_builder - .with_handler( - Lsps2GetInfoRequest::METHOD.to_string(), - Arc::new(getinfo_handler), - ) - .with_handler(Lsps2BuyRequest::METHOD.to_string(), Arc::new(buy_handler)); - - let lsps_service = lsps_builder.build(); - - let state = State { - lsps_service, - lsps2_enabled: true, - }; + let state = State::new(rpc_path, &secret); let plugin = plugin.start(state).await?; plugin.join().await } else { @@ -110,7 +123,7 @@ async fn main() -> Result<(), anyhow::Error> { } } else { return plugin - .disable(&format!("`{}` not enabled", &lsps2::OPTION_ENABLED.name)) + .disable(&format!("`{}` not enabled", &OPTION_ENABLED.name)) .await; } } else { @@ -121,73 +134,128 @@ async fn main() -> Result<(), anyhow::Error> { async fn on_htlc_accepted( p: Plugin, v: serde_json::Value, +) -> Result { + Ok(handle_htlc_safe(&p, v).await) +} + +async fn handle_htlc_safe(p: &Plugin, v: serde_json::Value) -> serde_json::Value { + match handle_htlc_inner(p, v).await { + Ok(response) => response, + Err(e) => { + error!("HTLC hook error (continuing): {:#}", e); + json_continue() + } + } +} + +async fn handle_htlc_inner( + p: &Plugin, + v: serde_json::Value, ) -> Result { if !p.state().lsps2_enabled { - // just continue. - // Fixme: Add forward and extra tlvs from incoming. - let res = serde_json::to_value(&HtlcAcceptedResponse::continue_(None, None, None))?; - return Ok(res); + return Ok(json_continue()); } let req: HtlcAcceptedRequest = serde_json::from_value(v)?; + + let short_channel_id = match req.onion.short_channel_id { + Some(scid) => scid, + None => { + trace!("We are the destination of the HTLC, continue."); + return Ok(json_continue()); + } + }; + let rpc_path = Path::new(&p.configuration().lightning_dir).join(&p.configuration().rpc_file); let api = ClnApiRpc::new(rpc_path); // Fixme: Use real htlc_minimum_amount. let handler = HtlcAcceptedHookHandler::new(api, 1000); - let res = handler.handle(req).await?; - let res_val = serde_json::to_value(&res)?; - Ok(res_val) -} - -async fn on_custommsg( - p: Plugin, - v: serde_json::Value, -) -> Result { - // All of this could be done async if needed. - let continue_response = Ok(serde_json::json!({ - "result": "continue" - })); - let msg: CustomMsgNotification = - serde_json::from_value(v).map_err(|e| anyhow!("invalid custommsg: {e}"))?; - - let req = CustomMsg::from_str(&msg.payload).map_err(|e| anyhow!("invalid payload {e}"))?; - if req.message_type != lsps0::transport::LSPS0_MESSAGE_TYPE { - // We don't care if this is not for us! - return continue_response; - } - let dir = p.configuration().lightning_dir; - let rpc_path = Path::new(&dir).join(&p.configuration().rpc_file); - let mut writer = LspsResponseWriter { - peer_id: msg.peer_id, - rpc_path: rpc_path.try_into()?, + let onion = Onion { + short_channel_id, + payload: req.onion.payload, }; - // The payload inside CustomMsg is the actual JSON-RPC - // request/notification, we wrap it to attach the peer_id as well. - let payload = wrap_payload_with_peer_id(&req.payload, msg.peer_id); + let htlc = Htlc { + amount_msat: Msat::from_msat(req.htlc.amount_msat.msat()), + extra_tlvs: req.htlc.extra_tlvs.unwrap_or_default(), + }; - let service = p.state().lsps_service.clone(); - match service.handle_message(&payload, &mut writer).await { - Ok(_) => continue_response, + debug!("Handle potential jit-session HTLC."); + let response = match handler.handle(&htlc, &onion).await { + Ok(dec) => { + log_decision(&dec); + decision_to_response(dec)? + } Err(e) => { - debug!("failed to handle lsps message: {}", e); - continue_response + // Fixme: Should we log **BROKEN** here? + debug!("Htlc handler failed (continuing): {:#}", e); + return Ok(json_continue()); } - } + }; + + Ok(serde_json::to_value(&response)?) } -pub struct LspsResponseWriter { - peer_id: PublicKey, - rpc_path: PathBuf, +fn decision_to_response(decision: HtlcDecision) -> Result { + Ok(match decision { + HtlcDecision::NotOurs => json_continue(), + + HtlcDecision::Forward { + mut payload, + forward_to, + mut extra_tlvs, + } => json_continue_forward( + payload.to_bytes()?, + forward_to.as_byte_array().to_vec(), + extra_tlvs.to_bytes()?, + ), + + // Fixme: once we implement MPP-Support we need to remove this. + HtlcDecision::Reject { + reason: RejectReason::MppNotSupported, + } => json_continue(), + HtlcDecision::Reject { reason } => json_fail(reason.failure_code()), + }) +} + +fn json_continue() -> serde_json::Value { + serde_json::json!({"result": "continue"}) +} + +fn json_continue_forward( + payload: Vec, + forward_to: Vec, + extra_tlvs: Vec, +) -> serde_json::Value { + serde_json::json!({ + "result": "continue", + "payload": hex::encode(payload), + "forward_to": hex::encode(forward_to), + "extra_tlvs": hex::encode(extra_tlvs) + }) } -#[async_trait] -impl JsonRpcResponseWriter for LspsResponseWriter { - async fn write(&mut self, payload: &[u8]) -> cln_lsps::jsonrpc::Result<()> { - let mut client = cln_rpc::ClnRpc::new(&self.rpc_path).await.map_err(|e| { - cln_lsps::jsonrpc::Error::Transport(TransportError::Other(e.to_string())) - })?; - transport::send_custommsg(&mut client, payload.to_vec(), self.peer_id).await +fn json_fail(failure_code: &str) -> serde_json::Value { + serde_json::json!({ + "result": "fail", + "failure_message": failure_code + }) +} + +fn log_decision(decision: &HtlcDecision) { + match decision { + HtlcDecision::NotOurs => { + trace!("SCID not ours, continue"); + } + HtlcDecision::Forward { forward_to, .. } => { + debug!( + "Forwarding via JIT channel {}", + hex::encode(forward_to.as_byte_array()) + ); + } + HtlcDecision::Reject { reason } => { + debug!("Rejecting HTLC: {:?}", reason); + } } }