Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/hteapot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ impl Hteapot {
let _ = socket_data.stream.shutdown(Shutdown::Both);
return None;
}

// If the request is not yet complete, read data from the stream into a buffer.
// This ensures that the server can handle partial or chunked requests.
if !status.request.done {
Expand All @@ -211,7 +210,17 @@ impl Hteapot {
return None;
}
status.ttl = Instant::now();
let _ = status.request.append(buffer[..m].to_vec());
let r = status.request.append(buffer[..m].to_vec());
if r.is_err() {
// Early return response if not valid request is sended
let error_msg = r.err().unwrap();
let response =
HttpResponse::new(HttpStatus::BadRequest, error_msg, None).to_bytes();
let _ = socket_data.stream.write(&response);
let _ = socket_data.stream.flush();
let _ = socket_data.stream.shutdown(Shutdown::Both);
return None;
}
}
}
}
Expand All @@ -224,10 +233,9 @@ impl Hteapot {

let keep_alive = request
.headers
.get("Connection")
.map(|v| v == "keep-alive")
.get("connection") //all headers are turn lowercase in the builder
.map(|v| v.to_lowercase() == "keep-alive")
.unwrap_or(false);

if !status.write {
let mut response = action(request);
if keep_alive {
Expand Down
138 changes: 90 additions & 48 deletions src/hteapot/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
// Written by Alberto Ruiz 2025-01-01
// This module handles the request struct and a builder for it
// This implementation has some issues and fixes are required for security
// Refactor is recomended, but for now can work with the fixes

use super::HttpMethod;
use std::{collections::HashMap, net::TcpStream, str};
use std::{cmp::min, collections::HashMap, net::TcpStream, str};

const MAX_HEADER_SIZE: usize = 1024 * 16;
const MAX_HEADER_COUNT: usize = 100;

#[derive(Debug)]
pub struct HttpRequest {
pub method: HttpMethod,
pub path: String,
Expand All @@ -24,7 +33,7 @@ impl HttpRequest {

pub fn default() -> Self {
HttpRequest {
method: HttpMethod::GET,
method: HttpMethod::Other(String::new()),
path: String::new(),
args: HashMap::new(),
headers: HashMap::new(),
Expand All @@ -44,39 +53,6 @@ impl HttpRequest {
};
}

// pub fn body(&mut self) -> Option<Vec<u8>> {
// if self.has_body() {
// let mut stream = self.stream.as_ref().unwrap();
// let content_length = self.headers.get("Content-Length")?;
// let content_length: usize = content_length.parse().unwrap();
// if content_length > self.body.len() {
// let _ = stream.flush();
// let mut total_read = 0;
// self.body.resize(content_length, 0);
// while total_read < content_length {
// match stream.read(&mut self.body[total_read..]) {
// Ok(0) => {
// break;
// }
// Ok(n) => {
// total_read += n;
// }
// Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// continue;
// }
// Err(_e) => {
// break;
// }
// }
// }
// }

// Some(self.body.clone())
// } else {
// None
// }
// }

pub fn set_stream(&mut self, stream: TcpStream) {
self.stream = Some(stream);
}
Expand All @@ -96,6 +72,8 @@ impl HttpRequest {
pub struct HttpRequestBuilder {
request: HttpRequest,
buffer: Vec<u8>,
header_done: bool,
header_size: usize,
body_size: usize,
pub done: bool,
}
Expand All @@ -111,6 +89,8 @@ impl HttpRequestBuilder {
body: Vec::new(),
stream: None,
},
header_size: 0,
header_done: false,
body_size: 0,
buffer: Vec::new(),
done: false,
Expand All @@ -125,21 +105,63 @@ impl HttpRequestBuilder {
}
}

pub fn append(&mut self, buffer: Vec<u8>) -> Option<HttpRequest> {
self.buffer.extend(buffer);
self.buffer.retain(|&b| b != 0);
fn read_body_len(&mut self) -> Option<()> {
let body_left = self.body_size.saturating_sub(self.request.body.len());
let to_take = min(body_left, self.buffer.len());
let to_append = &self.buffer[..to_take];
self.request.body.extend_from_slice(to_append);
self.buffer.drain(..to_take);

if body_left > 0 {
return None;
} else {
self.done = true;
return Some(());
}
}

fn read_body_chunk(&mut self) -> Option<()> {
//TODO: this will support chunked body in the future
todo!()
}

fn read_body(&mut self) -> Option<()> {
return self.read_body_len();
}

pub fn append(&mut self, chunk: Vec<u8>) -> Result<(), &'static str> {
if !self.header_done && self.buffer.len() > MAX_HEADER_SIZE {
return Err("Entity Too large");
}
let chunk_size = chunk.len();
self.buffer.extend(chunk);
if self.header_done {
self.read_body();
return Ok(());
} else {
self.header_size += chunk_size;
if self.header_size > MAX_HEADER_SIZE {
return Err("Entity Too large");
}
}
while let Some(pos) = self.buffer.windows(2).position(|w| w == b"\r\n") {
let line = self.buffer.drain(..pos).collect::<Vec<u8>>();
self.buffer.drain(..2);

let line_str = String::from_utf8_lossy(&line);
let line_str = match str::from_utf8(line.as_slice()) {
Ok(v) => v.to_string(),
Err(_e) => return Err("No utf-8"),
};

if self.request.path.is_empty() {
let parts: Vec<&str> = line_str.split_whitespace().collect();
if parts.len() < 2 {
return None;
return Ok(());
}

if parts.len() != 3 {
return Err("Invalid method + path + version request");
}
self.request.method = HttpMethod::from_str(parts[0]);
let path_parts: Vec<&str> = parts[1].split('?').collect();
self.request.path = path_parts[0].to_string();
Expand All @@ -158,21 +180,41 @@ impl HttpRequestBuilder {
.collect();
}
} else if !line_str.is_empty() {
if let Some((key, value)) = line_str.split_once(": ") {
if key.to_lowercase() == "content-length" {
if let Some((key, value)) = line_str.split_once(":") {
//Check the number of headers, if the actual headers exceed that number
//drop the connection
if self.request.headers.len() > MAX_HEADER_COUNT {
return Err("Header number exceed allowed");
}
let key = key.trim().to_lowercase();
let value = value.trim();
if key == "content-length" {
if self.request.headers.get("content-length").is_some()
|| self
.request
.headers
.get("transfer-encoding")
.map(|te| te == "chunked")
.unwrap_or(false)
{
return Err("Duplicated content-length");
}
self.body_size = value.parse().unwrap_or(0);
}
self.request
.headers
.insert(key.to_string(), value.to_string());
}
} else {
self.header_done = true;
self.read_body();
return Ok(());
}
}
self.request.body.append(&mut self.buffer.clone());
if self.request.body.len() == self.body_size {
self.done = true;
return Some(self.request.clone());
}
None
Ok(())
}
}

#[cfg(test)]
#[test]
fn basic_request() {}