From 0ce91bf39950114aab720008580f92ed19f64efb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rom=C3=A1n=20C=C3=A1rdenas?= Date: Mon, 31 Mar 2025 19:11:36 +0200 Subject: [PATCH 1/2] axum-like router --- Cargo.toml | 57 ++- examples/client.rs | 2 +- examples/router_client.rs | 61 ++++ examples/router_server.rs | 67 ++++ src/lib.rs | 2 + src/router/extract/body.rs | 68 ++++ src/router/extract/mod.rs | 53 +++ src/router/extract/path/de.rs | 619 +++++++++++++++++++++++++++++++++ src/router/extract/path/mod.rs | 182 ++++++++++ src/router/extract/query.rs | 65 ++++ src/router/extract/state.rs | 38 ++ src/router/extract/tuple.rs | 46 +++ src/router/handler.rs | 137 ++++++++ src/router/macros.rs | 105 ++++++ src/router/mod.rs | 98 ++++++ src/router/request.rs | 57 +++ src/router/response.rs | 116 ++++++ src/router/route.rs | 192 ++++++++++ src/router/util.rs | 30 ++ src/server.rs | 23 +- 20 files changed, 1999 insertions(+), 19 deletions(-) create mode 100644 examples/router_client.rs create mode 100644 examples/router_server.rs create mode 100644 src/router/extract/body.rs create mode 100644 src/router/extract/mod.rs create mode 100644 src/router/extract/path/de.rs create mode 100644 src/router/extract/path/mod.rs create mode 100644 src/router/extract/query.rs create mode 100644 src/router/extract/state.rs create mode 100644 src/router/extract/tuple.rs create mode 100644 src/router/handler.rs create mode 100644 src/router/macros.rs create mode 100644 src/router/mod.rs create mode 100644 src/router/request.rs create mode 100644 src/router/response.rs create mode 100644 src/router/route.rs create mode 100644 src/router/util.rs diff --git a/Cargo.toml b/Cargo.toml index 6c064d6a6..1bb0565e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,30 +14,55 @@ edition = "2021" url = "^2.5" log = "^0.4" regex = "^1.5" -tokio = {version = "^1.32", features = ["full"]} -tokio-util = {version = "^0.6", features = ["codec","net"]} -tokio-stream = {version = "^0.1", features = ["time"]} +tokio = { version = "^1.32", features = ["full"] } +tokio-util = { version = "^0.7", features = ["codec", "net"] } +tokio-stream = { version = "^0.1", features = ["time"] } futures = "^0.3" -coap-lite = "0.11.3" +coap-lite = "0.13.1" async-trait = "0.1.74" # dependencies for dtls -webrtc-dtls = {version = "0.8.0", optional = true} -webrtc-util = {version = "0.8.0", optional = true} -rustls = {version = "^0.21.1", optional = true} -rustls-pemfile = {version = "2.0.0", optional = true} -rcgen = {version = "^0.11.0", optional = true} -pkcs8 = {version = "0.10.2", optional = true} -sec1 = { version = "0.7.3", features = ["pem", "pkcs8", "std"], optional = true} -rand = "0.8.5" +webrtc-dtls = { version = "0.8.0", optional = true } +webrtc-util = { version = "0.8.0", optional = true } +rustls = { version = "^0.21.1", optional = true } +rustls-pemfile = { version = "2.0.0", optional = true } +rcgen = { version = "^0.11.0", optional = true } +pkcs8 = { version = "0.10.2", optional = true } +sec1 = { version = "0.7.3", features = [ + "pem", + "pkcs8", + "std", +], optional = true } +rand = "0.9.0" + +# dependencies for extractor +percent-encoding = { version = "2.1", optional = true} +serde = { version = "1.0", optional = true } +serde_html_form = { version = "0.2.7", optional = true } +serde_path_to_error = { version = "0.1.9", optional = true } +serde_json = { version = "1.0", optional = true } [features] -default = ["dtls"] -dtls = ["dep:webrtc-dtls", "dep:webrtc-util", "dep:rustls", "dep:rustls-pemfile", "dep:rcgen", "dep:pkcs8", "dep:sec1"] +default = ["dtls", "router"] +dtls = [ + "dep:webrtc-dtls", + "dep:webrtc-util", + "dep:rustls", + "dep:rustls-pemfile", + "dep:rcgen", + "dep:pkcs8", + "dep:sec1", +] +router = [ + "dep:percent-encoding", + "dep:serde", + "dep:serde_html_form", + "dep:serde_path_to_error", + "dep:serde_json", +] [dev-dependencies] -quickcheck = "0.8.2" +quickcheck = "1.0" socket2 = "0.5" tokio-test = "0.4.4" - diff --git a/examples/client.rs b/examples/client.rs index 8860e58ac..20ec94ce9 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -24,7 +24,7 @@ async fn main() { } async fn example_get() { - let url = "coap://127.0.0.1:5683/hello/get"; + let url = "coap://127.0.0.1:5683/temperature/kitchen"; println!("Client request: {}", url); match UdpCoAPClient::get(url).await { diff --git a/examples/router_client.rs b/examples/router_client.rs new file mode 100644 index 000000000..c5c636e6a --- /dev/null +++ b/examples/router_client.rs @@ -0,0 +1,61 @@ +extern crate coap; + +// use coap::client::ObserveMessage; +use coap::UdpCoAPClient; +// use std::io; +use std::io::ErrorKind; + +#[tokio::main] +async fn main() { + println!("GET url:"); + example_get().await; + + println!("POST data to url:"); + example_post().await; + + println!("GET url again:"); + example_get().await; +} + +async fn example_get() { + let url = "coap://127.0.0.1:5683/temperature?room=kitchen"; + println!("Client request: {}", url); + + match UdpCoAPClient::get(url).await { + Ok(response) => { + println!( + "Server reply: {}", + String::from_utf8(response.message.payload).unwrap() + ); + } + Err(e) => { + match e.kind() { + ErrorKind::WouldBlock => println!("Request timeout"), // Unix + ErrorKind::TimedOut => println!("Request timeout"), // Windows + _ => println!("Request error: {:?}", e), + } + } + } +} + +async fn example_post() { + let url = "coap://127.0.0.1:5683/temperature/kitchen"; + let data = b"21.0".to_vec(); + println!("Client request: {}", url); + + match UdpCoAPClient::post(url, data).await { + Ok(response) => { + println!( + "Server reply: {}", + String::from_utf8(response.message.payload).unwrap() + ); + } + Err(e) => { + match e.kind() { + ErrorKind::WouldBlock => println!("Request timeout"), // Unix + ErrorKind::TimedOut => println!("Request timeout"), // Windows + _ => println!("Request error: {:?}", e), + } + } + } +} diff --git a/examples/router_server.rs b/examples/router_server.rs new file mode 100644 index 000000000..95fdef4e7 --- /dev/null +++ b/examples/router_server.rs @@ -0,0 +1,67 @@ +extern crate coap; + +use coap::{ + router::{ + extract::{Body, Path, Query, State}, + Router, + }, + Server, +}; +use serde::Deserialize; +use std::{collections::HashMap, sync::Arc}; +use tokio::sync::Mutex; + +pub struct RoomState { + rooms: HashMap, +} + +pub type RoomMutex = Arc>; + +#[derive(Debug, Deserialize)] +pub struct QueryArgs { + room: String, +} + +async fn get_temperature( + Query(QueryArgs { room }): Query, + state: State, +) -> String { + println!("get_temperature: {room}"); + let state = state.lock().await; + + if let Some(temp) = state.rooms.get(&room) { + format!("Temperature in {room}: {temp}") + } else { + format!("Room {} not found", room) + } +} + +async fn set_temperature( + Path(room): Path, + Body(temp): Body, + State(state): State, +) -> String { + println!("set_temperature: {:?}", room); + let mut state = state.lock().await; + + state.rooms.insert(room, temp); + "OK".to_string() +} + +#[tokio::main] +async fn main() { + let addr = "127.0.0.1:5683"; + + let state = Arc::new(Mutex::new(RoomState { + rooms: HashMap::new(), + })); + + let router = Router::new() + .get("/temperature", get_temperature) + .post("/temperature/{room}", set_temperature); + + let server = Server::new_udp(addr).unwrap(); + println!("Server up on {addr}"); + + server.serve(router, state).await; +} diff --git a/src/lib.rs b/src/lib.rs index c8b98748a..643b8a628 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,4 +98,6 @@ pub mod client; pub mod dtls; mod observer; pub mod request; +#[cfg(feature = "router")] +pub mod router; pub mod server; diff --git a/src/router/extract/body.rs b/src/router/extract/body.rs new file mode 100644 index 000000000..7086d7a10 --- /dev/null +++ b/src/router/extract/body.rs @@ -0,0 +1,68 @@ +use crate::router::{ + extract::FromRequest, + request::Request, + response::{IntoResponse, Response, Status}, +}; +use serde::de::DeserializeOwned; +use std::ops::{Deref, DerefMut}; + +#[derive(Debug, Clone, Copy)] +pub enum BodyRejection { + InvalidBody, + DeserializationError, + InvalidUtf8, +} + +impl std::fmt::Display for BodyRejection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BodyRejection::InvalidBody => write!(f, "Invalid body string"), + BodyRejection::DeserializationError => write!(f, "Failed to deserialize JSON body"), + BodyRejection::InvalidUtf8 => write!(f, "Invalid UTF-8 in body"), + } + } +} + +impl IntoResponse for BodyRejection { + fn into_response(self) -> Response { + let error_message = self.to_string(); + Response::new() + .set_response_type(Status::BadRequest) + .set_payload(error_message.into_bytes()) + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Body(pub T); + +impl FromRequest for Body +where + S: Sync, + T: DeserializeOwned, +{ + type Rejection = BodyRejection; + + async fn from_request(req: &Request, _state: &S) -> Result { + // Convert payload bytes to UTF-8 string + let body_str = String::from_utf8(req.payload()).map_err(|_| BodyRejection::InvalidUtf8)?; + + // Deserialize JSON + serde_json::from_str::(&body_str) + .map(Body) + .map_err(|_| BodyRejection::DeserializationError) + } +} + +impl Deref for Body { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Body { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/src/router/extract/mod.rs b/src/router/extract/mod.rs new file mode 100644 index 000000000..d26c2e9d5 --- /dev/null +++ b/src/router/extract/mod.rs @@ -0,0 +1,53 @@ +use crate::router::{request::Request, response::IntoResponse}; +use std::future::Future; + +mod body; +mod path; +mod query; +mod state; +mod tuple; + +pub trait FromRequest: Sized { + type Rejection: IntoResponse; + + fn from_request( + req: &Request, + state: &S, + ) -> impl Future> + Send; +} + +pub trait FromRef { + fn from_ref(input: &T) -> Self; +} + +impl FromRef for T { + fn from_ref(input: &T) -> Self { + input.clone() + } +} + +pub trait OptionalFromRequest: Sized { + type Rejection: IntoResponse; + + fn from_request( + req: &Request, + state: &S, + ) -> impl Future, Self::Rejection>> + Send; +} + +impl FromRequest for Option +where + T: OptionalFromRequest, + S: Send + Sync, +{ + type Rejection = T::Rejection; + + async fn from_request(req: &Request, state: &S) -> Result, Self::Rejection> { + T::from_request(req, state).await + } +} + +pub use body::Body; +pub use path::Path; +pub use query::Query; +pub use state::State; diff --git a/src/router/extract/path/de.rs b/src/router/extract/path/de.rs new file mode 100644 index 000000000..be9a8a33d --- /dev/null +++ b/src/router/extract/path/de.rs @@ -0,0 +1,619 @@ +use crate::router::{extract::path::PathDeserializationError, util::PercentDecodedStr}; +use serde::{ + de::{self, DeserializeSeed, EnumAccess, Error, MapAccess, SeqAccess, VariantAccess, Visitor}, + forward_to_deserialize_any, Deserializer, +}; +use std::{any::type_name, sync::Arc}; + +macro_rules! unsupported_type { + ($trait_fn:ident) => { + fn $trait_fn(self, _: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializationError::UnsupportedType { + name: type_name::(), + }) + } + }; +} + +macro_rules! parse_single_value { + ($trait_fn:ident, $visit_fn:ident, $ty:literal) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.url_params.len() != 1 { + return Err(PathDeserializationError::WrongNumberOfParameters { + got: self.url_params.len(), + expected: 1, + }); + } + + let value = + self.url_params[0] + .1 + .parse() + .map_err(|_| PathDeserializationError::ParseError { + value: self.url_params[0].1.as_str().to_owned(), + expected_type: $ty, + })?; + visitor.$visit_fn(value) + } + }; +} + +pub(crate) struct PathDeserializer<'de> { + url_params: &'de [(Arc, PercentDecodedStr)], +} + +impl<'de> PathDeserializer<'de> { + #[inline] + pub(crate) fn new(url_params: &'de [(Arc, PercentDecodedStr)]) -> Self { + PathDeserializer { url_params } + } +} + +impl<'de> Deserializer<'de> for PathDeserializer<'de> { + type Error = PathDeserializationError; + + unsupported_type!(deserialize_bytes); + unsupported_type!(deserialize_option); + unsupported_type!(deserialize_identifier); + unsupported_type!(deserialize_ignored_any); + + parse_single_value!(deserialize_bool, visit_bool, "bool"); + parse_single_value!(deserialize_i8, visit_i8, "i8"); + parse_single_value!(deserialize_i16, visit_i16, "i16"); + parse_single_value!(deserialize_i32, visit_i32, "i32"); + parse_single_value!(deserialize_i64, visit_i64, "i64"); + parse_single_value!(deserialize_i128, visit_i128, "i128"); + parse_single_value!(deserialize_u8, visit_u8, "u8"); + parse_single_value!(deserialize_u16, visit_u16, "u16"); + parse_single_value!(deserialize_u32, visit_u32, "u32"); + parse_single_value!(deserialize_u64, visit_u64, "u64"); + parse_single_value!(deserialize_u128, visit_u128, "u128"); + parse_single_value!(deserialize_f32, visit_f32, "f32"); + parse_single_value!(deserialize_f64, visit_f64, "f64"); + parse_single_value!(deserialize_string, visit_string, "String"); + parse_single_value!(deserialize_byte_buf, visit_string, "String"); + parse_single_value!(deserialize_char, visit_char, "char"); + + fn deserialize_any(self, v: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(v) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.url_params.len() != 1 { + return Err(PathDeserializationError::WrongNumberOfParameters { + got: self.url_params.len(), + expected: 1, + }); + } + let value = &self.url_params[0].1; + visitor.visit_borrowed_str(value) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_seq(SeqDeserializer { + params: self.url_params, + idx: 0, + }) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + if self.url_params.len() != len { + return Err(PathDeserializationError::WrongNumberOfParameters { + got: self.url_params.len(), + expected: len, + }); + } + visitor.visit_seq(SeqDeserializer { + params: self.url_params, + idx: 0, + }) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.url_params.len() != len { + return Err(PathDeserializationError::WrongNumberOfParameters { + got: self.url_params.len(), + expected: len, + }); + } + visitor.visit_seq(SeqDeserializer { + params: self.url_params, + idx: 0, + }) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_map(MapDeserializer { + params: self.url_params, + value: None, + key: None, + }) + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_map(visitor) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + if self.url_params.len() != 1 { + return Err(PathDeserializationError::WrongNumberOfParameters { + got: self.url_params.len(), + expected: 1, + }); + } + + visitor.visit_enum(EnumDeserializer { + value: &self.url_params[0].1, + }) + } +} + +struct MapDeserializer<'de> { + params: &'de [(Arc, PercentDecodedStr)], + key: Option>, + value: Option<&'de PercentDecodedStr>, +} + +impl<'de> MapAccess<'de> for MapDeserializer<'de> { + type Error = PathDeserializationError; + + fn next_key_seed(&mut self, seed: K) -> Result, Self::Error> + where + K: DeserializeSeed<'de>, + { + match self.params.split_first() { + Some(((key, value), tail)) => { + self.value = Some(value); + self.params = tail; + self.key = Some(KeyOrIdx::Key(key)); + seed.deserialize(KeyDeserializer { key }).map(Some) + } + None => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'de>, + { + match self.value.take() { + Some(value) => seed.deserialize(ValueDeserializer { + key: self.key.take(), + value, + }), + None => Err(PathDeserializationError::custom("value is missing")), + } + } +} + +struct KeyDeserializer<'de> { + key: &'de str, +} + +macro_rules! parse_key { + ($trait_fn:ident) => { + fn $trait_fn(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_str(&self.key) + } + }; +} + +impl<'de> Deserializer<'de> for KeyDeserializer<'de> { + type Error = PathDeserializationError; + + parse_key!(deserialize_identifier); + parse_key!(deserialize_str); + parse_key!(deserialize_string); + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializationError::custom("Unexpected key type")) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char bytes + byte_buf option unit unit_struct seq tuple + tuple_struct map newtype_struct struct enum ignored_any + } +} + +macro_rules! parse_value { + ($trait_fn:ident, $visit_fn:ident, $ty:literal) => { + fn $trait_fn(mut self, visitor: V) -> Result + where + V: Visitor<'de>, + { + let v = self.value.parse().map_err(|_| { + if let Some(key) = self.key.take() { + match key { + KeyOrIdx::Key(key) => PathDeserializationError::ParseErrorAtKey { + key: key.to_owned(), + value: self.value.as_str().to_owned(), + expected_type: $ty, + }, + KeyOrIdx::Idx { idx: index, key: _ } => { + PathDeserializationError::ParseErrorAtIndex { + index, + value: self.value.as_str().to_owned(), + expected_type: $ty, + } + } + } + } else { + PathDeserializationError::ParseError { + value: self.value.as_str().to_owned(), + expected_type: $ty, + } + } + })?; + visitor.$visit_fn(v) + } + }; +} + +#[derive(Debug)] +struct ValueDeserializer<'de> { + key: Option>, + value: &'de PercentDecodedStr, +} + +impl<'de> Deserializer<'de> for ValueDeserializer<'de> { + type Error = PathDeserializationError; + + unsupported_type!(deserialize_map); + unsupported_type!(deserialize_identifier); + + parse_value!(deserialize_bool, visit_bool, "bool"); + parse_value!(deserialize_i8, visit_i8, "i8"); + parse_value!(deserialize_i16, visit_i16, "i16"); + parse_value!(deserialize_i32, visit_i32, "i32"); + parse_value!(deserialize_i64, visit_i64, "i64"); + parse_value!(deserialize_i128, visit_i128, "i128"); + parse_value!(deserialize_u8, visit_u8, "u8"); + parse_value!(deserialize_u16, visit_u16, "u16"); + parse_value!(deserialize_u32, visit_u32, "u32"); + parse_value!(deserialize_u64, visit_u64, "u64"); + parse_value!(deserialize_u128, visit_u128, "u128"); + parse_value!(deserialize_f32, visit_f32, "f32"); + parse_value!(deserialize_f64, visit_f64, "f64"); + parse_value!(deserialize_string, visit_string, "String"); + parse_value!(deserialize_byte_buf, visit_string, "String"); + parse_value!(deserialize_char, visit_char, "char"); + + fn deserialize_any(self, v: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_str(v) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.value) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_bytes(self.value.as_bytes()) + } + + fn deserialize_option(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_some(self) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + struct PairDeserializer<'de> { + key: Option>, + value: Option<&'de PercentDecodedStr>, + } + + impl<'de> SeqAccess<'de> for PairDeserializer<'de> { + type Error = PathDeserializationError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + match self.key.take() { + Some(KeyOrIdx::Idx { idx: _, key }) => { + return seed.deserialize(KeyDeserializer { key }).map(Some); + } + Some(KeyOrIdx::Key(_)) => { + return Err(PathDeserializationError::custom( + "array types are not supported", + )); + } + None => {} + }; + + self.value + .take() + .map(|value| seed.deserialize(ValueDeserializer { key: None, value })) + .transpose() + } + } + + if len == 2 { + match self.key { + Some(key) => visitor.visit_seq(PairDeserializer { + key: Some(key), + value: Some(self.value), + }), + // `self.key` is only `None` when deserializing maps so `deserialize_seq` + // wouldn't be called for that + None => unreachable!(), + } + } else { + Err(PathDeserializationError::UnsupportedType { + name: type_name::(), + }) + } + } + + fn deserialize_seq(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializationError::UnsupportedType { + name: type_name::(), + }) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializationError::UnsupportedType { + name: type_name::(), + }) + } + + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializationError::UnsupportedType { + name: type_name::(), + }) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(EnumDeserializer { value: self.value }) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } +} + +struct EnumDeserializer<'de> { + value: &'de str, +} + +impl<'de> EnumAccess<'de> for EnumDeserializer<'de> { + type Error = PathDeserializationError; + type Variant = UnitVariant; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> + where + V: de::DeserializeSeed<'de>, + { + Ok(( + seed.deserialize(KeyDeserializer { key: self.value })?, + UnitVariant, + )) + } +} + +struct UnitVariant; + +impl<'de> VariantAccess<'de> for UnitVariant { + type Error = PathDeserializationError; + + fn unit_variant(self) -> Result<(), Self::Error> { + Ok(()) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where + T: DeserializeSeed<'de>, + { + Err(PathDeserializationError::UnsupportedType { + name: "newtype enum variant", + }) + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializationError::UnsupportedType { + name: "tuple enum variant", + }) + } + + fn struct_variant( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + Err(PathDeserializationError::UnsupportedType { + name: "struct enum variant", + }) + } +} + +struct SeqDeserializer<'de> { + params: &'de [(Arc, PercentDecodedStr)], + idx: usize, +} + +impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { + type Error = PathDeserializationError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: DeserializeSeed<'de>, + { + match self.params.split_first() { + Some(((key, value), tail)) => { + self.params = tail; + let idx = self.idx; + self.idx += 1; + Ok(Some(seed.deserialize(ValueDeserializer { + key: Some(KeyOrIdx::Idx { idx, key }), + value, + })?)) + } + None => Ok(None), + } + } +} + +#[derive(Debug, Clone)] +enum KeyOrIdx<'de> { + Key(&'de str), + Idx { idx: usize, key: &'de str }, +} diff --git a/src/router/extract/path/mod.rs b/src/router/extract/path/mod.rs new file mode 100644 index 000000000..3e910e9a9 --- /dev/null +++ b/src/router/extract/path/mod.rs @@ -0,0 +1,182 @@ +use crate::router::{ + extract::FromRequest, + request::Request, + response::{IntoResponse, Response, Status}, +}; +use serde::de::DeserializeOwned; +use std::ops::{Deref, DerefMut}; + +mod de; + +#[derive(Debug)] +pub struct Path(pub T); + +impl Deref for Path { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Path { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl FromRequest for Path +where + S: Sync, + T: DeserializeOwned, +{ + type Rejection = PathDeserializationError; + + async fn from_request(req: &Request, _state: &S) -> Result { + match T::deserialize(de::PathDeserializer::new(&req.path)) { + Ok(value) => Ok(Path(value)), + Err(e) => Err(e), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PathDeserializationError { + /// The URI contained the wrong number of parameters. + WrongNumberOfParameters { + /// The number of actual parameters in the URI. + got: usize, + /// The number of expected parameters. + expected: usize, + }, + + /// Failed to parse the value at a specific key into the expected type. + /// + /// This variant is used when deserializing into types that have named fields, such as structs. + ParseErrorAtKey { + /// The key at which the value was located. + key: String, + /// The value from the URI. + value: String, + /// The expected type of the value. + expected_type: &'static str, + }, + + /// Failed to parse the value at a specific index into the expected type. + /// + /// This variant is used when deserializing into sequence types, such as tuples. + ParseErrorAtIndex { + /// The index at which the value was located. + index: usize, + /// The value from the URI. + value: String, + /// The expected type of the value. + expected_type: &'static str, + }, + + /// Failed to parse a value into the expected type. + /// + /// This variant is used when deserializing into a primitive type (such as `String` and `u32`). + ParseError { + /// The value from the URI. + value: String, + /// The expected type of the value. + expected_type: &'static str, + }, + + /// A parameter contained text that, once percent decoded, wasn't valid UTF-8. + InvalidUtf8InPathParam { + /// The key at which the invalid value was located. + key: String, + }, + + /// Tried to serialize into an unsupported type such as nested maps. + /// + /// This error kind is caused by programmer errors and thus gets converted into a `500 Internal + /// Server Error` response. + UnsupportedType { + /// The name of the unsupported type. + name: &'static str, + }, + + /// Failed to deserialize the value with a custom deserialization error. + DeserializeError { + /// The key at which the invalid value was located. + key: String, + /// The value that failed to deserialize. + value: String, + /// The deserializaation failure message. + message: String, + }, + + /// Catch-all variant for errors that don't fit any other variant. + Message(String), +} + +impl std::fmt::Display for PathDeserializationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Message(error) => error.fmt(f), + Self::InvalidUtf8InPathParam { key } => write!(f, "Invalid UTF-8 in `{key}`"), + Self::WrongNumberOfParameters { got, expected } => { + write!( + f, + "Wrong number of path arguments for `Path`. Expected {expected} but got {got}" + )?; + + if *expected == 1 { + write!(f, ". Note that multiple parameters must be extracted with a tuple `Path<(_, _)>` or a struct `Path`")?; + } + + Ok(()) + } + Self::UnsupportedType { name } => write!(f, "Unsupported type `{name}`"), + Self::ParseErrorAtKey { + key, + value, + expected_type, + } => write!( + f, + "Cannot parse `{key}` with value `{value}` to a `{expected_type}`" + ), + Self::ParseError { + value, + expected_type, + } => write!(f, "Cannot parse `{value}` to a `{expected_type}`"), + Self::ParseErrorAtIndex { + index, + value, + expected_type, + } => write!( + f, + "Cannot parse value at index {index} with value `{value}` to a `{expected_type}`" + ), + Self::DeserializeError { + key, + value, + message, + } => write!(f, "Cannot parse `{key}` with value `{value}`: {message}"), + } + } +} + +impl IntoResponse for PathDeserializationError { + fn into_response(self) -> Response { + let error_message = self.to_string(); + Response::new() + .set_response_type(Status::BadRequest) + .set_payload(error_message.into_bytes()) + } +} + +impl serde::de::Error for PathDeserializationError { + #[inline] + fn custom(msg: T) -> Self + where + T: std::fmt::Display, + { + Self::Message(msg.to_string()) + } +} + +impl std::error::Error for PathDeserializationError {} diff --git a/src/router/extract/query.rs b/src/router/extract/query.rs new file mode 100644 index 000000000..85e39b411 --- /dev/null +++ b/src/router/extract/query.rs @@ -0,0 +1,65 @@ +use crate::router::{ + extract::FromRequest, + request::Request, + response::{IntoResponse, Response, Status}, +}; +use serde::de::DeserializeOwned; +use std::ops::{Deref, DerefMut}; + +#[derive(Debug, Clone, Copy)] +pub enum QueryRejection { + InvalidQuery, + InvalidUtf8, +} + +impl std::fmt::Display for QueryRejection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + QueryRejection::InvalidQuery => write!(f, "Invalid query string"), + QueryRejection::InvalidUtf8 => write!(f, "Invalid UTF-8 in query"), + } + } +} + +impl IntoResponse for QueryRejection { + fn into_response(self) -> Response { + let error_message = self.to_string(); + Response::new() + .set_response_type(Status::BadRequest) + .set_payload(error_message.into_bytes()) + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct Query(pub T); + +impl FromRequest for Query +where + S: Sync, + T: DeserializeOwned, +{ + type Rejection = QueryRejection; + + async fn from_request(req: &Request, _state: &S) -> Result { + let query = req.query(); + let deserializer = + serde_html_form::Deserializer::new(url::form_urlencoded::parse(query.as_bytes())); + let value = serde_path_to_error::deserialize(deserializer) + .map_err(|_| QueryRejection::InvalidQuery)?; + Ok(Query(value)) + } +} + +impl Deref for Query { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Query { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/src/router/extract/state.rs b/src/router/extract/state.rs new file mode 100644 index 000000000..ec6cdbf9f --- /dev/null +++ b/src/router/extract/state.rs @@ -0,0 +1,38 @@ +use crate::router::{ + extract::{FromRef, FromRequest}, + request::Request, +}; +use std::{ + convert::Infallible, + ops::{Deref, DerefMut}, +}; + +#[derive(Debug, Default, Clone, Copy)] +pub struct State(pub S); + +impl FromRequest for State +where + InnerState: FromRef, + OuterState: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request(_req: &Request, state: &OuterState) -> Result { + let inner_state = InnerState::from_ref(state); + Ok(State(inner_state)) + } +} + +impl Deref for State { + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for State { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/src/router/extract/tuple.rs b/src/router/extract/tuple.rs new file mode 100644 index 000000000..8f183284c --- /dev/null +++ b/src/router/extract/tuple.rs @@ -0,0 +1,46 @@ +use crate::router::{ + extract::FromRequest, + macros::all_the_tuples, + request::Request, + response::{IntoResponse, Response}, +}; +use std::convert::Infallible; + +impl FromRequest for () { + type Rejection = Infallible; + + async fn from_request(_req: &Request, _state: &S) -> Result { + Ok::(()) + } +} + +macro_rules! impl_from_request { + ( + [$($ty:ident),*], $last:ident + ) => { + #[allow(non_snake_case, unused_mut, unused_variables)] + impl FromRequest for ($($ty,)* $last,) + where + S: Send + Sync, + $( $ty: FromRequest + Send, )* + $last: FromRequest + Send, + { + type Rejection = Response; + + async fn from_request(req: &Request, state: &S) -> Result { + $( + let $ty = $ty::from_request(req, state) + .await + .map_err(|err| err.into_response())?; + )* + let $last = $last::from_request(req, state) + .await + .map_err(|err| err.into_response())?; + + Ok(($($ty,)* $last,)) + } + } + }; +} + +all_the_tuples!(impl_from_request); diff --git a/src/router/handler.rs b/src/router/handler.rs new file mode 100644 index 000000000..2c2565d7a --- /dev/null +++ b/src/router/handler.rs @@ -0,0 +1,137 @@ +use crate::router::{ + extract::FromRequest, macros::all_the_tuples, request::Request, response::IntoResponse, +}; +use std::{future::Future, pin::Pin}; + +// Type-erased handler trait (no T parameter) +pub trait BoxableHandler: Send + Sync + 'static { + fn call(&self, req: Request, state: S) -> Pin + Send>>; + fn clone_box(&self) -> Box>; +} + +// Enable cloning of boxed handlers +impl Clone for Box> { + fn clone(&self) -> Self { + self.clone_box() + } +} + +impl BoxableHandler for HandlerWrapper +where + T: Sync + Send + 'static, + S: Send + Sync + 'static, + H: Handler + Clone + Send + Sync + 'static, +{ + fn call(&self, req: Request, state: S) -> Pin + Send>> { + let clone = self.clone(); + Box::pin(H::call(clone.handler, req, state)) + } + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } +} + +// Updated BoxedHandler type +pub type BoxedHandler = Box>; + +pub struct HandlerWrapper> { + handler: H, + _marker: std::marker::PhantomData<(T, S)>, +} + +impl HandlerWrapper +where + H: Handler + 'static, + T: 'static, +{ + pub fn new(handler: H) -> Self { + Self { + handler, + _marker: std::marker::PhantomData, + } + } +} + +impl Clone for HandlerWrapper +where + H: Handler + Clone, +{ + fn clone(&self) -> Self { + Self { + handler: self.handler.clone(), + _marker: std::marker::PhantomData, + } + } +} + +pub trait Handler: Clone + Send + Sync + 'static { + /// The type of future calling this handler returns. + type Future: Future + Send + 'static; + + /// Call the handler with the given request. + fn call(self, req: Request, state: S) -> Self::Future; +} + +impl Handler<(), S> for F +where + F: FnOnce() -> Fut + Clone + Send + Sync + 'static, + Fut: Future + Send, + Res: IntoResponse + 'static, +{ + type Future = Pin + Send>>; + + fn call(self, mut req: Request, _state: S) -> Self::Future { + Box::pin(async move { + let result = self().await.into_response(); + result.fill_response(&mut req); + req + }) + } +} + +macro_rules! impl_handler { + ( + [$($ty:ident),*], $last:ident + ) => { + #[allow(non_snake_case, unused_mut)] + impl<$($ty,)* $last, S, F, Fut, Res> Handler<($($ty,)* $last,), S> for F + where + S: Send + Sync + 'static, + F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + Sync + 'static, + Fut: Future + Send, + Res: IntoResponse, + $( $ty: FromRequest + Send, )* + $last: FromRequest + Send, + { + type Future = Pin + Send>>; + + fn call(self, mut req: Request, state: S) -> Self::Future { + Box::pin(async move { + $( + let $ty = match $ty::from_request(&req, &state).await { + Ok(value) => value, + Err(rejection) => { + rejection.into_response().fill_response(&mut req); + return req; + } + }; + )* + + let $last = match $last::from_request( &req, &state).await { + Ok(value) => value, + Err(rejection) => { + rejection.into_response().fill_response(&mut req); + return req; + } + }; + + let response = self($($ty,)* $last,).await; + response.into_response().fill_response(&mut req); + req + }) + } + } + }; +} + +all_the_tuples!(impl_handler); diff --git a/src/router/macros.rs b/src/router/macros.rs new file mode 100644 index 000000000..2195ffe33 --- /dev/null +++ b/src/router/macros.rs @@ -0,0 +1,105 @@ +/// Private API. +#[doc(hidden)] +#[macro_export] +macro_rules! composite_rejection { + ( + $(#[$m:meta])* + pub enum $name:ident { + $($variant:ident),+ + $(,)? + } + ) => { + $(#[$m])* + #[derive(Debug)] + #[non_exhaustive] + pub enum $name { + $( + #[allow(missing_docs)] + $variant($variant) + ),+ + } + + impl $crate::extractor::response::IntoResponse for $name { + fn into_response(self) -> $crate::extractor::response::Response { + match self { + $( + Self::$variant(inner) => inner.into_response(), + )+ + } + } + } + + impl $name { + /// Get the response body text used for this rejection. + pub fn body_text(&self) -> String { + match self { + $( + Self::$variant(inner) => inner.body_text(), + )+ + } + } + + /// Get the status code used for this rejection. + pub fn status(&self) -> http::StatusCode { + match self { + $( + Self::$variant(inner) => inner.status(), + )+ + } + } + } + + $( + impl From<$variant> for $name { + fn from(inner: $variant) -> Self { + Self::$variant(inner) + } + } + )+ + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + $( + Self::$variant(inner) => write!(f, "{inner}"), + )+ + } + } + } + + impl std::error::Error for $name { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + $( + Self::$variant(inner) => inner.source(), + )+ + } + } + } + }; +} + +#[macro_export] +#[rustfmt::skip] +macro_rules! all_the_tuples { + ($name:ident) => { + $name!([], T1); + $name!([T1], T2); + $name!([T1, T2], T3); + $name!([T1, T2, T3], T4); + $name!([T1, T2, T3, T4], T5); + $name!([T1, T2, T3, T4, T5], T6); + $name!([T1, T2, T3, T4, T5, T6], T7); + $name!([T1, T2, T3, T4, T5, T6, T7], T8); + $name!([T1, T2, T3, T4, T5, T6, T7, T8], T9); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], T14); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], T15); + $name!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], T16); + }; +} + +pub use {all_the_tuples, composite_rejection}; diff --git a/src/router/mod.rs b/src/router/mod.rs new file mode 100644 index 000000000..d88548b5c --- /dev/null +++ b/src/router/mod.rs @@ -0,0 +1,98 @@ +pub mod extract; +pub mod handler; +pub mod macros; +pub mod request; +pub mod response; +pub mod route; +pub mod util; + +pub use coap_lite::RequestType as Method; +use handler::{BoxedHandler, Handler, HandlerWrapper}; +pub use request::Request; +use response::IntoResponse; +use route::{Route, RouteError}; + +#[derive(Default)] +pub struct Router { + routes: Vec<(Route, BoxedHandler)>, + fallback: Option>, +} + +impl Router { + pub fn new() -> Self { + Self { + routes: Vec::new(), + fallback: None, + } + } + + pub fn route(mut self, method: Method, path: impl ToString, handler: BoxedHandler) -> Self { + let route = Route::new(method, path); + self.routes.push((route, handler)); + self + } + + pub fn get(self, path: impl ToString, handler: H) -> Self + where + T: Send + Sync + 'static, + H: Handler, + { + let handler = Box::new(HandlerWrapper::new(handler)); + self.route(Method::Get, path, handler) + } + + pub fn post(self, path: impl ToString, handler: H) -> Self + where + T: Send + Sync + 'static, + H: Handler, + { + let handler = Box::new(HandlerWrapper::new(handler)); + self.route(Method::Post, path, handler) + } + + pub fn put(self, path: impl ToString, handler: H) -> Self + where + T: Send + Sync + 'static, + H: Handler, + { + let handler = Box::new(HandlerWrapper::new(handler)); + self.route(Method::Put, path, handler) + } + + pub fn delete(self, path: impl ToString, handler: H) -> Self + where + T: Send + Sync + 'static, + H: Handler, + { + let handler = Box::new(HandlerWrapper::new(handler)); + self.route(Method::Delete, path, handler) + } + + pub fn fallback(mut self, handler: H) -> Self + where + T: Send + Sync + 'static, + H: Handler, + { + let handler = Box::new(HandlerWrapper::new(handler)); + self.fallback = Some(handler); + self + } + + pub(crate) async fn handle(&self, mut req: Request, state: S) -> Request { + // routes are explored in order of registration + for (route, handler) in &self.routes { + if let Ok(path) = route.match_request(&mut req) { + req.path = path; + return handler.call(req, state).await; + } + } + // No route matched, use fallback or return not found + match self.fallback { + Some(ref fallback) => fallback.call(req, state).await, + None => { + RouteError::NotFound.into_response().fill_response(&mut req); + req + } + } + } +} diff --git a/src/router/request.rs b/src/router/request.rs new file mode 100644 index 000000000..45d7e498e --- /dev/null +++ b/src/router/request.rs @@ -0,0 +1,57 @@ +use crate::router::util::PercentDecodedStr; +use coap_lite::{CoapOption, CoapRequest}; +use std::{net::SocketAddr, sync::Arc}; + +pub struct Request { + pub req: Box>, + pub(crate) path: Vec<(Arc, PercentDecodedStr)>, +} + +impl Request { + pub fn new(req: Box>) -> Self { + Self { + req, + path: Vec::new(), + } + } + + pub fn response(&self) -> Option<&coap_lite::CoapResponse> { + self.req.response.as_ref() + } + + pub fn response_mut(&mut self) -> Option<&mut coap_lite::CoapResponse> { + self.req.response.as_mut() + } + + pub fn method(&self) -> coap_lite::RequestType { + *self.req.get_method() + } + + pub fn path_as_vec(&self) -> Vec { + self.req.get_path_as_vec().unwrap_or_default() + } + + pub fn path(&self) -> String { + self.path_as_vec().join("/") + } + + pub fn query_as_vec(&self) -> Vec { + let mut vec = Vec::new(); + if let Some(options) = self.req.message.get_option(CoapOption::UriQuery) { + for option in options.iter() { + if let Ok(seg) = core::str::from_utf8(option) { + vec.push(seg.to_string()); + } + } + }; + vec + } + + pub fn query(&self) -> String { + self.query_as_vec().join("&") + } + + pub fn payload(&self) -> Vec { + self.req.message.payload.clone() + } +} diff --git a/src/router/response.rs b/src/router/response.rs new file mode 100644 index 000000000..5a96391b2 --- /dev/null +++ b/src/router/response.rs @@ -0,0 +1,116 @@ +use crate::router::request::Request; +pub use coap_lite::ResponseType as Status; +use std::convert::Infallible; + +#[derive(Debug, Clone, Default)] +pub struct Response { + pub response_type: Option, + pub payload: Option>, +} + +impl Response { + pub fn new() -> Self { + Self { + response_type: None, + payload: None, + } + } + + pub fn set_response_type(mut self, response_type: Status) -> Self { + self.response_type = Some(response_type); + self + } + + pub fn set_payload(mut self, payload: Vec) -> Self { + self.payload = Some(payload); + self + } + + pub fn fill_response(self, request: &mut Request) { + if let Some(response) = request.response_mut() { + if let Some(response_type) = self.response_type { + response.set_status(response_type); + } + if let Some(payload) = &self.payload { + response.message.payload = payload.clone(); + } + } + } +} + +/// Trait for generating responses. +/// +/// Types that implement `IntoResponse` can be returned from handlers. +pub trait IntoResponse { + /// Create a response. + fn into_response(self) -> Response; +} + +impl IntoResponse for Response { + fn into_response(self) -> Response { + self + } +} + +impl IntoResponse for () { + fn into_response(self) -> Response { + Response::new() + } +} + +impl IntoResponse for Infallible { + fn into_response(self) -> Response { + match self {} + } +} + +impl IntoResponse for Result { + fn into_response(self) -> Response { + match self { + Ok(value) => value.into_response(), + Err(err) => err.into_response(), + } + } +} + +impl IntoResponse for Vec { + fn into_response(self) -> Response { + Response::new().set_payload(self) + } +} + +impl IntoResponse for [u8; N] { + fn into_response(self) -> Response { + self.to_vec().into_response() + } +} + +impl IntoResponse for &[u8; N] { + fn into_response(self) -> Response { + self.to_vec().into_response() + } +} + +impl IntoResponse for &[u8] { + fn into_response(self) -> Response { + self.to_vec().into_response() + } +} + +impl IntoResponse for String { + fn into_response(self) -> Response { + self.into_bytes().into_response() + } +} + +impl IntoResponse for &str { + fn into_response(self) -> Response { + self.as_bytes().into_response() + } +} + +impl IntoResponse for Box { + fn into_response(self) -> Response { + (*self).into_response() + } +} diff --git a/src/router/route.rs b/src/router/route.rs new file mode 100644 index 000000000..1ba98ab7e --- /dev/null +++ b/src/router/route.rs @@ -0,0 +1,192 @@ +use crate::router::{ + request::Request, + response::{IntoResponse, Response, Status}, + util::PercentDecodedStr, +}; +use coap_lite::RequestType as Method; +use regex::Regex; +use std::{hash::Hash, sync::Arc}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RouteError { + DifferentMethod, + NoUriPath, + DifferentPath, + NotFound, +} + +impl std::fmt::Display for RouteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RouteError::DifferentMethod => write!(f, "Different method"), + RouteError::NoUriPath => write!(f, "No URI path"), + RouteError::DifferentPath => write!(f, "Different path"), + RouteError::NotFound => write!(f, "Not found"), + } + } +} + +impl IntoResponse for RouteError { + fn into_response(self) -> Response { + let error_message = self.to_string(); + Response::new() + .set_response_type(Status::BadRequest) + .set_payload(error_message.into_bytes()) + } +} + +#[derive(Debug, Clone)] +pub struct Route { + method: Method, + route: String, + regex: Regex, + param_names: Vec, +} + +impl Route { + pub fn new(method: Method, route: S) -> Self { + let route = route.to_string(); + let (regex, param_names) = Self::build_regex(&route); + + Route { + method, + route, + regex, + param_names, + } + } + + /// Build a regex pattern from a route string + fn build_regex(route: &str) -> (Regex, Vec) { + let mut param_names = Vec::new(); + let mut pattern = String::new(); + + // Add start anchor + pattern.push('^'); + + // Process each path segment. We ignore leading and trailing slashes + let segments = route.trim_matches('/').split('/'); + + for (i, segment) in segments.enumerate() { + if i > 0 { + pattern.push('/'); + } + + if segment.is_empty() { + continue; + } + + // Track position as we scan the segment + let mut pos = 0; + let mut in_param = false; + let mut param_start = 0; + + // Process the segment character by character + let chars: Vec = segment.chars().collect(); + for i in 0..chars.len() { + match chars[i] { + '{' => { + if in_param { + // Nested parameter start, which is invalid + panic!("Nested parameters are not allowed in route pattern: {route}"); + } + + // Start of parameter + in_param = true; + param_start = i; + + // Add preceding literal text with proper escaping + if i > pos { + pattern.push_str(®ex::escape(&segment[pos..i])); + } + } + '}' => { + if !in_param { + // Unmatched closing brace, which is invalid + panic!("Unmatched closing brace in route pattern: {route}",); + } + // End of parameter + in_param = false; + + // Extract parameter name + let param_name = &segment[param_start + 1..i]; + param_names.push(param_name.to_string()); + // Add capture group to pattern + pattern.push_str("([^/]+)"); + + // Update position + pos = i + 1; + } + _ => {} // Skip other characters + } + } + // If we ended in an unclosed parameter, that's an error + if in_param { + panic!("Unclosed parameter in route pattern: {route}"); + } + + // Handle any trailing literal text + if pos < segment.len() { + pattern.push_str(®ex::escape(&segment[pos..])); + } + } + + // Add end anchor + pattern.push('$'); + + println!("Regex pattern: {}", pattern); + println!("Parameter names: {:?}", param_names); + + // Compile the regex pattern + let regex = match Regex::new(&pattern) { + Ok(re) => re, + Err(err) => panic!("Invalid route pattern: {} ({})", route, err), + }; + + (regex, param_names) + } + + pub fn route(&self) -> &str { + &self.route + } + + #[inline] + pub fn match_method(&self, method: Method) -> Result<(), RouteError> { + match self.method == method { + true => Ok(()), + false => Err(RouteError::DifferentMethod), + } + } + + #[inline] + pub(crate) fn match_path( + &self, + path: &str, + ) -> Result, PercentDecodedStr)>, RouteError> { + let path = path.trim_matches('/'); + if let Some(captures) = self.regex.captures(path) { + let mut params = Vec::new(); + for (i, name) in self.param_names.iter().enumerate() { + if let Some(matched) = captures.get(i + 1) { + let value = matched.as_str(); + if let Some(decoded) = PercentDecodedStr::new(value) { + params.push((name.clone().into(), decoded)); + } else { + return Err(RouteError::NoUriPath); + } + } + } + Ok(params) + } else { + Err(RouteError::DifferentPath) + } + } + + pub(crate) fn match_request( + &self, + req: &mut Request, + ) -> Result, PercentDecodedStr)>, RouteError> { + self.match_method(req.method())?; + self.match_path(&req.path()) + } +} diff --git a/src/router/util.rs b/src/router/util.rs new file mode 100644 index 000000000..0e9d4095f --- /dev/null +++ b/src/router/util.rs @@ -0,0 +1,30 @@ +use std::{ops::Deref, sync::Arc}; + +/// A wrapper around a percent-decoded string. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub(crate) struct PercentDecodedStr(Arc); + +impl PercentDecodedStr { + /// Creates a new `PercentDecodedStr` from a percent-encoded string. + /// Usually, this is used to decode URI path segments or query parameters. + pub(crate) fn new>(s: S) -> Option { + percent_encoding::percent_decode(s.as_ref().as_bytes()) + .decode_utf8() + .ok() + .map(|decoded| Self(decoded.as_ref().into())) + } + + /// Returns the inner string as a `&str`. + pub(crate) fn as_str(&self) -> &str { + &self.0 + } +} + +impl Deref for PercentDecodedStr { + type Target = str; + + #[inline] + fn deref(&self) -> &Self::Target { + self.as_str() + } +} diff --git a/src/server.rs b/src/server.rs index 38ad2fd34..241e8817b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,3 +1,6 @@ +use crate::observer::Observer; +#[cfg(feature = "router")] +use crate::router::{request::Request, Router}; use async_trait::async_trait; use coap_lite::{BlockHandler, BlockHandlerConfig, CoapRequest, CoapResponse, Packet}; use log::debug; @@ -19,8 +22,6 @@ use tokio::{ task::JoinHandle, }; -use crate::observer::Observer; - #[derive(Debug)] pub enum CoAPServerError { NetworkError, @@ -430,6 +431,24 @@ impl Server { } } } + + #[cfg(feature = "router")] + pub async fn serve(self, router: Router, state: S) + where + S: Clone + Send + Sync + 'static, + { + let router = Arc::new(router); + let handler = { + move |req| { + let r = router.clone(); + let s = state.clone(); + let req = Request::new(req); + async move { r.handle(req, s).await.req } + } + }; + self.run(handler).await.unwrap(); + } + async fn respond_to_request(req: Box>, responder: Arc) { // if we have some reponse to send, send it if let Some(Ok(b)) = req.response.map(|resp| resp.message.to_bytes()) { From 66e495cece34095503ceda07bbf7a388e05c9fdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rom=C3=A1n=20C=C3=A1rdenas=20Rodr=C3=ADguez?= Date: Thu, 3 Apr 2025 13:32:21 +0200 Subject: [PATCH 2/2] leave previous client example as it was --- examples/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/client.rs b/examples/client.rs index 20ec94ce9..8860e58ac 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -24,7 +24,7 @@ async fn main() { } async fn example_get() { - let url = "coap://127.0.0.1:5683/temperature/kitchen"; + let url = "coap://127.0.0.1:5683/hello/get"; println!("Client request: {}", url); match UdpCoAPClient::get(url).await {