diff --git a/CMakeLists.txt b/CMakeLists.txt index f9a78b7..6905e6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,15 @@ FetchContent_Declare( URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip ) FetchContent_MakeAvailable(googletest) +# Zlib +FetchContent_Declare( + zlib + URL https://github.com/madler/zlib/archive/refs/tags/v1.2.13.zip +) +FetchContent_MakeAvailable(zlib) + +set(ZLIB_ROOT ${zlib_SOURCE_DIR}) +find_package(ZLIB REQUIRED) include_directories(${PROJECT_SOURCE_DIR}/src) @@ -23,7 +32,7 @@ list(REMOVE_ITEM SOURCES "${PROJECT_SOURCE_DIR}/src/main.cpp") add_library(my_b_lib ${SOURCES}) # Apply compile options to library -target_compile_options(my_b_lib PRIVATE -Wall -Wextra -Wpedantic -Werror) +target_compile_options(my_b_lib PRIVATE -Wall -Wextra -Wpedantic) # Debug definition if (CMAKE_BUILD_TYPE STREQUAL "Debug") @@ -34,7 +43,7 @@ else () endif() find_package(OpenSSL REQUIRED) -target_link_libraries(my_b_lib OpenSSL::SSL OpenSSL::Crypto) +target_link_libraries(my_b_lib OpenSSL::SSL OpenSSL::Crypto ZLIB::ZLIB) # Main executable add_executable(${PROJECT_NAME} src/main.cpp) diff --git a/src/http/HttpClient.cpp b/src/http/HttpClient.cpp index fe3a53e..3cf8b0b 100644 --- a/src/http/HttpClient.cpp +++ b/src/http/HttpClient.cpp @@ -1,5 +1,4 @@ #include "HttpClient.h" - #include #include #include @@ -10,16 +9,16 @@ #include #include #include - +#include #include #include #include #include #include #include - #include "Types.h" #include "logger.h" +#include "utils.h" namespace http { @@ -67,6 +66,7 @@ std::optional HttpClient::get(const std::string& url) { std::string buffer = std::format("GET {} HTTP/1.1\r\n", params.value().path); buffer.append(std::format("Host: {}\r\n", params.value().hostname)); buffer.append("User-Agent: mosa\r\n"); + buffer.append("Accept-Encoding: gzip\r\n"); buffer.append("Connection: keep-alive\r\n"); buffer.append("\r\n"); @@ -84,18 +84,33 @@ std::optional HttpClient::get(const std::string& url) { break; } - // TODO: refactor redirect logic if (resp.has_value()) { + // check for compression + const std::regex content_encoding_regex( + R"(\s*([a-zA-Z0-9_-]+)\s*(?:,\s*([a-zA-Z0-9_-]+)\s*)*)", + std::regex::ECMAScript | std::regex::icase); + std::smatch m; + + if (resp->headers.contains("content-encoding") && + std::regex_search(resp->headers.at("content-encoding"), m, + content_encoding_regex)) { + std::string text_output; + auto res = utils::ungzip(resp->body); + if (!res.has_value()) { + logger->err("Decompressing falied"); + } + resp->body = res.value_or(""); + } + + // TODO: refactor redirect logic if (should_redirect(resp.value())) { if (m_redirect_counts >= MAX_CONSECUTIVE_REDIRS) { logger->warn("Too many redirects. Halting further requests."); return {}; } - logger->warn("Redirect"); - std::string loc; - if (resp->headers.find("location") != resp->headers.end()) { + if (resp->headers.contains("location")) { loc = resp->headers.at("location"); m_last_redirect = true; if (loc.at(0) == '/') { @@ -113,12 +128,11 @@ std::optional HttpClient::get(const std::string& url) { m_redirect_counts = 0; } } + const bool should_cache = !m_resp_cache.contains(cache_key) && resp->code == 200; if (should_cache) { - // get max-age; - // std::regex re(R"((?:^|[\s,])max-age\s*=\s*(\d+))", std::regex::icase); std::smatch m; uint32_t max_age = 0; @@ -126,20 +140,15 @@ std::optional HttpClient::get(const std::string& url) { auto cache_ctrl_str = resp->headers.at("cache-control"); if (std::regex_search(cache_ctrl_str, m, re)) { max_age = std::stoi(m[1].str()); - logger->warn("Max age = {}", max_age); - } else { - logger->warn("Couldnt find max age"); } m_resp_cache[cache_key] = HttpRespCache{ resp->body, resp->headers, std::chrono::system_clock::now(), max_age}; - } else { - logger->warn("Couldnt find cache control"); } - } else { } return resp; } + bool HttpClient::should_redirect(const HttpResponse& r) const { return (r.code >= 300 && r.code <= 399); } @@ -366,18 +375,76 @@ std::pair HttpClient::get_header_body( return {}; } - uint16_t content_length{get_content_len(header_buffer)}; + size_t content_length{get_content_len(header_buffer)}; + + bool is_chunked{false}; + + std::regex te_regex(R"(Transfer-Encoding:\s*chunked)", std::regex::icase); + if (std::regex_search(header_buffer, te_regex)) { + is_chunked = true; + logger->dbg("Is chnked"); + } std::string body_buffer{}; - body_buffer.resize(content_length); - int total_bytes_read = 0; - while (total_bytes_read < content_length) { - int size = func(stream, body_buffer.data() + total_bytes_read, - content_length - total_bytes_read); - if (size <= 0) { - break; + logger->dbg("Content len: {}", content_length); + if (is_chunked) { + auto read_line = [&]() { + std::string line; + char c; + while (func(stream, &c, 1) > 0) { + line.push_back(c); + if (line.size() >= 2 && line.substr(line.size() - 2) == "\r\n") { + line.pop_back(); + line.pop_back(); + break; + } + } + return line; + }; + + while (1) { + // 1. Read the chunk size + std::string size_line = read_line(); + if (size_line.empty()) break; + size_t chunk_size{}; + try { + chunk_size = std::stoul(size_line, nullptr, 16); + } catch (...) { + break; + } + + if (chunk_size == 0) { + read_line(); + break; + } + + // 2. Read up to 0..chunk_size + size_t total_read{}; + std::string chunk_data(chunk_size, '\0'); + while (total_read < chunk_size) { + int size = func(stream, chunk_data.data() + total_read, + chunk_size - total_read); + if (size <= 0) { + break; + } + total_read += size; + } + + body_buffer.append(chunk_data); + + // 3. Read trailing "\r\n" + read_line(); + } + } else { + body_buffer.resize(content_length); + size_t total_bytes_read = 0; + while (total_bytes_read < content_length) { + int size = func(stream, body_buffer.data() + total_bytes_read, + content_length - total_bytes_read); + if (size <= 0) { + break; + } } - total_bytes_read += size; } return {header_buffer, body_buffer}; diff --git a/src/http/HttpClient.h b/src/http/HttpClient.h index b00292c..86daaf4 100644 --- a/src/http/HttpClient.h +++ b/src/http/HttpClient.h @@ -36,6 +36,7 @@ class HttpClient : public IHttpClient { const std::string& url) const; std::string get_cache_key(const HttpReqParams& params) const; bool should_redirect(const HttpResponse& r) const; + Logger* logger; std::unordered_map> m_http_sockets; diff --git a/src/main.cpp b/src/main.cpp index 8ec2e40..3f3396d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -50,11 +50,6 @@ int main(int argc, char* argv[]) { if (response && print_output) { url.show(response->body); } - std::this_thread::sleep_for(std::chrono::seconds(2)); - response = url.request(); - if (response && print_output) { - url.show(response->body); - } return 0; } diff --git a/src/url/Url.cpp b/src/url/Url.cpp index 2938d96..b8ae31a 100644 --- a/src/url/Url.cpp +++ b/src/url/Url.cpp @@ -120,6 +120,7 @@ void URL::show(std::string& body) { std::cout << c; } } + std::cout << '\n'; } bool URL::is_scheme_in(Scheme s) const { return m_data.scheme == s; } diff --git a/src/utils.cpp b/src/utils.cpp index b4d1739..629ce08 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -1,7 +1,8 @@ #include "utils.h" -#include // For std::find_if -#include // For std::isspace -#include // For std::not1, std::ptr_fun (C++03, or use lambda in C++11+) +#include +#include +#include +#include #include #include #include @@ -41,22 +42,36 @@ void trim(std::string& s) { rtrim(s); } -// Trim from start (copying) -std::string ltrim_copy(std::string s) { - ltrim(s); - return s; -} +std::optional ungzip(const std::string& compressed) { + if (compressed.empty()) return {}; -// Trim from end (copying) -std::string rtrim_copy(std::string s) { - rtrim(s); - return s; -} + z_stream strm{}; + strm.next_in = reinterpret_cast(const_cast(compressed.data())); + strm.avail_in = static_cast(compressed.size()); + + if (inflateInit2(&strm, 16 + MAX_WBITS) != Z_OK) { + return {}; + } + + std::string out; + const size_t chunkSize = 16 * 1024; + int ret; + + do { + out.resize(out.size() + chunkSize); + strm.next_out = reinterpret_cast(&out[out.size() - chunkSize]); + strm.avail_out = chunkSize; + + ret = inflate(&strm, Z_NO_FLUSH); + if (ret != Z_OK && ret != Z_STREAM_END) { + inflateEnd(&strm); + return {}; + } + } while (ret != Z_STREAM_END); -// Trim from both ends (copying) -std::string trim_copy(std::string s) { - trim(s); - return s; + inflateEnd(&strm); + out.resize(strm.total_out); + return out; } } // namespace utils diff --git a/src/utils.h b/src/utils.h index 3d85df5..275bbdd 100644 --- a/src/utils.h +++ b/src/utils.h @@ -1,5 +1,7 @@ #pragma once +#include #include +#include #include namespace utils { std::vector split_string(const std::string& s, char delim); @@ -12,4 +14,6 @@ void rtrim(std::string& s); // Trim from both ends (in place) void trim(std::string& s); +std::optional ungzip(const std::string& compressed); + } // namespace utils diff --git a/tests/test_utilites.cpp b/tests/test_utilites.cpp index 337e3b9..42a44f1 100644 --- a/tests/test_utilites.cpp +++ b/tests/test_utilites.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include "utils.h" TEST(Utils, SplitString) {