diff --git a/src/hteapot/mod.rs b/src/hteapot/mod.rs index 588d832..87ef9c2 100644 --- a/src/hteapot/mod.rs +++ b/src/hteapot/mod.rs @@ -188,7 +188,6 @@ impl Hteapot { let _ = socket_data.stream.shutdown(Shutdown::Both); return None; } - // If the request is not yet complete, read data from the stream into a buffer. // This ensures that the server can handle partial or chunked requests. if !status.request.done { @@ -211,7 +210,17 @@ impl Hteapot { return None; } status.ttl = Instant::now(); - let _ = status.request.append(buffer[..m].to_vec()); + let r = status.request.append(buffer[..m].to_vec()); + if r.is_err() { + // Early return response if not valid request is sended + let error_msg = r.err().unwrap(); + let response = + HttpResponse::new(HttpStatus::BadRequest, error_msg, None).to_bytes(); + let _ = socket_data.stream.write(&response); + let _ = socket_data.stream.flush(); + let _ = socket_data.stream.shutdown(Shutdown::Both); + return None; + } } } } @@ -224,10 +233,9 @@ impl Hteapot { let keep_alive = request .headers - .get("Connection") - .map(|v| v == "keep-alive") + .get("connection") //all headers are turn lowercase in the builder + .map(|v| v.to_lowercase() == "keep-alive") .unwrap_or(false); - if !status.write { let mut response = action(request); if keep_alive { diff --git a/src/hteapot/request.rs b/src/hteapot/request.rs index f00262f..4c7f6ef 100644 --- a/src/hteapot/request.rs +++ b/src/hteapot/request.rs @@ -1,6 +1,15 @@ +// Written by Alberto Ruiz 2025-01-01 +// This module handles the request struct and a builder for it +// This implementation has some issues and fixes are required for security +// Refactor is recomended, but for now can work with the fixes + use super::HttpMethod; -use std::{collections::HashMap, net::TcpStream, str}; +use std::{cmp::min, collections::HashMap, net::TcpStream, str}; + +const MAX_HEADER_SIZE: usize = 1024 * 16; +const MAX_HEADER_COUNT: usize = 100; +#[derive(Debug)] pub struct HttpRequest { pub method: HttpMethod, pub path: String, @@ -24,7 +33,7 @@ impl HttpRequest { pub fn default() -> Self { HttpRequest { - method: HttpMethod::GET, + method: HttpMethod::Other(String::new()), path: String::new(), args: HashMap::new(), headers: HashMap::new(), @@ -44,39 +53,6 @@ impl HttpRequest { }; } - // pub fn body(&mut self) -> Option> { - // if self.has_body() { - // let mut stream = self.stream.as_ref().unwrap(); - // let content_length = self.headers.get("Content-Length")?; - // let content_length: usize = content_length.parse().unwrap(); - // if content_length > self.body.len() { - // let _ = stream.flush(); - // let mut total_read = 0; - // self.body.resize(content_length, 0); - // while total_read < content_length { - // match stream.read(&mut self.body[total_read..]) { - // Ok(0) => { - // break; - // } - // Ok(n) => { - // total_read += n; - // } - // Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { - // continue; - // } - // Err(_e) => { - // break; - // } - // } - // } - // } - - // Some(self.body.clone()) - // } else { - // None - // } - // } - pub fn set_stream(&mut self, stream: TcpStream) { self.stream = Some(stream); } @@ -96,6 +72,8 @@ impl HttpRequest { pub struct HttpRequestBuilder { request: HttpRequest, buffer: Vec, + header_done: bool, + header_size: usize, body_size: usize, pub done: bool, } @@ -111,6 +89,8 @@ impl HttpRequestBuilder { body: Vec::new(), stream: None, }, + header_size: 0, + header_done: false, body_size: 0, buffer: Vec::new(), done: false, @@ -125,21 +105,63 @@ impl HttpRequestBuilder { } } - pub fn append(&mut self, buffer: Vec) -> Option { - self.buffer.extend(buffer); - self.buffer.retain(|&b| b != 0); + fn read_body_len(&mut self) -> Option<()> { + let body_left = self.body_size.saturating_sub(self.request.body.len()); + let to_take = min(body_left, self.buffer.len()); + let to_append = &self.buffer[..to_take]; + self.request.body.extend_from_slice(to_append); + self.buffer.drain(..to_take); + + if body_left > 0 { + return None; + } else { + self.done = true; + return Some(()); + } + } + + fn read_body_chunk(&mut self) -> Option<()> { + //TODO: this will support chunked body in the future + todo!() + } + + fn read_body(&mut self) -> Option<()> { + return self.read_body_len(); + } + + pub fn append(&mut self, chunk: Vec) -> Result<(), &'static str> { + if !self.header_done && self.buffer.len() > MAX_HEADER_SIZE { + return Err("Entity Too large"); + } + let chunk_size = chunk.len(); + self.buffer.extend(chunk); + if self.header_done { + self.read_body(); + return Ok(()); + } else { + self.header_size += chunk_size; + if self.header_size > MAX_HEADER_SIZE { + return Err("Entity Too large"); + } + } while let Some(pos) = self.buffer.windows(2).position(|w| w == b"\r\n") { let line = self.buffer.drain(..pos).collect::>(); self.buffer.drain(..2); - let line_str = String::from_utf8_lossy(&line); + let line_str = match str::from_utf8(line.as_slice()) { + Ok(v) => v.to_string(), + Err(_e) => return Err("No utf-8"), + }; if self.request.path.is_empty() { let parts: Vec<&str> = line_str.split_whitespace().collect(); if parts.len() < 2 { - return None; + return Ok(()); } + if parts.len() != 3 { + return Err("Invalid method + path + version request"); + } self.request.method = HttpMethod::from_str(parts[0]); let path_parts: Vec<&str> = parts[1].split('?').collect(); self.request.path = path_parts[0].to_string(); @@ -158,21 +180,41 @@ impl HttpRequestBuilder { .collect(); } } else if !line_str.is_empty() { - if let Some((key, value)) = line_str.split_once(": ") { - if key.to_lowercase() == "content-length" { + if let Some((key, value)) = line_str.split_once(":") { + //Check the number of headers, if the actual headers exceed that number + //drop the connection + if self.request.headers.len() > MAX_HEADER_COUNT { + return Err("Header number exceed allowed"); + } + let key = key.trim().to_lowercase(); + let value = value.trim(); + if key == "content-length" { + if self.request.headers.get("content-length").is_some() + || self + .request + .headers + .get("transfer-encoding") + .map(|te| te == "chunked") + .unwrap_or(false) + { + return Err("Duplicated content-length"); + } self.body_size = value.parse().unwrap_or(0); } self.request .headers .insert(key.to_string(), value.to_string()); } + } else { + self.header_done = true; + self.read_body(); + return Ok(()); } } - self.request.body.append(&mut self.buffer.clone()); - if self.request.body.len() == self.body_size { - self.done = true; - return Some(self.request.clone()); - } - None + Ok(()) } } + +#[cfg(test)] +#[test] +fn basic_request() {}