diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4e2361c1..c3a340c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,13 +73,15 @@ jobs: ./transaction_manager_tests ./statement_tests ./recovery_tests + ./recovery_manager_tests + ./buffer_pool_tests - name: Generate Coverage Report if: matrix.sanitizer == 'address' && matrix.compiler == 'clang++' run: | cd build lcov --directory . --capture --output-file coverage.info --gcov-tool "$GITHUB_WORKSPACE/.github/scripts/llvm-gcov.sh" --ignore-errors inconsistent - lcov --remove coverage.info '/usr/*' '*/tests/*' '*/CMakeFiles/*' --output-file filtered_coverage.info --ignore-errors inconsistent,unused + lcov --remove coverage.info '/usr/*' '*/tests/*' '*/CMakeFiles/*' '*/_deps/*' --output-file filtered_coverage.info --ignore-errors inconsistent,unused genhtml filtered_coverage.info --output-directory out_coverage --ignore-errors inconsistent - name: Upload Coverage diff --git a/include/storage/buffer_pool_manager.hpp b/include/storage/buffer_pool_manager.hpp index 4617fb45..8481cc67 100644 --- a/include/storage/buffer_pool_manager.hpp +++ b/include/storage/buffer_pool_manager.hpp @@ -7,6 +7,7 @@ #define CLOUDSQL_STORAGE_BUFFER_POOL_MANAGER_HPP #include +#include #include #include #include @@ -95,6 +96,11 @@ class BufferPoolManager { */ void flush_all_pages(); + /** + * @brief Get pointer to log manager + */ + [[nodiscard]] recovery::LogManager* get_log_manager() const { return log_manager_; } + private: /** * @brief Generates a unique string key for file and page mapping @@ -111,7 +117,7 @@ class BufferPoolManager { std::mutex latch_; // The actual array of pages - std::vector pages_; + std::unique_ptr pages_; // Replacer instance LRUReplacer replacer_; diff --git a/include/storage/storage_manager.hpp b/include/storage/storage_manager.hpp index 908d98a9..aa4debc9 100644 --- a/include/storage/storage_manager.hpp +++ b/include/storage/storage_manager.hpp @@ -73,6 +73,18 @@ class StorageManager { */ bool write_page(const std::string& filename, uint32_t page_num, const char* buffer); + /** + * @brief Allocate a new page in the database file + * @param filename Name of the database file + * @return index of the newly allocated page + */ + uint32_t allocate_page(const std::string& filename); + + /** + * @brief Deallocate a page (stub for future use) + */ + static void deallocate_page(const std::string& filename, uint32_t page_num); + /** * @brief Check if a file exists */ diff --git a/include/transaction/transaction.hpp b/include/transaction/transaction.hpp index 6a394ffa..c86b4b9a 100644 --- a/include/transaction/transaction.hpp +++ b/include/transaction/transaction.hpp @@ -109,11 +109,11 @@ class Transaction { exclusive_locks_.insert(rid); } - [[nodiscard]] const std::unordered_set& get_shared_locks() { + [[nodiscard]] std::unordered_set get_shared_lock_set() { const std::scoped_lock lock(lock_set_mutex_); return shared_locks_; } - [[nodiscard]] const std::unordered_set& get_exclusive_locks() { + [[nodiscard]] std::unordered_set get_exclusive_lock_set() { const std::scoped_lock lock(lock_set_mutex_); return exclusive_locks_; } diff --git a/include/transaction/transaction_manager.hpp b/include/transaction/transaction_manager.hpp index 09487db7..a3650595 100644 --- a/include/transaction/transaction_manager.hpp +++ b/include/transaction/transaction_manager.hpp @@ -7,41 +7,77 @@ #define CLOUDSQL_TRANSACTION_TRANSACTION_MANAGER_HPP #include +#include #include +#include #include +#include #include "catalog/catalog.hpp" -#include "recovery/log_manager.hpp" #include "storage/buffer_pool_manager.hpp" #include "transaction/lock_manager.hpp" #include "transaction/transaction.hpp" namespace cloudsql::transaction { +/** + * @brief Manages the lifecycle of transactions + */ class TransactionManager { - private: - std::atomic next_txn_id_{1}; - std::unordered_map> active_transactions_; - std::unordered_map> completed_transactions_; - LockManager& lock_manager_; - Catalog& catalog_; - storage::BufferPoolManager& bpm_; - recovery::LogManager* log_manager_; - std::mutex manager_latch_; - - void undo_transaction(Transaction* txn); - public: explicit TransactionManager(LockManager& lock_manager, Catalog& catalog, storage::BufferPoolManager& bpm, - recovery::LogManager* log_manager = nullptr) - : lock_manager_(lock_manager), catalog_(catalog), bpm_(bpm), log_manager_(log_manager) {} + recovery::LogManager* log_manager = nullptr); + + ~TransactionManager() = default; + // Disable copy/move + TransactionManager(const TransactionManager&) = delete; + TransactionManager& operator=(const TransactionManager&) = delete; + TransactionManager(TransactionManager&&) = delete; + TransactionManager& operator=(TransactionManager&&) = delete; + + /** + * @brief Start a new transaction + * @param level Isolation level + * @return Pointer to the new transaction + */ Transaction* begin(IsolationLevel level = IsolationLevel::REPEATABLE_READ); + + /** + * @brief Commit a transaction + */ void commit(Transaction* txn); + + /** + * @brief Abort a transaction + */ void abort(Transaction* txn); + /** + * @brief Get transaction by ID + */ Transaction* get_transaction(txn_id_t txn_id); + + private: + LockManager& lock_manager_; + Catalog& catalog_; + storage::BufferPoolManager& bpm_; + recovery::LogManager* log_manager_; + + std::atomic next_txn_id_{1}; + std::mutex manager_latch_; + + // All active transactions + std::unordered_map> active_transactions_; + + // Transactions that have recently finished (for cleanup/safety) + std::deque> completed_transactions_; + + /** + * @brief Undo changes made by a transaction + */ + void undo_transaction(Transaction* txn); }; } // namespace cloudsql::transaction diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index dca96d97..96cad052 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -70,6 +71,10 @@ bool Catalog::save(const std::string& filename) const { * @brief Create a new table */ oid_t Catalog::create_table(const std::string& table_name, std::vector columns) { + if (table_exists_by_name(table_name)) { + throw std::runtime_error("Table already exists: " + table_name); + } + auto table = std::make_unique(); table->table_id = next_oid_++; table->name = table_name; @@ -139,6 +144,13 @@ oid_t Catalog::create_index(const std::string& index_name, oid_t table_id, return 0; } + auto& table = *table_opt.value(); + for (const auto& existing_idx : table.indexes) { + if (existing_idx.name == index_name) { + throw std::runtime_error("Index already exists: " + index_name); + } + } + IndexInfo index; index.index_id = next_oid_++; index.name = index_name; @@ -148,7 +160,7 @@ oid_t Catalog::create_index(const std::string& index_name, oid_t table_id, index.is_unique = is_unique; const oid_t id = index.index_id; - (*table_opt)->indexes.push_back(std::move(index)); + table.indexes.push_back(std::move(index)); return id; } diff --git a/src/network/server.cpp b/src/network/server.cpp index 60169827..ad7af3e6 100644 --- a/src/network/server.cpp +++ b/src/network/server.cpp @@ -18,48 +18,41 @@ #include #include -#include #include #include -#include +#include #include #include #include -#include +#include #include #include #include "catalog/catalog.hpp" #include "executor/query_executor.hpp" +#include "executor/types.hpp" #include "parser/lexer.hpp" #include "parser/parser.hpp" -#include "parser/statement.hpp" #include "storage/buffer_pool_manager.hpp" +#include "transaction/lock_manager.hpp" +#include "transaction/transaction_manager.hpp" namespace cloudsql::network { namespace { -constexpr int PROTOCOL_VERSION_3 = 196608; -constexpr int SSL_REQUEST_CODE = 80877103; -constexpr int AUTH_OK_MSG_SIZE = 8; -constexpr int READY_FOR_QUERY_MSG_SIZE = 5; -constexpr int TEXT_FORMAT_CODE = 0; -constexpr int TEXT_TYPE_OID = 25; -constexpr size_t MAX_PACKET_SIZE = 8192; -constexpr int SELECT_TIMEOUT_USEC = 100000; -constexpr int ERROR_MSG_LEN = 9; -constexpr size_t MIN_MSG_SIZE = 5; + constexpr size_t HEADER_SIZE = 4; +constexpr int SELECT_TIMEOUT_SEC = 1; +constexpr uint32_t PG_SSL_CODE = 80877103; +constexpr uint32_t PG_STARTUP_CODE = 196608; /** - * @brief Helper to receive exactly count bytes - * @return count on success, 0 on EOF, -1 on error + * @brief Simple utility to receive exactly N bytes */ ssize_t recv_all(int fd, char* buf, size_t count) { size_t total = 0; while (total < count) { - const ssize_t n = - recv(fd, std::next(buf, static_cast(total)), count - total, 0); + const ssize_t n = recv(fd, buf + total, static_cast(count - total), 0); if (n <= 0) { return n; } @@ -67,110 +60,66 @@ ssize_t recv_all(int fd, char* buf, size_t count) { } return static_cast(total); } -} // anonymous namespace /** - * @brief Helper for parsing PostgreSQL binary protocol + * @brief Reader for PostgreSQL protocol types */ class ProtocolReader { public: - [[nodiscard]] static uint32_t read_int32(const char* buffer) { + static uint32_t read_int32(const char* data) { uint32_t val = 0; - std::memcpy(&val, buffer, sizeof(uint32_t)); + std::memcpy(&val, data, 4); return ntohl(val); } - - [[nodiscard]] static std::string read_string(const char* buffer, size_t& offset, size_t limit) { - std::string s; - const std::string_view view(buffer, limit); - while (offset < limit) { - const char c = view[offset]; - if (c == '\0') { - break; - } - s += c; - offset++; - } - if (offset < limit) { - offset++; /* Skip null terminator */ - } - return s; - } }; /** - * @brief Helper for building PostgreSQL binary responses + * @brief Writer for PostgreSQL protocol types */ class ProtocolWriter { public: - static void append_int16(std::vector& buf, uint16_t val) { - const uint16_t nval = htons(val); - std::array p{}; - std::memcpy(p.data(), &nval, sizeof(uint16_t)); - buf.insert(buf.end(), p.begin(), p.end()); - } - - static void append_int32(std::vector& buf, uint32_t val) { + static void write_int32(char* data, uint32_t val) { const uint32_t nval = htonl(val); - std::array p{}; - std::memcpy(p.data(), &nval, sizeof(uint32_t)); - buf.insert(buf.end(), p.begin(), p.end()); + std::memcpy(data, &nval, 4); } - static void append_string(std::vector& buf, const std::string& s) { - buf.insert(buf.end(), s.begin(), s.end()); - buf.push_back('\0'); - } - - static void finish_message(std::vector& buf) { - if (buf.size() < MIN_MSG_SIZE) { - return; - } - const uint32_t len = htonl(static_cast(buf.size() - 1)); - std::memcpy(&buf[1], &len, sizeof(uint32_t)); + static void write_int16(char* data, uint16_t val) { + const uint16_t nval = htons(val); + std::memcpy(data, &nval, 2); } }; +} // namespace + Server::Server(uint16_t port, Catalog& catalog, storage::BufferPoolManager& bpm) : port_(port), catalog_(catalog), bpm_(bpm), - transaction_manager_(lock_manager_, catalog, bpm) {} + transaction_manager_(lock_manager_, catalog, bpm, bpm.get_log_manager()) {} std::unique_ptr Server::create(uint16_t port, Catalog& catalog, storage::BufferPoolManager& bpm) { return std::make_unique(port, catalog, bpm); } -/** - * @brief Start the server - */ bool Server::start() { - { - const std::scoped_lock lock(state_mutex_); - if (running_) { - return false; - } - } - const int fd = socket(AF_INET, SOCK_STREAM, 0); if (fd < 0) { return false; } - const int opt = 1; - static_cast(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))); + int opt = 1; + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { + static_cast(close(fd)); + return false; + } struct sockaddr_in addr {}; addr.sin_family = AF_INET; addr.sin_addr.s_addr = INADDR_ANY; addr.sin_port = htons(port_); - struct sockaddr_storage storage {}; - std::memcpy(&storage, &addr, sizeof(addr)); - - if (bind(fd, reinterpret_cast(&storage), // NOLINT - sizeof(addr)) < 0) { + if (bind(fd, reinterpret_cast(&addr), sizeof(addr)) < 0) { static_cast(close(fd)); return false; } @@ -192,9 +141,6 @@ bool Server::start() { return true; } -/** - * @brief Stop the server - */ bool Server::stop() { int fd_to_close = -1; { @@ -202,46 +148,32 @@ bool Server::stop() { if (!running_) { return true; } - status_ = ServerStatus::Stopping; running_ = false; - if (listen_fd_ >= 0) { - fd_to_close = listen_fd_; - listen_fd_ = -1; - } + fd_to_close = listen_fd_; + listen_fd_ = -1; } - // 1. Signal all active client connections to shut down - std::vector fds; - { - const std::scoped_lock lock(thread_mutex_); - fds = client_fds_; - } - - for (const int fd : fds) { - static_cast(shutdown(fd, SHUT_RDWR)); - } - - // 2. Join the accept thread - std::thread a_thread; + std::thread t; { const std::scoped_lock lock(thread_mutex_); if (accept_thread_.joinable()) { - a_thread = std::move(accept_thread_); + t = std::move(accept_thread_); } } - if (a_thread.joinable()) { - a_thread.join(); + if (t.joinable()) { + t.join(); } - // 3. Join all connection worker threads std::vector workers; + std::vector fds; { const std::scoped_lock lock(thread_mutex_); + fds.swap(client_fds_); workers.swap(worker_threads_); } - for (auto& t : workers) { - if (t.joinable()) { - t.join(); + for (auto& worker : workers) { + if (worker.joinable()) { + worker.join(); } } @@ -258,6 +190,7 @@ bool Server::stop() { const std::scoped_lock lock(state_mutex_); status_ = ServerStatus::Stopped; } + return true; } @@ -290,161 +223,128 @@ int Server::get_listen_fd() const { } std::string Server::get_status_string() const { - switch (get_status()) { - case ServerStatus::Stopped: - return "Stopped"; + const std::scoped_lock lock(state_mutex_); + switch (status_) { case ServerStatus::Starting: return "Starting"; case ServerStatus::Running: return "Running"; case ServerStatus::Stopping: return "Stopping"; + case ServerStatus::Stopped: + return "Stopped"; case ServerStatus::Error: return "Error"; - default: - return "Unknown"; } + return "Unknown"; } void Server::accept_connections() { - while (is_running()) { - const int fd = get_listen_fd(); - if (fd < 0) { - break; + while (true) { + int fd = -1; + { + const std::scoped_lock lock(state_mutex_); + if (!running_) { + break; + } + fd = listen_fd_; } - /* Use select with timeout to allow periodic is_running() check */ fd_set read_fds; FD_ZERO(&read_fds); FD_SET(fd, &read_fds); struct timeval timeout { - 0, SELECT_TIMEOUT_USEC + SELECT_TIMEOUT_SEC, 0 }; const int res = select(fd + 1, &read_fds, nullptr, nullptr, &timeout); if (res <= 0) { - continue; /* Timeout or error */ + continue; } struct sockaddr_in client_addr {}; socklen_t client_len = sizeof(client_addr); - - struct sockaddr_storage storage {}; const int client_fd = - accept(fd, reinterpret_cast(&storage), &client_len); // NOLINT - if (client_fd < 0) { - continue; - } - std::memcpy(&client_addr, &storage, sizeof(client_addr)); + accept(fd, reinterpret_cast(&client_addr), &client_len); - static_cast(stats_.connections_accepted.fetch_add(1)); - static_cast(stats_.connections_active.fetch_add(1)); - - const std::scoped_lock lock(thread_mutex_); - client_fds_.push_back(client_fd); - worker_threads_.emplace_back([this, client_fd]() { - handle_connection(client_fd); - static_cast(stats_.connections_active.fetch_sub(1)); - - const std::scoped_lock thread_lock(thread_mutex_); - auto it = std::remove(client_fds_.begin(), client_fds_.end(), client_fd); - client_fds_.erase(it, client_fds_.end()); - }); + if (client_fd >= 0) { + const std::scoped_lock lock(thread_mutex_); + client_fds_.push_back(client_fd); + worker_threads_.emplace_back(&Server::handle_connection, this, client_fd); + } } } -/** - * @brief Handle a client connection using PostgreSQL protocol - */ void Server::handle_connection(int client_fd) { - std::array buffer{}; - executor::QueryExecutor client_executor(catalog_, bpm_, lock_manager_, transaction_manager_); + constexpr size_t PKT_BUF_SIZE = 8192; + std::array buffer{}; - /* 1. Read Length (Initial Startup/SSL) */ + // 1. Handshake ssize_t n = recv_all(client_fd, buffer.data(), HEADER_SIZE); if (n < static_cast(HEADER_SIZE)) { + static_cast(close(client_fd)); return; } uint32_t len = ProtocolReader::read_int32(buffer.data()); if (len > buffer.size() || len < HEADER_SIZE) { + static_cast(close(client_fd)); return; } - /* 2. Read Rest of Startup/SSL Packet */ - n = recv_all(client_fd, std::next(buffer.data(), static_cast(HEADER_SIZE)), - len - HEADER_SIZE); + n = recv_all(client_fd, buffer.data() + HEADER_SIZE, len - HEADER_SIZE); if (n < static_cast(len - HEADER_SIZE)) { + static_cast(close(client_fd)); return; } - uint32_t protocol = ProtocolReader::read_int32( - std::next(buffer.data(), static_cast(HEADER_SIZE))); - - /* Check for SSL Request */ - if (protocol == static_cast(SSL_REQUEST_CODE)) { - const char ssl_deny = 'N'; - static_cast(send(client_fd, &ssl_deny, 1, 0)); + uint32_t code = ProtocolReader::read_int32(buffer.data() + HEADER_SIZE); + if (code == PG_SSL_CODE) { // SSL Request + const char n_response = 'N'; + static_cast(send(client_fd, &n_response, 1, 0)); + // Expect startup packet next n = recv_all(client_fd, buffer.data(), HEADER_SIZE); if (n < static_cast(HEADER_SIZE)) { + static_cast(close(client_fd)); return; } len = ProtocolReader::read_int32(buffer.data()); - if (len < HEADER_SIZE || len > buffer.size()) { - return; - } - n = recv_all(client_fd, std::next(buffer.data(), static_cast(HEADER_SIZE)), - len - HEADER_SIZE); - if (n < static_cast(len - HEADER_SIZE)) { - return; - } - protocol = ProtocolReader::read_int32( - std::next(buffer.data(), static_cast(HEADER_SIZE))); + static_cast(recv_all(client_fd, buffer.data() + HEADER_SIZE, len - HEADER_SIZE)); + code = ProtocolReader::read_int32(buffer.data() + HEADER_SIZE); } - if (protocol != static_cast(PROTOCOL_VERSION_3)) { + if (code != PG_STARTUP_CODE) { + static_cast(close(client_fd)); return; } - /* Send AuthenticationOK ('R') */ - const std::vector auth_ok = {'R', 0, 0, 0, static_cast(AUTH_OK_MSG_SIZE), - 0, 0, 0, 0}; + // Auth OK + const std::array auth_ok = {'R', 0, 0, 0, 8, 0, 0, 0, 0}; static_cast(send(client_fd, auth_ok.data(), auth_ok.size(), 0)); - /* Send ReadyForQuery ('Z') */ - const std::vector ready = {'Z', 0, 0, 0, static_cast(READY_FOR_QUERY_MSG_SIZE), - 'I'}; + // Ready for Query + const std::array ready = {'Z', 0, 0, 0, 5, 'I'}; static_cast(send(client_fd, ready.data(), ready.size(), 0)); - /* 5. Main Message Loop */ - while (is_running()) { - char type = '\0'; - n = recv_all(client_fd, &type, 1); + // 2. Query Loop + while (true) { + char type = 0; + n = recv(client_fd, &type, 1, 0); if (n <= 0) { break; } - n = recv_all(client_fd, buffer.data(), HEADER_SIZE); - if (n < static_cast(HEADER_SIZE)) { + n = recv_all(client_fd, buffer.data(), 4); + if (n < 4) { break; } len = ProtocolReader::read_int32(buffer.data()); - if (len < HEADER_SIZE) { - break; - } - - std::vector body(len - HEADER_SIZE); - if (len > HEADER_SIZE) { - n = recv_all(client_fd, body.data(), len - HEADER_SIZE); - if (n < static_cast(len - HEADER_SIZE)) { - break; - } - } - if (type == 'Q') { /* Simple Query */ - const std::string sql(body.data()); - static_cast(stats_.queries_executed.fetch_add(1)); + if (type == 'Q') { + std::vector sql_buf(len - 4); + static_cast(recv_all(client_fd, sql_buf.data(), len - 4)); + const std::string sql(sql_buf.data()); try { auto lexer = std::make_unique(sql); @@ -452,72 +352,101 @@ void Server::handle_connection(int client_fd) { auto stmt = parser.parse_statement(); if (stmt) { - auto result = client_executor.execute(*stmt); - - if (result.success()) { - /* 1. Send RowDescription ('T') for SELECT */ - if (stmt->type() == parser::StmtType::Select) { - std::vector desc = {'T'}; - ProtocolWriter::append_int32(desc, 0); // Length placeholder - ProtocolWriter::append_int16( - desc, static_cast(result.schema().column_count())); - - for (const auto& col : result.schema().columns()) { - ProtocolWriter::append_string(desc, col.name()); - ProtocolWriter::append_int32(desc, 0); // Table OID - ProtocolWriter::append_int16(desc, 0); // Attr index - ProtocolWriter::append_int32(desc, - static_cast(TEXT_TYPE_OID)); - ProtocolWriter::append_int16( - desc, static_cast(-1)); // Type size - ProtocolWriter::append_int32( - desc, static_cast(-1)); // Type modifier - ProtocolWriter::append_int16( - desc, static_cast(TEXT_FORMAT_CODE)); + executor::QueryExecutor exec(catalog_, bpm_, lock_manager_, + transaction_manager_); + const auto res = exec.execute(*stmt); + + if (res.success()) { + // Row Description (T) + if (!res.rows().empty() && res.schema().column_count() > 0) { + const auto& schema = res.schema(); + const auto num_cols = static_cast(schema.column_count()); + + // Calculate T packet length + uint32_t t_len = 4 + 2; // len + num_cols + for (uint32_t i = 0; i < num_cols; ++i) { + t_len += static_cast(schema.get_column(i).name().size()) + + 1 + 4 + 2 + 4 + 2 + 4 + 2; } - ProtocolWriter::finish_message(desc); - static_cast(send(client_fd, desc.data(), desc.size(), 0)); - - /* 2. Send DataRows ('D') */ - for (const auto& row : result.rows()) { - std::vector data = {'D'}; - ProtocolWriter::append_int32(data, 0); // Length - ProtocolWriter::append_int16(data, - static_cast(row.size())); - - for (const auto& val : row.values()) { - const std::string s = val.to_string(); - ProtocolWriter::append_int32(data, - static_cast(s.size())); - data.insert(data.end(), s.begin(), s.end()); + + const char t_type = 'T'; + const uint32_t net_t_len = htonl(t_len); + const uint16_t net_num_cols = htons(static_cast(num_cols)); + + static_cast(send(client_fd, &t_type, 1, 0)); + static_cast(send(client_fd, &net_t_len, 4, 0)); + static_cast(send(client_fd, &net_num_cols, 2, 0)); + + for (uint32_t i = 0; i < num_cols; ++i) { + const auto& col = schema.get_column(i); + static_cast( + send(client_fd, col.name().c_str(), col.name().size() + 1, 0)); + const uint32_t table_oid = 0; + const uint16_t col_attr = 0; + const uint32_t type_oid = htonl(23); // 23 is int4, simplified + const uint16_t type_len = htons(4); + const uint32_t type_mod = htonl(0xFFFFFFFF); + const uint16_t format = 0; // Text format + + static_cast(send(client_fd, &table_oid, 4, 0)); + static_cast(send(client_fd, &col_attr, 2, 0)); + static_cast(send(client_fd, &type_oid, 4, 0)); + static_cast(send(client_fd, &type_len, 2, 0)); + static_cast(send(client_fd, &type_mod, 4, 0)); + static_cast(send(client_fd, &format, 2, 0)); + } + + // Data Rows (D) + for (const auto& row : res.rows()) { + const char d_type = 'D'; + uint32_t d_len = 4 + 2; // len + num_cols + std::vector str_vals; + for (uint32_t i = 0; i < num_cols; ++i) { + const std::string s_val = row.get(i).to_string(); + str_vals.push_back(s_val); + d_len += + 4 + static_cast(s_val.size()); // len + value + } + + const uint32_t net_d_len = htonl(d_len); + static_cast(send(client_fd, &d_type, 1, 0)); + static_cast(send(client_fd, &net_d_len, 4, 0)); + static_cast(send(client_fd, &net_num_cols, 2, 0)); + + for (const auto& s_val : str_vals) { + const uint32_t val_len = + htonl(static_cast(s_val.size())); + static_cast(send(client_fd, &val_len, 4, 0)); + static_cast( + send(client_fd, s_val.c_str(), s_val.size(), 0)); } - ProtocolWriter::finish_message(data); - static_cast(send(client_fd, data.data(), data.size(), 0)); } } - /* 3. Send CommandComplete ('C') */ - std::vector complete = {'C'}; - const std::string msg = (stmt->type() == parser::StmtType::Select) - ? "SELECT " + std::to_string(result.row_count()) - : "OK"; - ProtocolWriter::append_int32(complete, - static_cast(4 + msg.size() + 1)); - ProtocolWriter::append_string(complete, msg); - static_cast(send(client_fd, complete.data(), complete.size(), 0)); + // Command Complete (C) + const std::string tag = "SELECT " + std::to_string(res.row_count()); + const uint32_t tag_len = htonl(static_cast(tag.size() + 4 + 1)); + const char c_type = 'C'; + static_cast(send(client_fd, &c_type, 1, 0)); + static_cast(send(client_fd, &tag_len, 4, 0)); + static_cast(send(client_fd, tag.c_str(), tag.size() + 1, 0)); } else { - /* Send ErrorResponse ('E') or CommandComplete with error */ - std::vector complete = {'C'}; - ProtocolWriter::append_int32(complete, ERROR_MSG_LEN); - ProtocolWriter::append_string(complete, "ERROR"); - static_cast(send(client_fd, complete.data(), complete.size(), 0)); + // Error Response (E) + const std::string& err = res.error(); + const uint32_t e_len = htonl(static_cast(err.size() + 4 + 1)); + const char e_type = 'E'; + static_cast(send(client_fd, &e_type, 1, 0)); + static_cast(send(client_fd, &e_len, 4, 0)); + static_cast(send(client_fd, err.c_str(), err.size() + 1, 0)); } } - } catch (...) { /* Handle parsing/exec errors */ - std::vector complete = {'C'}; - ProtocolWriter::append_int32(complete, ERROR_MSG_LEN); - ProtocolWriter::append_string(complete, "ERROR"); - static_cast(send(client_fd, complete.data(), complete.size(), 0)); + } catch (const std::exception& e) { + const std::string err = e.what(); + const uint32_t e_len = htonl(static_cast(err.size() + 4 + 1)); + const char e_type = 'E'; + static_cast(send(client_fd, &e_type, 1, 0)); + static_cast(send(client_fd, &e_len, 4, 0)); + static_cast(send(client_fd, err.c_str(), err.size() + 1, 0)); } } else if (type == 'X') { break; @@ -526,6 +455,15 @@ void Server::handle_connection(int client_fd) { /* Ready for Query */ static_cast(send(client_fd, ready.data(), ready.size(), 0)); } + + { + const std::scoped_lock lock(thread_mutex_); + auto it = std::find(client_fds_.begin(), client_fds_.end(), client_fd); + if (it != client_fds_.end()) { + static_cast(client_fds_.erase(it)); + } + } + static_cast(close(client_fd)); } } // namespace cloudsql::network diff --git a/src/storage/buffer_pool_manager.cpp b/src/storage/buffer_pool_manager.cpp index 13195588..cd56cc9c 100644 --- a/src/storage/buffer_pool_manager.cpp +++ b/src/storage/buffer_pool_manager.cpp @@ -1,16 +1,16 @@ /** * @file buffer_pool_manager.cpp - * @brief Implementation of the Buffer Pool Manager + * @brief Buffer pool manager implementation */ #include "storage/buffer_pool_manager.hpp" -#include #include +#include +#include #include #include -#include "recovery/log_manager.hpp" #include "storage/page.hpp" #include "storage/storage_manager.hpp" @@ -21,22 +21,19 @@ BufferPoolManager::BufferPoolManager(size_t pool_size, StorageManager& storage_m : pool_size_(pool_size), storage_manager_(storage_manager), log_manager_(log_manager), - pages_(pool_size_), - replacer_(pool_size_) { - for (size_t i = 0; i < pool_size_; i++) { + pages_(std::make_unique(pool_size)), + replacer_(pool_size) { + for (size_t i = 0; i < pool_size_; ++i) { free_list_.push_back(static_cast(i)); } } -BufferPoolManager::~BufferPoolManager() { - flush_all_pages(); -} +BufferPoolManager::~BufferPoolManager() = default; Page* BufferPoolManager::fetch_page(const std::string& file_name, uint32_t page_id) { - const std::lock_guard lock(latch_); - const std::string key = make_page_key(file_name, page_id); + const std::scoped_lock lock(latch_); - // 1. If page is already in the buffer pool + const std::string key = make_page_key(file_name, page_id); if (page_table_.find(key) != page_table_.end()) { const uint32_t frame_id = page_table_[key]; Page* const page = &pages_[frame_id]; @@ -45,61 +42,36 @@ Page* BufferPoolManager::fetch_page(const std::string& file_name, uint32_t page_ return page; } - // 2. Page is not in the pool. Find a victim or free frame. uint32_t frame_id = 0; if (!free_list_.empty()) { - frame_id = free_list_.front(); - free_list_.pop_front(); - } else { - if (!replacer_.victim(&frame_id)) { - // Buffer pool is full and everything is pinned - return nullptr; - } - // Write back dirty page - Page* const victim_page = &pages_[frame_id]; - if (victim_page->is_dirty_) { - // Check WAL requirements before flushing - if (log_manager_ != nullptr && victim_page->lsn_ != -1) { - if (victim_page->lsn_ > log_manager_->get_persistent_lsn()) { - log_manager_->flush(true); - } - } - storage_manager_.write_page(victim_page->file_name_, victim_page->page_id_, - victim_page->get_data()); - } - - // Remove from page table - page_table_.erase(make_page_key(victim_page->file_name_, victim_page->page_id_)); - } - - // 3. Read the page from disk - Page* const new_page_ptr = &pages_[frame_id]; - new_page_ptr->reset_memory(); - - // storage_manager_.read_page populates the buffer or zero-fills if it doesn't exist - if (!storage_manager_.read_page(file_name, page_id, new_page_ptr->get_data())) { - // If it really failed (e.g., IO error), we should return the frame - free_list_.push_back(frame_id); + frame_id = free_list_.back(); + free_list_.pop_back(); + } else if (!replacer_.victim(&frame_id)) { return nullptr; } - // 4. Update metadata - new_page_ptr->page_id_ = page_id; - new_page_ptr->file_name_ = file_name; - new_page_ptr->pin_count_ = 1; - new_page_ptr->is_dirty_ = false; - new_page_ptr->lsn_ = -1; + Page* const page = &pages_[frame_id]; + if (page->is_dirty_) { + storage_manager_.write_page(page->file_name_, page->page_id_, page->get_data()); + } + page_table_.erase(make_page_key(page->file_name_, page->page_id_)); page_table_[key] = frame_id; - replacer_.pin(frame_id); - return new_page_ptr; + page->page_id_ = page_id; + page->file_name_ = file_name; + page->pin_count_ = 1; + page->is_dirty_ = false; + storage_manager_.read_page(file_name, page_id, page->get_data()); + + replacer_.pin(frame_id); + return page; } bool BufferPoolManager::unpin_page(const std::string& file_name, uint32_t page_id, bool is_dirty) { - const std::lock_guard lock(latch_); - const std::string key = make_page_key(file_name, page_id); + const std::scoped_lock lock(latch_); + const std::string key = make_page_key(file_name, page_id); if (page_table_.find(key) == page_table_.end()) { return false; } @@ -124,122 +96,87 @@ bool BufferPoolManager::unpin_page(const std::string& file_name, uint32_t page_i } bool BufferPoolManager::flush_page(const std::string& file_name, uint32_t page_id) { - const std::lock_guard lock(latch_); - const std::string key = make_page_key(file_name, page_id); + const std::scoped_lock lock(latch_); + const std::string key = make_page_key(file_name, page_id); if (page_table_.find(key) == page_table_.end()) { return false; } const uint32_t frame_id = page_table_[key]; Page* const page = &pages_[frame_id]; - - // Check WAL requirements before flushing - if (log_manager_ != nullptr && page->lsn_ != -1) { - if (page->lsn_ > log_manager_->get_persistent_lsn()) { - log_manager_->flush(true); - } - } - - storage_manager_.write_page(page->file_name_, page->page_id_, page->get_data()); + storage_manager_.write_page(file_name, page_id, page->get_data()); page->is_dirty_ = false; + return true; } -Page* BufferPoolManager::new_page(const std::string& file_name, const uint32_t* const page_id) { - const std::lock_guard lock(latch_); - - // We need to determine the new page ID. In our basic layout, we'll - // assume the caller knows the ID, but wait, the signature expects us to - // assign it. For simplicity in cloudSQL, typically the table asks for a specific - // page by ID, or it just reads page N until it fails. - // If the caller calls new_page, we assume page_id was pre-filled with the desired ID - // or we have a way to know the next ID. - // Wait, the interface says `Output param for the id of the created page`. Currently - // let's just use the passed in page_id as the requested page ID to create. - const uint32_t target_page_id = *page_id; - const std::string key = make_page_key(file_name, target_page_id); +Page* BufferPoolManager::new_page(const std::string& file_name, const uint32_t* page_id) { + const std::scoped_lock lock(latch_); - // If already exists, return - if (page_table_.find(key) != page_table_.end()) { - return nullptr; + const uint32_t target_page_id = storage_manager_.allocate_page(file_name); + if (page_id != nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + const_cast(*page_id) = target_page_id; } + const std::string key = make_page_key(file_name, target_page_id); uint32_t frame_id = 0; if (!free_list_.empty()) { - frame_id = free_list_.front(); - free_list_.pop_front(); - } else { - if (!replacer_.victim(&frame_id)) { - return nullptr; - } - Page* const victim_page = &pages_[frame_id]; - if (victim_page->is_dirty_) { - if (log_manager_ != nullptr && victim_page->lsn_ != -1) { - if (victim_page->lsn_ > log_manager_->get_persistent_lsn()) { - log_manager_->flush(true); - } - } - storage_manager_.write_page(victim_page->file_name_, victim_page->page_id_, - victim_page->get_data()); - } - page_table_.erase(make_page_key(victim_page->file_name_, victim_page->page_id_)); + frame_id = free_list_.back(); + free_list_.pop_back(); + } else if (!replacer_.victim(&frame_id)) { + return nullptr; } - Page* const new_page_ptr = &pages_[frame_id]; - new_page_ptr->reset_memory(); + Page* const page = &pages_[frame_id]; + if (page->is_dirty_) { + storage_manager_.write_page(page->file_name_, page->page_id_, page->get_data()); + } - // Explicitly write a blank page to storage to instantiate it - storage_manager_.write_page(file_name, target_page_id, new_page_ptr->get_data()); + page_table_.erase(make_page_key(page->file_name_, page->page_id_)); + page_table_[key] = frame_id; - new_page_ptr->page_id_ = target_page_id; - new_page_ptr->file_name_ = file_name; - new_page_ptr->pin_count_ = 1; - new_page_ptr->is_dirty_ = false; - new_page_ptr->lsn_ = -1; + page->page_id_ = target_page_id; + page->file_name_ = file_name; + page->pin_count_ = 1; + page->is_dirty_ = false; + std::memset(page->get_data(), 0, Page::PAGE_SIZE); - page_table_[key] = frame_id; replacer_.pin(frame_id); - - return new_page_ptr; + return page; } bool BufferPoolManager::delete_page(const std::string& file_name, uint32_t page_id) { - const std::lock_guard lock(latch_); - const std::string key = make_page_key(file_name, page_id); - - if (page_table_.find(key) == page_table_.end()) { - return true; - } + const std::scoped_lock lock(latch_); - const uint32_t frame_id = page_table_[key]; - Page* const page = &pages_[frame_id]; + const std::string key = make_page_key(file_name, page_id); + if (page_table_.find(key) != page_table_.end()) { + const uint32_t frame_id = page_table_[key]; + Page* const page = &pages_[frame_id]; + if (page->pin_count_ > 0) { + return false; + } - if (page->pin_count_ > 0) { - return false; + page_table_.erase(key); + replacer_.pin(frame_id); + page->page_id_ = 0; + page->file_name_ = ""; + page->pin_count_ = 0; + page->is_dirty_ = false; + free_list_.push_back(frame_id); } - page_table_.erase(key); - free_list_.push_back(frame_id); - page->reset_memory(); - page->page_id_ = 0; - page->file_name_.clear(); - page->is_dirty_ = false; - page->lsn_ = -1; - + StorageManager::deallocate_page(file_name, page_id); return true; } void BufferPoolManager::flush_all_pages() { - const std::lock_guard lock(latch_); - for (size_t i = 0; i < pool_size_; i++) { - Page* const page = &pages_[i]; - if (!page->file_name_.empty() && page->is_dirty_) { - if (log_manager_ != nullptr && page->lsn_ != -1) { - if (page->lsn_ > log_manager_->get_persistent_lsn()) { - log_manager_->flush(true); - } - } + const std::scoped_lock lock(latch_); + + for (auto const& [key, frame_id] : page_table_) { + Page* const page = &pages_[frame_id]; + if (page->is_dirty_) { storage_manager_.write_page(page->file_name_, page->page_id_, page->get_data()); page->is_dirty_ = false; } diff --git a/src/storage/lru_replacer.cpp b/src/storage/lru_replacer.cpp index ffd926c2..c10d8de5 100644 --- a/src/storage/lru_replacer.cpp +++ b/src/storage/lru_replacer.cpp @@ -1,6 +1,6 @@ /** * @file lru_replacer.cpp - * @brief Implementation of the LRU Replacer + * @brief Least Recently Used (LRU) tracking implementation */ #include "storage/lru_replacer.hpp" @@ -13,52 +13,46 @@ namespace cloudsql::storage { LRUReplacer::LRUReplacer(size_t num_pages) : capacity_(num_pages) {} -bool LRUReplacer::victim(uint32_t* const frame_id) { - const std::lock_guard lock(latch_); +bool LRUReplacer::victim(uint32_t* frame_id) { + const std::scoped_lock lock(latch_); + if (lru_list_.empty()) { return false; } - // The back of the list is the least recently used *frame_id = lru_list_.back(); - static_cast(lru_map_.erase(*frame_id)); lru_list_.pop_back(); - + static_cast(lru_map_.erase(*frame_id)); return true; } void LRUReplacer::pin(uint32_t frame_id) { - const std::lock_guard lock(latch_); + const std::scoped_lock lock(latch_); + const auto it = lru_map_.find(frame_id); if (it != lru_map_.end()) { - // Remove it from the tracker because it's currently pinned/in-use static_cast(lru_list_.erase(it->second)); static_cast(lru_map_.erase(it)); } } void LRUReplacer::unpin(uint32_t frame_id) { - const std::lock_guard lock(latch_); - if (lru_map_.find(frame_id) != lru_map_.end()) { - // Already in the replacer's candidate list + const std::scoped_lock lock(latch_); + + if (lru_map_.count(frame_id) != 0) { + return; + } + + if (lru_list_.size() >= capacity_) { return; } - // Add to the front of the list (most recently used) lru_list_.push_front(frame_id); lru_map_[frame_id] = lru_list_.begin(); - - // Enforce capacity constraint (shouldn't happen strictly if used with BufferPool properly, but - // safe) - if (lru_list_.size() > capacity_) { - const uint32_t lru_frame = lru_list_.back(); - static_cast(lru_map_.erase(lru_frame)); - lru_list_.pop_back(); - } } size_t LRUReplacer::size() const { - const std::lock_guard lock(latch_); + const std::scoped_lock lock(latch_); return lru_list_.size(); } diff --git a/src/storage/storage_manager.cpp b/src/storage/storage_manager.cpp index d2b18df0..b14dc904 100644 --- a/src/storage/storage_manager.cpp +++ b/src/storage/storage_manager.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -160,6 +161,32 @@ bool StorageManager::write_page(const std::string& filename, uint32_t page_num, return true; } +/** + * @brief Allocate a new page in the database file + */ +uint32_t StorageManager::allocate_page(const std::string& filename) { + if (open_files_.find(filename) == open_files_.end()) { + if (!open_file(filename)) { + return 0; + } + } + + auto& file = open_files_[filename]; + file->clear(); + file->seekg(0, std::ios::end); + const std::streamoff size = file->tellg(); + + return static_cast(static_cast(size) / PAGE_SIZE); +} + +/** + * @brief Deallocate a page + */ +void StorageManager::deallocate_page(const std::string& filename, uint32_t page_num) { + (void)filename; + (void)page_num; +} + /** * @brief Create data directory if it doesn't exist */ diff --git a/src/transaction/transaction_manager.cpp b/src/transaction/transaction_manager.cpp index 0f00582b..1c226ccd 100644 --- a/src/transaction/transaction_manager.cpp +++ b/src/transaction/transaction_manager.cpp @@ -1,38 +1,37 @@ /** * @file transaction_manager.cpp - * @brief Transaction Manager implementation + * @brief Transaction manager implementation */ #include "transaction/transaction_manager.hpp" #include +#include #include #include #include +#include "catalog/catalog.hpp" #include "executor/types.hpp" +#include "recovery/log_manager.hpp" #include "recovery/log_record.hpp" +#include "storage/buffer_pool_manager.hpp" #include "storage/heap_table.hpp" +#include "transaction/lock_manager.hpp" #include "transaction/transaction.hpp" namespace cloudsql::transaction { +TransactionManager::TransactionManager(LockManager& lock_manager, Catalog& catalog, + storage::BufferPoolManager& bpm, + recovery::LogManager* log_manager) + : lock_manager_(lock_manager), catalog_(catalog), bpm_(bpm), log_manager_(log_manager) {} + Transaction* TransactionManager::begin(IsolationLevel level) { const std::scoped_lock lock(manager_latch_); - - /* Clean up old completed transactions to avoid memory leak */ - completed_transactions_.clear(); - const txn_id_t txn_id = next_txn_id_++; auto txn = std::make_unique(txn_id, level); - /* Log BEGIN */ - if (log_manager_ != nullptr) { - recovery::LogRecord log(txn_id, -1, recovery::LogRecordType::BEGIN); - const auto lsn = log_manager_->append_log_record(log); - txn->set_prev_lsn(lsn); - } - /* Capture Snapshot */ TransactionSnapshot snapshot; snapshot.xmax = next_txn_id_.load(); @@ -48,68 +47,79 @@ Transaction* TransactionManager::begin(IsolationLevel level) { Transaction* const txn_ptr = txn.get(); active_transactions_[txn_id] = std::move(txn); + if (log_manager_ != nullptr) { + recovery::LogRecord record(txn_id, txn_ptr->get_prev_lsn(), recovery::LogRecordType::BEGIN); + const recovery::lsn_t lsn = log_manager_->append_log_record(record); + txn_ptr->set_prev_lsn(lsn); + } + return txn_ptr; } void TransactionManager::commit(Transaction* txn) { - if (txn == nullptr) { - return; - } - if (log_manager_ != nullptr) { - recovery::LogRecord log(txn->get_id(), txn->get_prev_lsn(), - recovery::LogRecordType::COMMIT); - static_cast(log_manager_->append_log_record(log)); + recovery::LogRecord record(txn->get_id(), txn->get_prev_lsn(), + recovery::LogRecordType::COMMIT); + const recovery::lsn_t lsn = log_manager_->append_log_record(record); + txn->set_prev_lsn(lsn); log_manager_->flush(true); } - txn->set_state(TransactionState::COMMITTED); - - /* Release all locks */ - for (const auto& rid : txn->get_shared_locks()) { + const auto lock_set = txn->get_shared_lock_set(); + for (const auto& rid : lock_set) { lock_manager_.unlock(txn, rid); } - for (const auto& rid : txn->get_exclusive_locks()) { + const auto ex_lock_set = txn->get_exclusive_lock_set(); + for (const auto& rid : ex_lock_set) { lock_manager_.unlock(txn, rid); } - const std::scoped_lock lock(manager_latch_); - auto it = active_transactions_.find(txn->get_id()); - if (it != active_transactions_.end()) { - completed_transactions_[txn->get_id()] = std::move(it->second); - static_cast(active_transactions_.erase(it)); + txn->set_state(TransactionState::COMMITTED); + + { + const std::scoped_lock lock(manager_latch_); + completed_transactions_.push_back(std::move(active_transactions_[txn->get_id()])); + static_cast(active_transactions_.erase(txn->get_id())); + + constexpr std::size_t MAX_COMPLETED = 100; + if (completed_transactions_.size() > MAX_COMPLETED) { + completed_transactions_.pop_front(); + } } } void TransactionManager::abort(Transaction* txn) { - if (txn == nullptr) { - return; - } - /* Undo all changes */ undo_transaction(txn); if (log_manager_ != nullptr) { - recovery::LogRecord log(txn->get_id(), txn->get_prev_lsn(), recovery::LogRecordType::ABORT); - static_cast(log_manager_->append_log_record(log)); + recovery::LogRecord record(txn->get_id(), txn->get_prev_lsn(), + recovery::LogRecordType::ABORT); + const recovery::lsn_t lsn = log_manager_->append_log_record(record); + txn->set_prev_lsn(lsn); log_manager_->flush(true); } - txn->set_state(TransactionState::ABORTED); - - /* Release all locks */ - for (const auto& rid : txn->get_shared_locks()) { + const auto lock_set = txn->get_shared_lock_set(); + for (const auto& rid : lock_set) { lock_manager_.unlock(txn, rid); } - for (const auto& rid : txn->get_exclusive_locks()) { + const auto ex_lock_set = txn->get_exclusive_lock_set(); + for (const auto& rid : ex_lock_set) { lock_manager_.unlock(txn, rid); } - const std::scoped_lock lock(manager_latch_); - auto it = active_transactions_.find(txn->get_id()); - if (it != active_transactions_.end()) { - completed_transactions_[txn->get_id()] = std::move(it->second); - static_cast(active_transactions_.erase(it)); + txn->set_state(TransactionState::ABORTED); + + { + const std::scoped_lock lock(manager_latch_); + completed_transactions_.push_back(std::move(active_transactions_[txn->get_id()])); + static_cast(active_transactions_.erase(txn->get_id())); + + constexpr std::size_t MAX_COMPLETED = 100; + if (completed_transactions_.size() > MAX_COMPLETED) { + completed_transactions_.pop_front(); + } } } @@ -149,13 +159,8 @@ void TransactionManager::undo_transaction(Transaction* txn) { Transaction* TransactionManager::get_transaction(txn_id_t txn_id) { const std::scoped_lock lock(manager_latch_); - auto it = active_transactions_.find(txn_id); - if (it != active_transactions_.end()) { - return it->second.get(); - } - auto it_comp = completed_transactions_.find(txn_id); - if (it_comp != completed_transactions_.end()) { - return it_comp->second.get(); + if (active_transactions_.find(txn_id) != active_transactions_.end()) { + return active_transactions_[txn_id].get(); } return nullptr; } diff --git a/tests/buffer_pool_tests.cpp b/tests/buffer_pool_tests.cpp index 9582383e..3462e61a 100644 --- a/tests/buffer_pool_tests.cpp +++ b/tests/buffer_pool_tests.cpp @@ -8,55 +8,156 @@ #include #include #include +#include #include +#include #include "storage/buffer_pool_manager.hpp" +#include "storage/lru_replacer.hpp" #include "storage/page.hpp" #include "storage/storage_manager.hpp" +#include "test_utils.hpp" using namespace cloudsql::storage; namespace { -constexpr size_t HELLO_LEN = 6; +TEST(BufferPoolTests, LRUReplacerBasic) { + LRUReplacer replacer(3); + uint32_t victim_frame = 0; -TEST(BufferPoolTests, Basic) { - const std::string filename = "test.db"; - static_cast(std::remove(filename.c_str())); + replacer.unpin(1); + replacer.unpin(2); + replacer.unpin(3); + EXPECT_EQ(replacer.size(), 3U); - StorageManager disk_manager("."); + EXPECT_TRUE(replacer.victim(&victim_frame)); + EXPECT_EQ(victim_frame, 1U); + EXPECT_EQ(replacer.size(), 2U); + + replacer.unpin(4); + EXPECT_EQ(replacer.size(), 3U); + + EXPECT_TRUE(replacer.victim(&victim_frame)); + EXPECT_EQ(victim_frame, 2U); + EXPECT_EQ(replacer.size(), 2U); + + replacer.pin(3); + EXPECT_EQ(replacer.size(), 1U); + + replacer.unpin(3); + EXPECT_EQ(replacer.size(), 2U); + + EXPECT_TRUE(replacer.victim(&victim_frame)); + EXPECT_EQ(victim_frame, 4U); + + EXPECT_TRUE(replacer.victim(&victim_frame)); + EXPECT_EQ(victim_frame, 3U); + EXPECT_EQ(replacer.size(), 0U); + + EXPECT_FALSE(replacer.victim(&victim_frame)); +} + +TEST(BufferPoolTests, BufferPoolManagerBasic) { + static_cast(std::remove("./test_data/bpm_test.db")); + StorageManager disk_manager("./test_data"); + BufferPoolManager bpm(2, disk_manager); + + const std::string file_name = "bpm_test.db"; + uint32_t page_id0 = 0; + Page* const page0 = bpm.new_page(file_name, &page_id0); + ASSERT_NE(page0, nullptr); + EXPECT_EQ(page_id0, 0U); + + EXPECT_TRUE(bpm.unpin_page(file_name, page_id0, true)); + + Page* const page0_fetch = bpm.fetch_page(file_name, page_id0); + ASSERT_NE(page0_fetch, nullptr); + EXPECT_TRUE(page0_fetch->is_dirty()); + EXPECT_TRUE(bpm.unpin_page(file_name, page_id0, false)); + + uint32_t page_id1 = 1; + Page* const page1 = bpm.new_page(file_name, &page_id1); + EXPECT_NE(page1, nullptr); + + uint32_t page_id2 = 2; + Page* const page2 = bpm.new_page(file_name, &page_id2); + EXPECT_NE(page2, nullptr); + + uint32_t page_id3 = 3; + Page* const page3 = bpm.new_page(file_name, &page_id3); + EXPECT_EQ(page3, nullptr); + + bpm.unpin_page(file_name, page_id1, false); + bpm.unpin_page(file_name, page_id2, true); + + Page* const page3_new = bpm.new_page(file_name, &page_id3); + EXPECT_NE(page3_new, nullptr); + + EXPECT_TRUE(bpm.flush_page(file_name, page_id3)); + bpm.flush_all_pages(); + bpm.unpin_page(file_name, page_id3, false); + + EXPECT_TRUE(bpm.delete_page(file_name, page_id2)); +} + +TEST(BufferPoolTests, BufferPoolManagerEviction) { + static_cast(std::remove("./test_data/bpm_eviction.db")); + StorageManager disk_manager("./test_data"); BufferPoolManager bpm(3, disk_manager); + const std::string file = "bpm_eviction.db"; - EXPECT_TRUE(bpm.open_file(filename)); + uint32_t id1 = 1; + uint32_t id2 = 2; + uint32_t id3 = 3; + uint32_t id4 = 4; + Page* const p1 = bpm.new_page(file, &id1); + Page* const p2 = bpm.new_page(file, &id2); + Page* const p3 = bpm.new_page(file, &id3); - const uint32_t page_id1 = 0; - Page* const page1 = bpm.new_page(filename, &page_id1); - ASSERT_NE(page1, nullptr); - EXPECT_EQ(page_id1, 0); + EXPECT_NE(p1, nullptr); + EXPECT_NE(p2, nullptr); + EXPECT_NE(p3, nullptr); + + bpm.unpin_page(file, id1, true); + bpm.unpin_page(file, id2, false); + bpm.unpin_page(file, id3, false); + + Page* const p4 = bpm.new_page(file, &id4); + EXPECT_NE(p4, nullptr); + + bpm.unpin_page(file, id4, false); + + Page* const p1_fetch = bpm.fetch_page(file, id1); + EXPECT_NE(p1_fetch, nullptr); + bpm.unpin_page(file, id1, false); + + EXPECT_TRUE(bpm.delete_page(file, id1)); + EXPECT_TRUE(bpm.delete_page(file, id2)); + EXPECT_TRUE(bpm.delete_page(file, id3)); + EXPECT_TRUE(bpm.delete_page(file, id4)); +} - std::memcpy(page1->get_data(), "Hello", HELLO_LEN); - bpm.unpin_page(filename, page_id1, true); +TEST(BufferPoolTests, BufferPoolManagerEdgeCases) { + static_cast(std::remove("./test_data/bpm_edge.db")); + StorageManager disk_manager("./test_data"); + BufferPoolManager bpm(1, disk_manager); + const std::string file = "bpm_edge.db"; - const uint32_t page_id2 = 1; - const Page* const page2 = bpm.new_page(filename, &page_id2); - ASSERT_NE(page2, nullptr); - EXPECT_EQ(page_id2, 1); - bpm.unpin_page(filename, page_id2, false); + EXPECT_FALSE(bpm.unpin_page(file, 999, false)); + EXPECT_FALSE(bpm.flush_page(file, 999)); + EXPECT_TRUE(bpm.delete_page(file, 999)); - const uint32_t page_id3 = 2; - const Page* const page3 = bpm.new_page(filename, &page_id3); - ASSERT_NE(page3, nullptr); - EXPECT_EQ(page_id3, 2); - bpm.unpin_page(filename, page_id3, false); + uint32_t id = 1; + Page* const p = bpm.new_page(file, &id); + ASSERT_NE(p, nullptr); + EXPECT_FALSE(bpm.delete_page(file, id)); // Pinned - // Fetch page 1 again - Page* const page1_fetch = bpm.fetch_page(filename, page_id1); - ASSERT_NE(page1_fetch, nullptr); - EXPECT_STREQ(page1_fetch->get_data(), "Hello"); - bpm.unpin_page(filename, page_id1, false); + // new page again with same ID + Page* const p_dup = bpm.new_page(file, &id); + EXPECT_EQ(p_dup, nullptr); - static_cast(bpm.close_file(filename)); - static_cast(std::remove(filename.c_str())); + bpm.unpin_page(file, id, false); } } // namespace diff --git a/tests/cloudSQL_tests.cpp b/tests/cloudSQL_tests.cpp index 0c01ad84..56378240 100644 --- a/tests/cloudSQL_tests.cpp +++ b/tests/cloudSQL_tests.cpp @@ -1,12 +1,13 @@ /** * @file cloudSQL_tests.cpp - * @brief Comprehensive test suite for cloudSQL C++ implementation + * @brief Comprehensive test suite for cloudSQL implementation */ #include #include #include +#include #include #include #include @@ -17,9 +18,12 @@ #include "common/value.hpp" #include "executor/query_executor.hpp" #include "executor/types.hpp" +#include "parser/expression.hpp" #include "parser/lexer.hpp" #include "parser/parser.hpp" #include "parser/statement.hpp" +#include "parser/token.hpp" +#include "storage/btree_index.hpp" #include "storage/buffer_pool_manager.hpp" #include "storage/heap_table.hpp" #include "storage/storage_manager.hpp" @@ -38,20 +42,27 @@ namespace { constexpr int64_t VAL_42 = 42; constexpr double PI_LOWER = 3.14; constexpr double PI_UPPER = 3.15; +constexpr int64_t VAL_1 = 1; +constexpr int64_t VAL_2 = 2; constexpr int64_t VAL_10 = 10; +constexpr int64_t VAL_20 = 20; constexpr int64_t VAL_25 = 25; constexpr uint64_t STATS_100 = 100; constexpr uint16_t PORT_5432 = 5432; constexpr uint16_t PORT_9999 = 9999; +constexpr int64_t VAL_123 = 123; +constexpr double VAL_1_5 = 1.5; +constexpr oid_t TABLE_9999 = 9999; +constexpr oid_t INDEX_8888 = 8888; // ============= Value Tests ============= -TEST(ValueTests, Basic) { +TEST(CloudSQLTests, ValueBasic) { const auto val = Value::make_int64(VAL_42); EXPECT_EQ(val.to_int64(), VAL_42); } -TEST(ValueTests, TypeVariety) { +TEST(CloudSQLTests, ValueTypeVariety) { const Value b(true); EXPECT_TRUE(b.as_bool()); EXPECT_STREQ(b.to_string().c_str(), "TRUE"); @@ -66,17 +77,17 @@ TEST(ValueTests, TypeVariety) { // ============= Parser Tests ============= -TEST(ParserTests, Expressions) { +TEST(CloudSQLTests, ParserExpressions) { auto lexer = std::make_unique("SELECT 1 + 2 * 3 FROM dual"); Parser parser(std::move(lexer)); auto stmt = parser.parse_statement(); - ASSERT_NE(stmt, nullptr); + EXPECT_TRUE(stmt != nullptr); const auto* const select = dynamic_cast(stmt.get()); ASSERT_NE(select, nullptr); EXPECT_STREQ(select->columns()[0]->to_string().c_str(), "1 + 2 * 3"); } -TEST(ParserTests, ComplexExpressions) { +TEST(CloudSQLTests, ExpressionComplex) { { auto lexer = std::make_unique("SELECT (1 > 0 AND 5 <= 2) OR NOT (1 = 1) FROM dual"); Parser parser(std::move(lexer)); @@ -99,7 +110,7 @@ TEST(ParserTests, ComplexExpressions) { } } -TEST(ParserTests, SelectVariants) { +TEST(CloudSQLTests, ParserSelectVariants) { auto lexer = std::make_unique("SELECT DISTINCT name FROM users LIMIT 10 OFFSET 20"); Parser parser(std::move(lexer)); auto stmt = parser.parse_statement(); @@ -108,12 +119,21 @@ TEST(ParserTests, SelectVariants) { ASSERT_NE(select, nullptr); EXPECT_TRUE(select->distinct()); EXPECT_EQ(select->limit(), VAL_10); - EXPECT_EQ(select->offset(), 20); + EXPECT_EQ(select->offset(), VAL_20); +} + +TEST(CloudSQLTests, ParserErrors) { + { + auto lexer = std::make_unique("SELECT FROM users"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + EXPECT_TRUE(stmt == nullptr); + } } // ============= Catalog Tests ============= -TEST(CatalogTests, FullLifecycle) { +TEST(CloudSQLTests, CatalogFullLifecycle) { auto catalog = Catalog::create(); const std::vector cols = {{"id", ValueType::TYPE_INT64, 0}, @@ -122,6 +142,7 @@ TEST(CatalogTests, FullLifecycle) { const oid_t table_id = catalog->create_table("test_table", cols); EXPECT_TRUE(table_id > 0); EXPECT_TRUE(catalog->table_exists(table_id)); + EXPECT_TRUE(catalog->table_exists_by_name("test_table")); auto table = catalog->get_table(table_id); EXPECT_TRUE(table.has_value()); @@ -140,7 +161,7 @@ TEST(CatalogTests, FullLifecycle) { // ============= Config Tests ============= -TEST(ConfigTests, Basic) { +TEST(CloudSQLTests, ConfigBasic) { config::Config cfg; EXPECT_EQ(cfg.port, PORT_5432); @@ -162,9 +183,10 @@ TEST(ConfigTests, Basic) { // ============= Storage Tests ============= -TEST(StorageTests, Persistence) { +TEST(CloudSQLTests, StoragePersistence) { const std::string filename = "persist_test"; - static_cast(std::remove("./test_data/persist_test.heap")); + const std::string filepath = "./test_data/" + filename + ".heap"; + static_cast(std::remove(filepath.c_str())); Schema schema; schema.add_column("data", ValueType::TYPE_TEXT); { @@ -173,6 +195,7 @@ TEST(StorageTests, Persistence) { HeapTable table(filename, sm, schema); static_cast(table.create()); static_cast(table.insert(Tuple({Value::make_text("Persistent data")}))); + sm.flush_all_pages(); } { StorageManager disk_manager("./test_data"); @@ -183,42 +206,576 @@ TEST(StorageTests, Persistence) { EXPECT_TRUE(iter.next(t)); EXPECT_STREQ(t.get(0).as_text().c_str(), "Persistent data"); } + static_cast(std::remove(filepath.c_str())); +} + +TEST(CloudSQLTests, StorageDelete) { + const std::string filename = "delete_test"; + const std::string filepath = "./test_data/" + filename + ".heap"; + static_cast(std::remove(filepath.c_str())); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + Schema schema; + schema.add_column("id", ValueType::TYPE_INT64); + HeapTable table(filename, sm, schema); + EXPECT_TRUE(table.create()); + + static_cast(table.insert(Tuple({Value::make_int64(VAL_1)}))); + const auto tid2 = table.insert(Tuple({Value::make_int64(VAL_2)})); + + EXPECT_EQ(table.tuple_count(), 2U); + EXPECT_TRUE(table.remove(tid2, 100)); // Logically delete with xmax=100 + EXPECT_EQ(table.tuple_count(), 1U); + + auto iter = table.scan(); + Tuple t; + EXPECT_TRUE(iter.next(t)); + EXPECT_EQ(t.get(0).to_int64(), 1); + EXPECT_FALSE(iter.next(t)); + static_cast(std::remove(filepath.c_str())); +} + +// ============= Index Tests ============= + +TEST(IndexTests, BTreeBasic) { + static_cast(std::remove("./test_data/idx_test.idx")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + BTreeIndex idx("idx_test", sm, ValueType::TYPE_INT64); + static_cast(idx.create()); + static_cast(idx.insert(Value::make_int64(VAL_10), HeapTable::TupleId(1, 1))); + static_cast(idx.insert(Value::make_int64(VAL_20), HeapTable::TupleId(1, 2))); + static_cast(idx.insert(Value::make_int64(VAL_10), HeapTable::TupleId(2, 1))); + const auto res = idx.search(Value::make_int64(VAL_10)); + EXPECT_EQ(res.size(), 2U); + static_cast(idx.drop()); +} + +TEST(IndexTests, Scan) { + static_cast(std::remove("./test_data/scan_test.idx")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + BTreeIndex idx("scan_test", sm, ValueType::TYPE_INT64); + static_cast(idx.create()); + static_cast(idx.insert(Value::make_int64(VAL_1), HeapTable::TupleId(1, 1))); + static_cast(idx.insert(Value::make_int64(VAL_2), HeapTable::TupleId(1, 2))); + + auto iter = idx.scan(); + BTreeIndex::Entry entry; + EXPECT_TRUE(iter.next(entry)); + EXPECT_EQ(entry.key.to_int64(), 1); + EXPECT_TRUE(iter.next(entry)); + EXPECT_EQ(entry.key.to_int64(), 2); + EXPECT_FALSE(iter.next(entry)); + static_cast(idx.drop()); } // ============= Execution Tests ============= TEST(ExecutionTests, EndToEnd) { - static_cast(std::remove("./test_data/users.heap")); + static_cast(std::remove("./test_data/users_e2e.heap")); StorageManager disk_manager("./test_data"); BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); auto catalog = Catalog::create(); LockManager lm; - TransactionManager tm(lm, *catalog, sm); + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); QueryExecutor exec(*catalog, sm, lm, tm); { - auto lexer = std::make_unique("CREATE TABLE users (id BIGINT, age BIGINT)"); + auto lexer = std::make_unique("CREATE TABLE users_e2e (id BIGINT, age BIGINT)"); auto stmt = Parser(std::move(lexer)).parse_statement(); - ASSERT_NE(stmt, nullptr); const auto res = exec.execute(*stmt); EXPECT_TRUE(res.success()); } { - auto lexer = - std::make_unique("INSERT INTO users (id, age) VALUES (1, 20), (2, 30), (3, 40)"); + auto lexer = std::make_unique( + "INSERT INTO users_e2e (id, age) VALUES (1, 20), (2, 30), (3, 40)"); auto stmt = Parser(std::move(lexer)).parse_statement(); - ASSERT_NE(stmt, nullptr); const auto res = exec.execute(*stmt); EXPECT_TRUE(res.success()); } { - auto lexer = std::make_unique("SELECT id FROM users WHERE age > 25"); + auto lexer = std::make_unique("SELECT id FROM users_e2e WHERE age > 25"); auto stmt = Parser(std::move(lexer)).parse_statement(); - ASSERT_NE(stmt, nullptr); const auto res = exec.execute(*stmt); EXPECT_TRUE(res.success()); - EXPECT_EQ(res.row_count(), 2); + EXPECT_EQ(res.row_count(), 2U); } + static_cast(std::remove("./test_data/users_e2e.heap")); +} + +TEST(ExecutionTests, Sort) { + static_cast(std::remove("./test_data/sort_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE sort_test (val INT)")).parse_statement())); + static_cast(exec.execute( + *Parser(std::make_unique("INSERT INTO sort_test VALUES (30), (10), (20)")) + .parse_statement())); + + const auto res = + exec.execute(*Parser(std::make_unique("SELECT val FROM sort_test ORDER BY val")) + .parse_statement()); + ASSERT_EQ(res.row_count(), 3U); + EXPECT_STREQ(res.rows()[0].get(0).to_string().c_str(), "10"); + EXPECT_STREQ(res.rows()[1].get(0).to_string().c_str(), "20"); + EXPECT_STREQ(res.rows()[2].get(0).to_string().c_str(), "30"); + static_cast(std::remove("./test_data/sort_test.heap")); +} + +TEST(ExecutionTests, Aggregate) { + static_cast(std::remove("./test_data/agg_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast( + exec.execute(*Parser(std::make_unique("CREATE TABLE agg_test (cat TEXT, val INT)")) + .parse_statement())); + static_cast( + exec.execute(*Parser(std::make_unique( + "INSERT INTO agg_test VALUES ('A', 10), ('A', 20), ('B', 5)")) + .parse_statement())); + + auto lex = + std::make_unique("SELECT cat, COUNT(val), SUM(val) FROM agg_test GROUP BY cat"); + auto stmt = Parser(std::move(lex)).parse_statement(); + ASSERT_NE(stmt, nullptr); + + const auto res = exec.execute(*stmt); + EXPECT_TRUE(res.success()); + + ASSERT_EQ(res.row_count(), 2U); + /* Row 0: 'A', 2, 30 */ + EXPECT_STREQ(res.rows()[0].get(0).to_string().c_str(), "A"); + EXPECT_STREQ(res.rows()[0].get(1).to_string().c_str(), "2"); + EXPECT_STREQ(res.rows()[0].get(2).to_string().c_str(), "30"); + static_cast(std::remove("./test_data/agg_test.heap")); +} + +TEST(ExecutionTests, AggregateAdvanced) { + static_cast(std::remove("./test_data/adv_agg.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE adv_agg (val INT)")).parse_statement())); + static_cast( + exec.execute(*Parser(std::make_unique("INSERT INTO adv_agg VALUES (10), (20), (30)")) + .parse_statement())); + + const auto res = exec.execute( + *Parser(std::make_unique("SELECT MIN(val), MAX(val), AVG(val) FROM adv_agg")) + .parse_statement()); + EXPECT_TRUE(res.success()); + + ASSERT_EQ(res.row_count(), 1U); + EXPECT_STREQ(res.rows()[0].get(0).to_string().c_str(), "10"); + EXPECT_STREQ(res.rows()[0].get(1).to_string().c_str(), "30"); + EXPECT_STREQ(res.rows()[0].get(2).to_string().c_str(), "20"); + static_cast(std::remove("./test_data/adv_agg.heap")); +} + +TEST(ExecutionTests, AggregateDistinct) { + static_cast(std::remove("./test_data/dist_agg.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE dist_agg (val INT)")).parse_statement())); + static_cast( + exec.execute(*Parser(std::make_unique( + "INSERT INTO dist_agg VALUES (10), (10), (20), (30), (30), (30)")) + .parse_statement())); + + const auto res = + exec.execute(*Parser(std::make_unique( + "SELECT COUNT(DISTINCT val), SUM(DISTINCT val) FROM dist_agg")) + .parse_statement()); + EXPECT_TRUE(res.success()); + + ASSERT_EQ(res.row_count(), 1U); + EXPECT_STREQ(res.rows()[0].get(0).to_string().c_str(), "3"); + EXPECT_STREQ(res.rows()[0].get(1).to_string().c_str(), "60"); + static_cast(std::remove("./test_data/dist_agg.heap")); +} + +TEST(ExecutionTests, Transaction) { + static_cast(std::remove("./test_data/txn_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + + QueryExecutor qexec1(*catalog, sm, lm, tm); + static_cast( + qexec1.execute(*Parser(std::make_unique("CREATE TABLE txn_test (id INT, val INT)")) + .parse_statement())); + + static_cast(qexec1.execute(*Parser(std::make_unique("BEGIN")).parse_statement())); + static_cast( + qexec1.execute(*Parser(std::make_unique("INSERT INTO txn_test VALUES (1, 100)")) + .parse_statement())); + + QueryExecutor qexec2(*catalog, sm, lm, tm); + + const auto res_commit = + qexec1.execute(*Parser(std::make_unique("COMMIT")).parse_statement()); + EXPECT_TRUE(res_commit.success()); + + const auto res_select = + qexec2.execute(*Parser(std::make_unique("SELECT val FROM txn_test WHERE id = 1")) + .parse_statement()); + ASSERT_EQ(res_select.row_count(), 1U); + EXPECT_STREQ(res_select.rows()[0].get(0).to_string().c_str(), "100"); + static_cast(std::remove("./test_data/txn_test.heap")); +} + +TEST(ExecutionTests, Rollback) { + static_cast(std::remove("./test_data/rollback_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast( + exec.execute(*Parser(std::make_unique("CREATE TABLE rollback_test (val INT)")) + .parse_statement())); + + static_cast(exec.execute(*Parser(std::make_unique("BEGIN")).parse_statement())); + static_cast( + exec.execute(*Parser(std::make_unique("INSERT INTO rollback_test VALUES (100)")) + .parse_statement())); + + const auto res_internal = exec.execute( + *Parser(std::make_unique("SELECT val FROM rollback_test")).parse_statement()); + EXPECT_EQ(res_internal.row_count(), 1U); + + static_cast(exec.execute(*Parser(std::make_unique("ROLLBACK")).parse_statement())); + + const auto res_after = exec.execute( + *Parser(std::make_unique("SELECT val FROM rollback_test")).parse_statement()); + EXPECT_EQ(res_after.row_count(), 0U); + static_cast(std::remove("./test_data/rollback_test.heap")); +} + +TEST(ExecutionTests, UpdateDelete) { + static_cast(std::remove("./test_data/upd_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast( + exec.execute(*Parser(std::make_unique("CREATE TABLE upd_test (id INT, val TEXT)")) + .parse_statement())); + static_cast(exec.execute( + *Parser(std::make_unique("INSERT INTO upd_test VALUES (1, 'old'), (2, 'stay')")) + .parse_statement())); + + /* Test UPDATE */ + const auto res_upd = exec.execute( + *Parser(std::make_unique("UPDATE upd_test SET val = 'new' WHERE id = 1")) + .parse_statement()); + EXPECT_EQ(res_upd.rows_affected(), 1U); + + const auto res_sel = + exec.execute(*Parser(std::make_unique("SELECT val FROM upd_test WHERE id = 1")) + .parse_statement()); + ASSERT_EQ(res_sel.row_count(), 1U); + EXPECT_STREQ(res_sel.rows()[0].get(0).to_string().c_str(), "new"); + + /* Test DELETE */ + const auto res_del = exec.execute( + *Parser(std::make_unique("DELETE FROM upd_test WHERE id = 2")).parse_statement()); + EXPECT_EQ(res_del.rows_affected(), 1U); + + const auto res_sel2 = + exec.execute(*Parser(std::make_unique("SELECT id FROM upd_test")).parse_statement()); + EXPECT_EQ(res_sel2.row_count(), 1U); // Only ID 1 remains + static_cast(std::remove("./test_data/upd_test.heap")); +} + +TEST(ExecutionTests, MVCC) { + static_cast(std::remove("./test_data/mvcc_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + + QueryExecutor qexec1(*catalog, sm, lm, tm); + static_cast(qexec1.execute( + *Parser(std::make_unique("CREATE TABLE mvcc_test (val INT)")).parse_statement())); + + /* Start T1 and Insert */ + static_cast(qexec1.execute(*Parser(std::make_unique("BEGIN")).parse_statement())); + static_cast(qexec1.execute( + *Parser(std::make_unique("INSERT INTO mvcc_test VALUES (10)")).parse_statement())); + + /* Session 2 should see nothing yet (atomic snapshot) */ + QueryExecutor qexec2(*catalog, sm, lm, tm); + const auto res2_pre = qexec2.execute( + *Parser(std::make_unique("SELECT val FROM mvcc_test")).parse_statement()); + EXPECT_EQ(res2_pre.row_count(), 0U); + + /* T1 updates row */ + static_cast(qexec1.execute( + *Parser(std::make_unique("UPDATE mvcc_test SET val = 20")).parse_statement())); + + /* T1 sees new value */ + const auto res1 = qexec1.execute( + *Parser(std::make_unique("SELECT val FROM mvcc_test")).parse_statement()); + ASSERT_EQ(res1.row_count(), 1U); + EXPECT_STREQ(res1.rows()[0].get(0).to_string().c_str(), "20"); + + static_cast(qexec1.execute(*Parser(std::make_unique("COMMIT")).parse_statement())); + + /* After commit, Session 2 sees the latest value */ + const auto res2_post = qexec2.execute( + *Parser(std::make_unique("SELECT val FROM mvcc_test")).parse_statement()); + ASSERT_EQ(res2_post.row_count(), 1U); + EXPECT_STREQ(res2_post.rows()[0].get(0).to_string().c_str(), "20"); + static_cast(std::remove("./test_data/mvcc_test.heap")); +} + +TEST(ExecutionTests, Join) { + static_cast(std::remove("./test_data/users_join.heap")); + static_cast(std::remove("./test_data/orders_join.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast( + exec.execute(*Parser(std::make_unique("CREATE TABLE users_join (id INT, name TEXT)")) + .parse_statement())); + static_cast(exec.execute(*Parser(std::make_unique("CREATE TABLE orders_join (id " + "INT, user_id INT, amount " + "DOUBLE)")) + .parse_statement())); + + static_cast(exec.execute( + *Parser(std::make_unique("INSERT INTO users_join VALUES (1, 'Alice'), (2, 'Bob')")) + .parse_statement())); + static_cast(exec.execute(*Parser(std::make_unique("INSERT INTO orders_join VALUES " + "(101, 1, 50.5), (102, 1, " + "25.0), (103, 2, 100.0)")) + .parse_statement())); + + /* Test: INNER JOIN with sorting */ + const auto result = exec.execute( + *Parser(std::make_unique("SELECT users_join.name, orders_join.amount FROM " + "users_join JOIN orders_join " + "ON users_join.id = orders_join.user_id ORDER BY " + "orders_join.amount")) + .parse_statement()); + + ASSERT_EQ(result.row_count(), 3U); + + /* 25.0 (Alice), 50.5 (Alice), 100.0 (Bob) */ + EXPECT_STREQ(result.rows()[0].get(0).to_string().c_str(), "Alice"); + EXPECT_STREQ(result.rows()[0].get(1).to_string().c_str(), "25"); + EXPECT_STREQ(result.rows()[2].get(0).to_string().c_str(), "Bob"); + EXPECT_STREQ(result.rows()[2].get(1).to_string().c_str(), "100"); + static_cast(std::remove("./test_data/users_join.heap")); + static_cast(std::remove("./test_data/orders_join.heap")); +} + +TEST(ExecutionTests, DDL) { + static_cast(std::remove("./test_data/ddl_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + /* 1. Create and then Drop Table */ + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE ddl_test (id INT)")).parse_statement())); + EXPECT_TRUE(catalog->table_exists_by_name("ddl_test")); + + const auto res_drop = + exec.execute(*Parser(std::make_unique("DROP TABLE ddl_test")).parse_statement()); + EXPECT_TRUE(res_drop.success()); + EXPECT_FALSE(catalog->table_exists_by_name("ddl_test")); + + /* 2. IF EXISTS */ + const auto res_drop_none = exec.execute( + *Parser(std::make_unique("DROP TABLE IF EXISTS non_existent")).parse_statement()); + EXPECT_TRUE(res_drop_none.success()); + + /* 3. Create Index and then Drop Index */ + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE ddl_test (id INT)")).parse_statement())); + auto table_opt = catalog->get_table_by_name("ddl_test"); + if (table_opt) { + const oid_t tid = (*table_opt)->table_id; + static_cast(catalog->create_index("idx_ddl", tid, {0}, IndexType::BTree, true)); + } + + const auto res_drop_idx = + exec.execute(*Parser(std::make_unique("DROP INDEX idx_ddl")).parse_statement()); + EXPECT_TRUE(res_drop_idx.success()); + static_cast(std::remove("./test_data/ddl_test.heap")); +} + +TEST(LexerTests, Advanced) { + /* 1. Test comments and line tracking */ + { + const std::string sql = "SELECT -- comment here\n* FROM users"; + Lexer lexer(sql); + const auto t1 = lexer.next_token(); + EXPECT_EQ(static_cast(t1.type()), static_cast(TokenType::Select)); + const auto t2 = lexer.next_token(); // Should skip comment and newline + EXPECT_STREQ(t2.lexeme().c_str(), "*"); + EXPECT_EQ(t2.line(), 2U); + } + /* 2. Test Error and Unknown operators */ + { + Lexer lexer("@"); + const auto t = lexer.next_token(); + EXPECT_EQ(static_cast(t.type()), static_cast(TokenType::Error)); + } +} + +TEST(ExecutionTests, Expressions) { + static_cast(std::remove("./test_data/expr_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm, sm.get_log_manager()); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE expr_test (id INT, val DOUBLE, str TEXT)")) + .parse_statement())); + static_cast(exec.execute( + *Parser(std::make_unique( + "INSERT INTO expr_test VALUES (1, 10.5, 'A'), (2, NULL, 'B'), (3, 20.0, 'C')")) + .parse_statement())); + + /* 1. Test IS NULL / IS NOT NULL */ + { + const auto res = exec.execute( + *Parser(std::make_unique("SELECT id FROM expr_test WHERE val IS NULL")) + .parse_statement()); + ASSERT_EQ(res.row_count(), 1U); + EXPECT_EQ(res.rows()[0].get(0).to_int64(), 2); + + const auto res2 = exec.execute( + *Parser(std::make_unique("SELECT id FROM expr_test WHERE val IS NOT NULL")) + .parse_statement()); + EXPECT_EQ(res2.row_count(), 2U); + } + + /* 2. Test IN / NOT IN */ + { + const auto res = exec.execute( + *Parser(std::make_unique("SELECT id FROM expr_test WHERE id IN (1, 3)")) + .parse_statement()); + EXPECT_EQ(res.row_count(), 2U); + + const auto res2 = exec.execute( + *Parser(std::make_unique("SELECT id FROM expr_test WHERE str NOT IN ('A', 'C')")) + .parse_statement()); + ASSERT_EQ(res2.row_count(), 1U); + EXPECT_EQ(res2.rows()[0].get(0).to_int64(), 2); + } + + /* 3. Test Arithmetic and Complex Binary */ + { + const auto res = exec.execute( + *Parser(std::make_unique( + "SELECT id, val * 2 + 10, val / 2, val - 5 FROM expr_test WHERE id = 1")) + .parse_statement()); + ASSERT_EQ(res.row_count(), 1U); + EXPECT_DOUBLE_EQ(res.rows()[0].get(1).to_float64(), 31.0); + EXPECT_DOUBLE_EQ(res.rows()[0].get(2).to_float64(), 5.25); + EXPECT_DOUBLE_EQ(res.rows()[0].get(3).to_float64(), 5.5); + } + static_cast(std::remove("./test_data/expr_test.heap")); +} + +TEST(CloudSQLTests, ExpressionTypes) { + /* Test ConstantExpr with various types for coverage */ + { + const ConstantExpr c_bool(Value::make_bool(true)); + EXPECT_TRUE(c_bool.evaluate().as_bool()); + + const ConstantExpr c_int(Value::make_int64(VAL_123)); + EXPECT_EQ(c_int.evaluate().to_int64(), VAL_123); + + const ConstantExpr c_float(Value::make_float64(VAL_1_5)); + EXPECT_DOUBLE_EQ(c_float.evaluate().to_float64(), VAL_1_5); + + const ConstantExpr c_null(Value::make_null()); + EXPECT_TRUE(c_null.evaluate().is_null()); + } +} + +TEST(CatalogTests, Errors) { + auto catalog = Catalog::create(); + const std::vector cols = {{"id", ValueType::TYPE_INT64, 0}}; + + static_cast(catalog->create_table("fail_test", cols)); + /* Duplicate table */ + EXPECT_THROW(catalog->create_table("fail_test", cols), std::exception); + + /* Missing table */ + EXPECT_FALSE(catalog->table_exists(TABLE_9999)); + EXPECT_FALSE(catalog->get_table(TABLE_9999).has_value()); + EXPECT_FALSE(catalog->table_exists_by_name("non_existent")); + + /* Duplicate index */ + const oid_t tid = catalog->create_table("idx_fail", cols); + static_cast(catalog->create_index("my_idx", tid, {0}, IndexType::BTree, true)); + EXPECT_THROW(catalog->create_index("my_idx", tid, {0}, IndexType::BTree, true), std::exception); + + /* Missing index */ + EXPECT_FALSE(catalog->get_index(INDEX_8888).has_value()); + EXPECT_FALSE(catalog->drop_index(INDEX_8888)); +} + +TEST(CatalogTests, Stats) { + auto catalog = Catalog::create(); + const std::vector cols = {{"id", ValueType::TYPE_INT64, 0}}; + const oid_t tid = catalog->create_table("stats_test", cols); + + EXPECT_TRUE(catalog->update_table_stats(tid, 500U)); + auto tinfo = catalog->get_table(tid); + if (tinfo) { + EXPECT_EQ((*tinfo)->num_rows, 500U); + } + + /* Cover print() */ + catalog->print(); } } // namespace diff --git a/tests/cloudSQL_tests.cpp.old b/tests/cloudSQL_tests.cpp.old new file mode 100644 index 00000000..5d9f3c57 --- /dev/null +++ b/tests/cloudSQL_tests.cpp.old @@ -0,0 +1,925 @@ +/** + * @file cloudSQL_tests.cpp + * @brief Comprehensive test suite for cloudSQL C++ implementation + */ + +#include // IWYU pragma: keep +#include // IWYU pragma: keep +#include // IWYU pragma: keep +#include // IWYU pragma: keep +#include // IWYU pragma: keep + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "catalog/catalog.hpp" +#include "common/config.hpp" +#include "common/value.hpp" +#include "executor/query_executor.hpp" +#include "executor/types.hpp" +#include "network/server.hpp" // IWYU pragma: keep +#include "parser/expression.hpp" +#include "parser/lexer.hpp" +#include "parser/parser.hpp" +#include "parser/statement.hpp" +#include "parser/token.hpp" +#include "storage/btree_index.hpp" +#include "storage/buffer_pool_manager.hpp" +#include "storage/heap_table.hpp" +#include "storage/storage_manager.hpp" +#include "test_utils.hpp" +#include "transaction/lock_manager.hpp" +#include "transaction/transaction_manager.hpp" + +using namespace cloudsql; +using namespace cloudsql::common; +using namespace cloudsql::parser; +using namespace cloudsql::executor; +using namespace cloudsql::storage; +using namespace cloudsql::transaction; + +namespace { + +// Using common test counters +using cloudsql::tests::tests_failed; +using cloudsql::tests::tests_passed; + +constexpr int64_t VAL_42 = 42; +constexpr double PI_LOWER = 3.14; +constexpr double PI_UPPER = 3.15; +constexpr int64_t VAL_10 = 10; +constexpr int64_t VAL_25 = 25; +constexpr uint64_t STATS_100 = 100; +constexpr uint16_t PORT_9999 = 9999; +constexpr uint64_t XMAX_100 = 100; +constexpr int64_t BTREE_VAL_10 = 10; +constexpr int64_t BTREE_VAL_20 = 20; +constexpr uint64_t STATS_500 = 500; + +constexpr uint16_t PORT_5432 = 5432; +constexpr int64_t VAL_123 = 123; +constexpr double VAL_1_5 = 1.5; +constexpr oid_t TABLE_9999 = 9999; +constexpr oid_t INDEX_8888 = 8888; + +// ============= Value Tests ============= + +TEST(ValueTest_Basic) { + const auto val = Value::make_int64(VAL_42); + EXPECT_EQ(val.to_int64(), VAL_42); +} + +TEST(ValueTest_TypeVariety) { + const Value b(true); + EXPECT_TRUE(b.as_bool()); + EXPECT_STREQ(b.to_string(), "TRUE"); + + const Value f(3.14159); + EXPECT_GT(f.as_float64(), PI_LOWER); + EXPECT_LT(f.as_float64(), PI_UPPER); + + const Value s("cloudSQL"); + EXPECT_STREQ(s.as_text(), "cloudSQL"); +} + +// ============= Parser Tests ============= + +TEST(ParserTest_Expressions) { + { + auto lexer = std::make_unique("SELECT 1 + 2 * 3 FROM dual"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + EXPECT_TRUE(stmt != nullptr); + const auto* const select = dynamic_cast(stmt.get()); + EXPECT_STREQ(select->columns()[0]->to_string(), "1 + 2 * 3"); + } +} + +TEST(ExpressionTest_Complex) { + { + auto lexer = std::make_unique("SELECT (1 > 0 AND 5 <= 2) OR NOT (1 = 1) FROM dual"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + if (!stmt) { + throw std::runtime_error("ExpressionTest_Complex: Parser failed on query 1"); + } + const auto* const select = dynamic_cast(stmt.get()); + const auto val = select->columns()[0]->evaluate(); + EXPECT_FALSE(val.as_bool()); + } + { + auto lexer = std::make_unique("SELECT -10 + 20, 5 * (2 + 3) FROM dual"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + if (!stmt) { + throw std::runtime_error("ExpressionTest_Complex: Parser failed on query 2"); + } + const auto* const select = dynamic_cast(stmt.get()); + EXPECT_EQ(select->columns()[0]->evaluate().to_int64(), VAL_10); + EXPECT_EQ(select->columns()[1]->evaluate().to_int64(), VAL_25); + } + { + auto lexer = std::make_unique("SELECT 5.5 FROM dual"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + if (!stmt) { + throw std::runtime_error("ExpressionTest_Complex: Parser failed on query 3a"); + } + const auto* const select = dynamic_cast(stmt.get()); + EXPECT_TRUE( + select->columns()[0]->evaluate().to_float64() == + 5.5); // NOLINT(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + } + { + auto lexer = std::make_unique("SELECT 10 / 2 FROM dual"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + if (!stmt) { + throw std::runtime_error("ExpressionTest_Complex: Parser failed on query 3b"); + } + const auto* const select = dynamic_cast(stmt.get()); + EXPECT_TRUE( + select->columns()[0]->evaluate().to_float64() == + 5.0); // NOLINT(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + } +} + +TEST(ParserTest_SelectVariants) { + { + auto lexer = std::make_unique("SELECT DISTINCT name FROM users LIMIT 10 OFFSET 20"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + const auto* const select = dynamic_cast(stmt.get()); + EXPECT_TRUE(select->distinct()); + EXPECT_EQ(select->limit(), VAL_10); + EXPECT_EQ( + select->offset(), + static_cast( + 20)); // NOLINT(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + } + { + auto lexer = + std::make_unique("SELECT age, cnt FROM users GROUP BY age ORDER BY age"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + const auto* const select = dynamic_cast(stmt.get()); + EXPECT_EQ(select->group_by().size(), static_cast(1)); + EXPECT_EQ(select->order_by().size(), static_cast(1)); + } +} + +TEST(ParserTest_Errors) { + { + auto lexer = std::make_unique("SELECT FROM users"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + EXPECT_TRUE(stmt == nullptr); + } +} + +// ============= Catalog Tests ============= + +TEST(CatalogTest_FullLifecycle) { + auto catalog = Catalog::create(); + + const std::vector cols = {{"id", ValueType::TYPE_INT64, 0}, + {"name", ValueType::TYPE_TEXT, 1}}; + + const oid_t table_id = catalog->create_table("test_table", cols); + EXPECT_TRUE(table_id > 0); + EXPECT_TRUE(catalog->table_exists(table_id)); + EXPECT_TRUE(catalog->table_exists_by_name("test_table")); + + auto table = catalog->get_table(table_id); + EXPECT_TRUE(table.has_value()); + if (table.has_value()) { + EXPECT_STREQ(table.value()->name, "test_table"); + } + + catalog->update_table_stats(table_id, STATS_100); + if (table.has_value()) { + EXPECT_EQ(table.value()->num_rows, STATS_100); + } + + const oid_t idx_id = catalog->create_index("test_idx", table_id, {0}, IndexType::BTree, true); + EXPECT_TRUE(idx_id > 0); + EXPECT_EQ(catalog->get_table_indexes(table_id).size(), static_cast(1)); + + auto idx_pair = catalog->get_index(idx_id); + EXPECT_TRUE(idx_pair.has_value()); + if (idx_pair.has_value()) { + EXPECT_STREQ(idx_pair.value().second->name, "test_idx"); + } + + EXPECT_TRUE(catalog->drop_index(idx_id)); + EXPECT_EQ(catalog->get_table_indexes(table_id).size(), static_cast(0)); + + EXPECT_TRUE(catalog->drop_table(table_id)); + EXPECT_FALSE(catalog->table_exists(table_id)); +} + +// ============= Config Tests ============= + +TEST(ConfigTest_Basic) { + config::Config cfg; + EXPECT_EQ(cfg.port, PORT_5432); + + cfg.port = PORT_9999; + cfg.data_dir = "./tmp_data"; + + EXPECT_TRUE(cfg.validate()); + + const std::string cfg_file = "test_config.conf"; + EXPECT_TRUE(cfg.save(cfg_file)); + + config::Config cfg2; + EXPECT_TRUE(cfg2.load(cfg_file)); + EXPECT_EQ(cfg2.port, PORT_9999); + EXPECT_STREQ(cfg2.data_dir, "./tmp_data"); + + static_cast(std::remove(cfg_file.c_str())); +} + +// ============= Statement Tests ============= + +TEST(StatementTest_ToString) { + const TransactionBeginStatement begin; + EXPECT_STREQ(begin.to_string(), "BEGIN"); + + const TransactionCommitStatement commit; + EXPECT_STREQ(commit.to_string(), "COMMIT"); + + const TransactionRollbackStatement rollback; + EXPECT_STREQ(rollback.to_string(), "ROLLBACK"); +} + +TEST(StatementTest_Serialization) { + { + auto lexer = std::make_unique( + "SELECT name, age FROM users WHERE age > 18 ORDER BY age LIMIT 10 OFFSET 5"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + EXPECT_STREQ(stmt->to_string(), + "SELECT name, age FROM users WHERE age > 18 ORDER BY age LIMIT 10 OFFSET 5"); + } + { + auto lexer = + std::make_unique("INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')"); + Parser parser(std::move(lexer)); + auto stmt = parser.parse_statement(); + EXPECT_STREQ(stmt->to_string(), + "INSERT INTO users (id, name) VALUES (1, 'Alice'), (2, 'Bob')"); + } +} + +// ============= Storage Tests ============= + +TEST(StorageTest_Persistence) { + const std::string filename = "persist_test"; + static_cast(std::remove("./test_data/persist_test.heap")); + Schema schema; + schema.add_column("data", ValueType::TYPE_TEXT); + { + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + HeapTable table(filename, sm, schema); + static_cast(table.create()); + static_cast(table.insert(Tuple({Value::make_text("Persistent data")}))); + } + { + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + HeapTable table(filename, sm, schema); + auto iter = table.scan(); + Tuple t; + EXPECT_TRUE(iter.next(t)); + EXPECT_STREQ(t.get(0).as_text(), "Persistent data"); + } +} + +TEST(StorageTest_Delete) { + const std::string filename = "delete_test"; + static_cast(std::remove("./test_data/delete_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + Schema schema; + schema.add_column("id", ValueType::TYPE_INT64); + HeapTable table(filename, sm, schema); + EXPECT_TRUE(table.create()); + + static_cast(table.insert(Tuple({Value::make_int64(1)}))); + const auto tid2 = table.insert(Tuple({Value::make_int64(2)})); + + EXPECT_EQ(table.tuple_count(), static_cast(2)); + EXPECT_TRUE(table.remove(tid2, XMAX_100)); // Logically delete with xmax=100 + EXPECT_EQ(table.tuple_count(), static_cast(1)); + + auto iter = table.scan(); + Tuple t; + EXPECT_TRUE(iter.next(t)); + EXPECT_EQ(t.get(0).to_int64(), 1); + EXPECT_FALSE(iter.next(t)); +} + +// ============= Index Tests ============= + +TEST(IndexTest_BTreeBasic) { + static_cast(std::remove("./test_data/idx_test.idx")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + BTreeIndex idx("idx_test", sm, ValueType::TYPE_INT64); + static_cast(idx.create()); + static_cast(idx.insert(Value::make_int64(BTREE_VAL_10), HeapTable::TupleId(1, 1))); + static_cast(idx.insert(Value::make_int64(BTREE_VAL_20), HeapTable::TupleId(1, 2))); + static_cast(idx.insert(Value::make_int64(BTREE_VAL_10), HeapTable::TupleId(2, 1))); + const auto res = idx.search(Value::make_int64(BTREE_VAL_10)); + EXPECT_EQ(res.size(), static_cast(2)); + static_cast(idx.drop()); +} + +TEST(IndexTest_Scan) { + static_cast(std::remove("./test_data/scan_test.idx")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + BTreeIndex idx("scan_test", sm, ValueType::TYPE_INT64); + static_cast(idx.create()); + static_cast(idx.insert(Value::make_int64(1), HeapTable::TupleId(1, 1))); + static_cast(idx.insert(Value::make_int64(2), HeapTable::TupleId(1, 2))); + + auto iter = idx.scan(); + BTreeIndex::Entry entry; + EXPECT_TRUE(iter.next(entry)); + EXPECT_EQ(entry.key.to_int64(), 1); + EXPECT_TRUE(iter.next(entry)); + EXPECT_EQ(entry.key.to_int64(), 2); + EXPECT_FALSE(iter.next(entry)); +} + +// ============= Execution Tests ============= + +TEST(ExecutionTest_EndToEnd) { + static_cast(std::remove("./test_data/users.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + { + auto lexer = std::make_unique("CREATE TABLE users (id BIGINT, age BIGINT)"); + auto stmt = Parser(std::move(lexer)).parse_statement(); + const auto res = exec.execute(*stmt); + if (!res.success()) { + throw std::runtime_error("CREATE failed: " + res.error()); + } + } + { + auto lexer = + std::make_unique("INSERT INTO users (id, age) VALUES (1, 20), (2, 30), (3, 40)"); + auto stmt = Parser(std::move(lexer)).parse_statement(); + const auto res = exec.execute(*stmt); + if (!res.success()) { + throw std::runtime_error("INSERT failed: " + res.error()); + } + } + { + auto lexer = std::make_unique("SELECT id FROM users WHERE age > 25"); + auto stmt = Parser(std::move(lexer)).parse_statement(); + const auto res = exec.execute(*stmt); + if (!res.success()) { + throw std::runtime_error("SELECT failed: " + res.error()); + } + EXPECT_EQ(res.row_count(), static_cast(2)); + } +} + +TEST(ExecutionTest_Sort) { + static_cast(std::remove("./test_data/sort_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE sort_test (val INT)")).parse_statement())); + static_cast(exec.execute( + *Parser(std::make_unique("INSERT INTO sort_test VALUES (30), (10), (20)")) + .parse_statement())); + + const auto res = + exec.execute(*Parser(std::make_unique("SELECT val FROM sort_test ORDER BY val")) + .parse_statement()); + EXPECT_EQ(res.row_count(), static_cast(3)); + EXPECT_STREQ(res.rows()[0].get(0).to_string(), "10"); + EXPECT_STREQ(res.rows()[1].get(0).to_string(), "20"); + EXPECT_STREQ(res.rows()[2].get(0).to_string(), "30"); +} + +TEST(ExecutionTest_Aggregate) { + static_cast(std::remove("./test_data/agg_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast( + exec.execute(*Parser(std::make_unique("CREATE TABLE agg_test (cat TEXT, val INT)")) + .parse_statement())); + static_cast( + exec.execute(*Parser(std::make_unique( + "INSERT INTO agg_test VALUES ('A', 10), ('A', 20), ('B', 5)")) + .parse_statement())); + + auto lex = + std::make_unique("SELECT cat, COUNT(val), SUM(val) FROM agg_test GROUP BY cat"); + auto stmt = Parser(std::move(lex)).parse_statement(); + if (!stmt) { + throw std::runtime_error("Parser failed for aggregate query"); + } + + const auto res = exec.execute(*stmt); + if (!res.success()) { + throw std::runtime_error("Execution failed: " + res.error()); + } + + EXPECT_EQ(res.row_count(), static_cast(2)); + /* Row 0: 'A', 2, 30 */ + EXPECT_STREQ(res.rows()[0].get(0).to_string(), "A"); + EXPECT_STREQ(res.rows()[0].get(1).to_string(), "2"); + EXPECT_STREQ(res.rows()[0].get(2).to_string(), "30"); +} + +TEST(ExecutionTest_AggregateAdvanced) { + static_cast(std::remove("./test_data/adv_agg.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE adv_agg (val INT)")).parse_statement())); + static_cast( + exec.execute(*Parser(std::make_unique("INSERT INTO adv_agg VALUES (10), (20), (30)")) + .parse_statement())); + + const auto res = exec.execute( + *Parser(std::make_unique("SELECT MIN(val), MAX(val), AVG(val) FROM adv_agg")) + .parse_statement()); + if (!res.success()) { + throw std::runtime_error("Execution failed: " + res.error()); + } + + EXPECT_EQ(res.row_count(), static_cast(1)); + EXPECT_STREQ(res.rows()[0].get(0).to_string(), "10"); + EXPECT_STREQ(res.rows()[0].get(1).to_string(), "30"); + EXPECT_STREQ(res.rows()[0].get(2).to_string(), "20"); +} + +TEST(ExecutionTest_AggregateDistinct) { + static_cast(std::remove("./test_data/dist_agg.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE dist_agg (val INT)")).parse_statement())); + static_cast( + exec.execute(*Parser(std::make_unique( + "INSERT INTO dist_agg VALUES (10), (10), (20), (30), (30), (30)")) + .parse_statement())); + + const auto res = + exec.execute(*Parser(std::make_unique( + "SELECT COUNT(DISTINCT val), SUM(DISTINCT val) FROM dist_agg")) + .parse_statement()); + if (!res.success()) { + throw std::runtime_error("Execution failed: " + res.error()); + } + + EXPECT_EQ(res.row_count(), static_cast(1)); + EXPECT_STREQ(res.rows()[0].get(0).to_string(), "3"); + EXPECT_STREQ(res.rows()[0].get(1).to_string(), "60"); +} + +TEST(ExecutionTest_Transaction) { + static_cast(std::remove("./test_data/txn_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + + QueryExecutor qexec1(*catalog, sm, lm, tm); + static_cast( + qexec1.execute(*Parser(std::make_unique("CREATE TABLE txn_test (id INT, val INT)")) + .parse_statement())); + + static_cast(qexec1.execute(*Parser(std::make_unique("BEGIN")).parse_statement())); + static_cast( + qexec1.execute(*Parser(std::make_unique("INSERT INTO txn_test VALUES (1, 100)")) + .parse_statement())); + + QueryExecutor qexec2(*catalog, sm, lm, tm); + + const auto res_commit = + qexec1.execute(*Parser(std::make_unique("COMMIT")).parse_statement()); + EXPECT_TRUE(res_commit.success()); + + const auto res_select = + qexec2.execute(*Parser(std::make_unique("SELECT val FROM txn_test WHERE id = 1")) + .parse_statement()); + EXPECT_EQ(res_select.row_count(), static_cast(1)); + EXPECT_STREQ(res_select.rows()[0].get(0).to_string(), "100"); +} + +TEST(ExecutionTest_Rollback) { + static_cast(std::remove("./test_data/rollback_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast( + exec.execute(*Parser(std::make_unique("CREATE TABLE rollback_test (val INT)")) + .parse_statement())); + + static_cast(exec.execute(*Parser(std::make_unique("BEGIN")).parse_statement())); + static_cast( + exec.execute(*Parser(std::make_unique("INSERT INTO rollback_test VALUES (100)")) + .parse_statement())); + + const auto res_internal = exec.execute( + *Parser(std::make_unique("SELECT val FROM rollback_test")).parse_statement()); + EXPECT_EQ(res_internal.row_count(), static_cast(1)); + + static_cast(exec.execute(*Parser(std::make_unique("ROLLBACK")).parse_statement())); + + const auto res_after = exec.execute( + *Parser(std::make_unique("SELECT val FROM rollback_test")).parse_statement()); + EXPECT_EQ(res_after.row_count(), static_cast(0)); +} + +TEST(ExecutionTest_UpdateDelete) { + static_cast(std::remove("./test_data/upd_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast( + exec.execute(*Parser(std::make_unique("CREATE TABLE upd_test (id INT, val TEXT)")) + .parse_statement())); + static_cast(exec.execute( + *Parser(std::make_unique("INSERT INTO upd_test VALUES (1, 'old'), (2, 'stay')")) + .parse_statement())); + + /* Test UPDATE */ + const auto res_upd = exec.execute( + *Parser(std::make_unique("UPDATE upd_test SET val = 'new' WHERE id = 1")) + .parse_statement()); + EXPECT_EQ(res_upd.rows_affected(), static_cast(1)); + + const auto res_sel = + exec.execute(*Parser(std::make_unique("SELECT val FROM upd_test WHERE id = 1")) + .parse_statement()); + EXPECT_EQ(res_sel.row_count(), static_cast(1)); + EXPECT_STREQ(res_sel.rows()[0].get(0).to_string(), "new"); + + /* Test DELETE */ + const auto res_del = exec.execute( + *Parser(std::make_unique("DELETE FROM upd_test WHERE id = 2")).parse_statement()); + EXPECT_EQ(res_del.rows_affected(), static_cast(1)); + + const auto res_sel2 = + exec.execute(*Parser(std::make_unique("SELECT id FROM upd_test")).parse_statement()); + EXPECT_EQ(res_sel2.row_count(), static_cast(1)); // Only ID 1 remains +} + +TEST(ExecutionTest_MVCC) { + static_cast(std::remove("./test_data/mvcc_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + + QueryExecutor qexec1(*catalog, sm, lm, tm); + static_cast(qexec1.execute( + *Parser(std::make_unique("CREATE TABLE mvcc_test (val INT)")).parse_statement())); + + /* Start T1 and Insert */ + static_cast(qexec1.execute(*Parser(std::make_unique("BEGIN")).parse_statement())); + static_cast(qexec1.execute( + *Parser(std::make_unique("INSERT INTO mvcc_test VALUES (10)")).parse_statement())); + + /* Session 2 should see nothing yet (atomic snapshot) */ + QueryExecutor qexec2(*catalog, sm, lm, tm); + const auto res2_pre = qexec2.execute( + *Parser(std::make_unique("SELECT val FROM mvcc_test")).parse_statement()); + EXPECT_EQ(res2_pre.row_count(), static_cast(0)); + + /* T1 updates row */ + static_cast(qexec1.execute( + *Parser(std::make_unique("UPDATE mvcc_test SET val = 20")).parse_statement())); + + /* T1 sees new value */ + const auto res1 = qexec1.execute( + *Parser(std::make_unique("SELECT val FROM mvcc_test")).parse_statement()); + EXPECT_EQ(res1.row_count(), static_cast(1)); + EXPECT_STREQ(res1.rows()[0].get(0).to_string(), "20"); + + static_cast(qexec1.execute(*Parser(std::make_unique("COMMIT")).parse_statement())); + + /* After commit, Session 2 sees the latest value */ + const auto res2_post = qexec2.execute( + *Parser(std::make_unique("SELECT val FROM mvcc_test")).parse_statement()); + EXPECT_EQ(res2_post.row_count(), static_cast(1)); + EXPECT_STREQ(res2_post.rows()[0].get(0).to_string(), "20"); +} + +TEST(ExecutionTest_Join) { + static_cast(std::remove("./test_data/users.heap")); + static_cast(std::remove("./test_data/orders.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast( + exec.execute(*Parser(std::make_unique("CREATE TABLE users (id INT, name TEXT)")) + .parse_statement())); + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE orders (id INT, user_id INT, amount DOUBLE)")) + .parse_statement())); + + static_cast(exec.execute( + *Parser(std::make_unique("INSERT INTO users VALUES (1, 'Alice'), (2, 'Bob')")) + .parse_statement())); + static_cast(exec.execute( + *Parser(std::make_unique( + "INSERT INTO orders VALUES (101, 1, 50.5), (102, 1, 25.0), (103, 2, 100.0)")) + .parse_statement())); + + /* Test: INNER JOIN with sorting */ + const auto result = exec.execute( + *Parser(std::make_unique("SELECT users.name, orders.amount FROM users JOIN orders " + "ON users.id = orders.user_id ORDER BY orders.amount")) + .parse_statement()); + + EXPECT_EQ(result.row_count(), static_cast(3)); + + /* 25.0 (Alice), 50.5 (Alice), 100.0 (Bob) */ + EXPECT_STREQ(result.rows()[0].get(0).to_string(), "Alice"); + EXPECT_STREQ(result.rows()[0].get(1).to_string(), "25"); + EXPECT_STREQ(result.rows()[2].get(0).to_string(), "Bob"); + EXPECT_STREQ(result.rows()[2].get(1).to_string(), "100"); +} + +TEST(ExecutionTest_DDL) { + static_cast(std::remove("./test_data/ddl_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + /* 1. Create and then Drop Table */ + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE ddl_test (id INT)")).parse_statement())); + EXPECT_TRUE(catalog->table_exists_by_name("ddl_test")); + + const auto res_drop = + exec.execute(*Parser(std::make_unique("DROP TABLE ddl_test")).parse_statement()); + EXPECT_TRUE(res_drop.success()); + EXPECT_FALSE(catalog->table_exists_by_name("ddl_test")); + + /* 2. IF EXISTS */ + const auto res_drop_none = exec.execute( + *Parser(std::make_unique("DROP TABLE IF EXISTS non_existent")).parse_statement()); + EXPECT_TRUE(res_drop_none.success()); + + /* 3. Create Index and then Drop Index */ + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE ddl_test (id INT)")).parse_statement())); + // Note: Our system doesn't have a direct "CREATE INDEX" statement parsing yet, + // but the catalog supports it. For now we just test that DROP INDEX works if index exists. + auto table_opt = catalog->get_table_by_name("ddl_test"); + if (table_opt) { + const oid_t tid = (*table_opt)->table_id; + static_cast(catalog->create_index("idx_ddl", tid, {0}, IndexType::BTree, true)); + } + + const auto res_drop_idx = + exec.execute(*Parser(std::make_unique("DROP INDEX idx_ddl")).parse_statement()); + EXPECT_TRUE(res_drop_idx.success()); +} + +TEST(LexerTest_Advanced) { + /* 1. Test comments and line tracking */ + { + const std::string sql = "SELECT -- comment here\n* FROM users"; + Lexer lexer(sql); + const auto t1 = lexer.next_token(); + EXPECT_EQ(static_cast(t1.type()), static_cast(TokenType::Select)); + const auto t2 = lexer.next_token(); // Should skip comment and newline + EXPECT_STREQ(t2.lexeme(), "*"); + EXPECT_EQ(t2.line(), static_cast(2)); + } + /* 2. Test Error and Unknown operators */ + { + Lexer lexer("@"); + const auto t = lexer.next_token(); + EXPECT_EQ(static_cast(t.type()), static_cast(TokenType::Error)); + } +} + +TEST(ExecutionTest_Expressions) { + static_cast(std::remove("./test_data/expr_test.heap")); + StorageManager disk_manager("./test_data"); + BufferPoolManager sm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); + auto catalog = Catalog::create(); + LockManager lm; + TransactionManager tm(lm, *catalog, sm); + QueryExecutor exec(*catalog, sm, lm, tm); + + static_cast(exec.execute( + *Parser(std::make_unique("CREATE TABLE expr_test (id INT, val DOUBLE, str TEXT)")) + .parse_statement())); + static_cast(exec.execute( + *Parser(std::make_unique( + "INSERT INTO expr_test VALUES (1, 10.5, 'A'), (2, NULL, 'B'), (3, 20.0, 'C')")) + .parse_statement())); + + /* 1. Test IS NULL / IS NOT NULL */ + { + const auto res = exec.execute( + *Parser(std::make_unique("SELECT id FROM expr_test WHERE val IS NULL")) + .parse_statement()); + EXPECT_EQ(res.row_count(), static_cast(1)); + EXPECT_EQ( + res.rows()[0].get(0).to_int64(), + static_cast( + 2)); // NOLINT(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + + const auto res2 = exec.execute( + *Parser(std::make_unique("SELECT id FROM expr_test WHERE val IS NOT NULL")) + .parse_statement()); + EXPECT_EQ(res2.row_count(), static_cast(2)); + } + + /* 2. Test IN / NOT IN */ + { + const auto res = exec.execute( + *Parser(std::make_unique("SELECT id FROM expr_test WHERE id IN (1, 3)")) + .parse_statement()); + EXPECT_EQ(res.row_count(), static_cast(2)); + + const auto res2 = exec.execute( + *Parser(std::make_unique("SELECT id FROM expr_test WHERE str NOT IN ('A', 'C')")) + .parse_statement()); + EXPECT_EQ(res2.row_count(), static_cast(1)); + EXPECT_EQ( + res2.rows()[0].get(0).to_int64(), + static_cast( + 2)); // NOLINT(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + } + + /* 3. Test Arithmetic and Complex Binary */ + { + const auto res = exec.execute( + *Parser(std::make_unique( + "SELECT id, val * 2 + 10, val / 2, val - 5 FROM expr_test WHERE id = 1")) + .parse_statement()); + EXPECT_DOUBLE_EQ( + res.rows()[0].get(1).to_float64(), + 31.0); // NOLINT(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + EXPECT_DOUBLE_EQ( + res.rows()[0].get(2).to_float64(), + 5.25); // NOLINT(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + EXPECT_DOUBLE_EQ( + res.rows()[0].get(3).to_float64(), + 5.5); // NOLINT(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers) + } +} + +TEST(ExpressionTest_Types) { + /* Test ConstantExpr with various types for coverage */ + { + const ConstantExpr c_bool(Value::make_bool(true)); + EXPECT_TRUE(c_bool.evaluate().as_bool()); + + const ConstantExpr c_int(Value::make_int64(VAL_123)); + EXPECT_EQ(c_int.evaluate().to_int64(), VAL_123); + + const ConstantExpr c_float(Value::make_float64(VAL_1_5)); + EXPECT_DOUBLE_EQ(c_float.evaluate().to_float64(), VAL_1_5); + + const ConstantExpr c_null(Value::make_null()); + EXPECT_TRUE(c_null.evaluate().is_null()); + } +} + +TEST(CatalogTest_Errors) { + auto catalog = Catalog::create(); + const std::vector cols = {{"id", ValueType::TYPE_INT64, 0}}; + + static_cast(catalog->create_table("fail_test", cols)); + /* Duplicate table */ + EXPECT_THROW(catalog->create_table("fail_test", cols), std::exception); + + /* Missing table */ + EXPECT_FALSE(catalog->table_exists(TABLE_9999)); + EXPECT_FALSE(catalog->get_table(TABLE_9999).has_value()); + EXPECT_FALSE(catalog->table_exists_by_name("non_existent")); + + /* Duplicate index */ + const oid_t tid = catalog->create_table("idx_fail", cols); + static_cast(catalog->create_index("my_idx", tid, {0}, IndexType::BTree, true)); + EXPECT_THROW(catalog->create_index("my_idx", tid, {0}, IndexType::BTree, true), std::exception); + + /* Missing index */ + EXPECT_FALSE(catalog->get_index(INDEX_8888).has_value()); + EXPECT_FALSE(catalog->drop_index(INDEX_8888)); +} + +TEST(CatalogTest_Stats) { + auto catalog = Catalog::create(); + const std::vector cols = {{"id", ValueType::TYPE_INT64, 0}}; + const oid_t tid = catalog->create_table("stats_test", cols); + + EXPECT_TRUE(catalog->update_table_stats(tid, STATS_500)); + auto tinfo = catalog->get_table(tid); + if (tinfo) { + EXPECT_EQ((*tinfo)->num_rows, STATS_500); + } + + /* Cover print() */ + catalog->print(); +} + +} // namespace + +int main() { + std::cout << "Unit Tests\n"; + std::cout << "==========\n"; + + RUN_TEST(ValueTest_Basic); + RUN_TEST(ValueTest_TypeVariety); + RUN_TEST(ParserTest_Expressions); + RUN_TEST(ExpressionTest_Complex); + RUN_TEST(ParserTest_SelectVariants); + RUN_TEST(ParserTest_Errors); + RUN_TEST(CatalogTest_FullLifecycle); + RUN_TEST(ConfigTest_Basic); + RUN_TEST(StatementTest_ToString); + RUN_TEST(StatementTest_Serialization); + RUN_TEST(StorageTest_Persistence); + RUN_TEST(StorageTest_Delete); + RUN_TEST(IndexTest_BTreeBasic); + RUN_TEST(IndexTest_Scan); + RUN_TEST(ExecutionTest_EndToEnd); + RUN_TEST(ExecutionTest_Sort); + RUN_TEST(ExecutionTest_Aggregate); + RUN_TEST(ExecutionTest_AggregateAdvanced); + RUN_TEST(ExecutionTest_AggregateDistinct); + RUN_TEST(ExecutionTest_Transaction); + RUN_TEST(ExecutionTest_Rollback); + RUN_TEST(ExecutionTest_UpdateDelete); + RUN_TEST(ExecutionTest_MVCC); + RUN_TEST(ExecutionTest_Join); + RUN_TEST(ExecutionTest_DDL); + RUN_TEST(LexerTest_Advanced); + RUN_TEST(ExecutionTest_Expressions); + RUN_TEST(ExpressionTest_Types); + RUN_TEST(CatalogTest_Errors); + RUN_TEST(CatalogTest_Stats); + + std::cout << "\nResults: \n" << tests_passed << " passed, \n" << tests_failed << " failed\n"; + return (tests_failed > 0); +} diff --git a/tests/lock_manager_tests.cpp b/tests/lock_manager_tests.cpp index d8a16208..bf5db78f 100644 --- a/tests/lock_manager_tests.cpp +++ b/tests/lock_manager_tests.cpp @@ -108,8 +108,6 @@ TEST(LockManagerTests, Deadlock) { std::this_thread::sleep_for(TEST_SLEEP_MS); // txn2 waits for A -> Deadlock! - // Current implementation might not detect deadlock and just timeout or block. - // For now we just verify we can grant if one releases. static_cast(lm.unlock(&txn1, "A")); static_cast(lm.acquire_exclusive(&txn2, "A")); diff --git a/tests/recovery_tests.cpp b/tests/recovery_tests.cpp index f91801df..76d8c7be 100644 --- a/tests/recovery_tests.cpp +++ b/tests/recovery_tests.cpp @@ -1,37 +1,145 @@ /** * @file recovery_tests.cpp - * @brief Unit tests for Log Manager and Recovery + * @brief Unit tests for Write-Ahead Logging and Recovery */ #include +#include #include +#include +#include #include +#include +#include +#include "common/value.hpp" +#include "executor/types.hpp" #include "recovery/log_manager.hpp" #include "recovery/log_record.hpp" +#include "storage/heap_table.hpp" +#include "test_utils.hpp" +using namespace cloudsql; using namespace cloudsql::recovery; +using namespace cloudsql::common; +using namespace cloudsql::executor; +using namespace cloudsql::storage; namespace { +constexpr uint64_t TXN_100 = 100; +constexpr lsn_t PREV_LSN_99 = 99; +constexpr lsn_t CUR_LSN_101 = 101; +constexpr int64_t VAL_42 = 42; +constexpr uint64_t TXN_50 = 50; +constexpr lsn_t PREV_LSN_49 = 49; +constexpr int8_t INT8_10 = 10; +constexpr int16_t INT16_200 = 200; +constexpr int32_t INT32_3000 = 3000; +constexpr float F32_1_22 = 1.22F; +constexpr float F32_1_23 = 1.23F; +constexpr float F32_1_24 = 1.24F; +constexpr double F64_4_55 = 4.55; +constexpr double F64_4_56 = 4.56; +constexpr double F64_4_57 = 4.57; + +// Helper to clean up test files +void cleanup(const std::string& file) { + static_cast(std::remove(file.c_str())); +} + +TEST(RecoveryTests, LogRecordSerialization) { + // 1. Create a dummy INSERT log record + std::vector values; + values.emplace_back(Value::make_int64(VAL_42)); + values.emplace_back(Value::make_text("test_string")); + const Tuple tuple(std::move(values)); + + LogRecord original(TXN_100, PREV_LSN_99, LogRecordType::INSERT, "test_table", + HeapTable::TupleId(1, 2), tuple); + original.lsn_ = CUR_LSN_101; + original.size_ = original.get_size(); + + // 2. Serialize + std::vector buffer(original.size_); + static_cast(original.serialize(buffer.data())); + + // 3. Deserialize + const LogRecord deserialized = LogRecord::deserialize(buffer.data()); + + // 4. Verify + EXPECT_EQ(deserialized.lsn_, original.lsn_); + EXPECT_EQ(deserialized.prev_lsn_, original.prev_lsn_); + EXPECT_EQ(deserialized.txn_id_, original.txn_id_); + EXPECT_EQ(static_cast(deserialized.type_), static_cast(original.type_)); + EXPECT_EQ(deserialized.table_name_, original.table_name_); + EXPECT_EQ(deserialized.rid_, original.rid_); + + EXPECT_EQ(deserialized.tuple_.size(), original.tuple_.size()); + EXPECT_EQ(deserialized.tuple_.get(0).to_int64(), VAL_42); + EXPECT_EQ(deserialized.tuple_.get(1).as_text(), "test_string"); +} + +TEST(RecoveryTests, LogRecordAllTypes) { + std::vector values; + values.emplace_back(Value::make_bool(true)); + values.emplace_back(static_cast(INT8_10)); + values.emplace_back(static_cast(INT16_200)); + values.emplace_back(static_cast(INT32_3000)); + values.emplace_back(static_cast(F32_1_23)); + values.emplace_back(static_cast(F64_4_56)); + values.emplace_back(Value::make_null()); + + const Tuple tuple(std::move(values)); + LogRecord original(TXN_50, PREV_LSN_49, LogRecordType::INSERT, "types_table", + HeapTable::TupleId(1, 1), tuple); + original.size_ = original.get_size(); + + std::vector buffer(original.size_); + static_cast(original.serialize(buffer.data())); + + const LogRecord deserialized = LogRecord::deserialize(buffer.data()); + + ASSERT_EQ(deserialized.tuple_.size(), 7U); + EXPECT_TRUE(deserialized.tuple_.get(0).as_bool()); + EXPECT_EQ(deserialized.tuple_.get(1).as_int8(), INT8_10); + EXPECT_EQ(deserialized.tuple_.get(2).as_int16(), INT16_200); + EXPECT_EQ(deserialized.tuple_.get(3).as_int32(), INT32_3000); + EXPECT_GT(deserialized.tuple_.get(4).as_float32(), F32_1_22); + EXPECT_LT(deserialized.tuple_.get(4).as_float32(), F32_1_24); + EXPECT_GT(deserialized.tuple_.get(5).as_float64(), F64_4_55); + EXPECT_LT(deserialized.tuple_.get(5).as_float64(), F64_4_57); + EXPECT_TRUE(deserialized.tuple_.get(6).is_null()); +} + TEST(RecoveryTests, LogManagerBasic) { - const std::string log_file = "test.log"; - static_cast(std::remove(log_file.c_str())); + const std::string log_file = "test_log_basic.log"; + cleanup(log_file); { - LogManager lm(log_file); - lm.run_flush_thread(); + LogManager log_manager(log_file); + log_manager.run_flush_thread(); - LogRecord record1(1, 0, LogRecordType::INSERT, "table1", {0, 0}, {}); - const lsn_t lsn1 = lm.append_log_record(record1); - EXPECT_GE(lsn1, 0); + // Append a few logs + LogRecord qlog1(1, -1, LogRecordType::BEGIN); + const lsn_t lsn1 = log_manager.append_log_record(qlog1); + EXPECT_EQ(lsn1, 0); - lm.flush(true); - lm.stop_flush_thread(); + LogRecord qlog2(1, lsn1, LogRecordType::COMMIT); + const lsn_t lsn2 = log_manager.append_log_record(qlog2); + EXPECT_EQ(lsn2, 1); + + // Wait for flush + log_manager.flush(true); + EXPECT_GE(log_manager.get_persistent_lsn(), lsn2); } - static_cast(std::remove(log_file.c_str())); + // Verify file content size roughly + std::ifstream in(log_file, std::ios::binary | std::ios::ate); + EXPECT_GT(in.tellg(), 0); + + cleanup(log_file); } } // namespace diff --git a/tests/server_tests.cpp b/tests/server_tests.cpp index e74bc230..6e105949 100644 --- a/tests/server_tests.cpp +++ b/tests/server_tests.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,7 @@ #include "storage/buffer_pool_manager.hpp" #include "storage/heap_table.hpp" #include "storage/storage_manager.hpp" +#include "test_utils.hpp" using namespace cloudsql; using namespace cloudsql::network; @@ -82,13 +84,13 @@ TEST(ServerTests, SimpleQuery) { std::vector cols; cols.emplace_back("id", common::ValueType::TYPE_INT32, 0); - static_cast(catalog->create_table("dual", std::move(cols))); + static_cast(catalog->create_table("dual_server", std::move(cols))); auto server = Server::create(port, *catalog, sm); - static_cast(std::remove("./test_data/dual.heap")); + static_cast(std::remove("./test_data/dual_server.heap")); storage::HeapTable table( - "dual", sm, + "dual_server", sm, executor::Schema({executor::ColumnMeta("id", common::ValueType::TYPE_INT32, true)})); static_cast(table.create()); static_cast(table.insert(executor::Tuple({common::Value(1)}), 0)); @@ -124,7 +126,7 @@ TEST(ServerTests, SimpleQuery) { static_cast(recv(sock, buffer.data(), AUTH_OK_LEN, 0)); static_cast(recv(sock, buffer.data(), READY_LEN, 0)); - const std::string sql = "SELECT id FROM dual"; + const std::string sql = "SELECT id FROM dual_server"; const char q_type = 'Q'; const uint32_t q_len = htonl(static_cast(sql.size() + 4 + 1)); static_cast(send(sock, &q_type, 1, 0)); @@ -169,6 +171,7 @@ TEST(ServerTests, SimpleQuery) { } static_cast(server->stop()); + static_cast(std::remove("./test_data/dual_server.heap")); } TEST(ServerTests, InvalidProtocol) { diff --git a/tests/statement_tests.cpp b/tests/statement_tests.cpp index 53593654..59044c45 100644 --- a/tests/statement_tests.cpp +++ b/tests/statement_tests.cpp @@ -1,27 +1,146 @@ /** * @file statement_tests.cpp - * @brief Unit tests for SQL Statements + * @brief Unit tests for SQL Statement serialization */ #include +#include +#include +#include #include +#include +#include +#include "common/value.hpp" +#include "parser/expression.hpp" #include "parser/statement.hpp" +#include "parser/token.hpp" +#include "test_utils.hpp" +using namespace cloudsql; using namespace cloudsql::parser; +using namespace cloudsql::common; namespace { -TEST(StatementTests, ToString) { - const TransactionBeginStatement begin; - EXPECT_STREQ(begin.to_string().c_str(), "BEGIN"); +constexpr int64_t VAL_18 = 18; +constexpr int64_t VAL_5 = 5; +constexpr int64_t LIMIT_10 = 10; +constexpr int64_t OFFSET_5 = 5; +constexpr int64_t PRICE_100 = 100; +constexpr int64_t STOCK_50 = 50; - const TransactionCommitStatement commit; - EXPECT_STREQ(commit.to_string().c_str(), "COMMIT"); +TEST(StatementTests, SelectStatementComplex) { + auto stmt = std::make_unique(); + stmt->set_distinct(true); + stmt->add_column(std::make_unique("id")); + stmt->add_column(std::make_unique("name")); - const TransactionRollbackStatement rollback; - EXPECT_STREQ(rollback.to_string().c_str(), "ROLLBACK"); + stmt->add_from(std::make_unique("users")); + + // JOIN orders ON users.id = orders.user_id + auto join_cond = + std::make_unique(std::make_unique("users", "id"), TokenType::Eq, + std::make_unique("orders", "user_id")); + stmt->add_join(SelectStatement::JoinType::Inner, std::make_unique("orders"), + std::move(join_cond)); + + // LEFT JOIN metadata (no condition for test simplicity, though invalid SQL usually) + stmt->add_join(SelectStatement::JoinType::Left, std::make_unique("metadata"), + nullptr); + + // WHERE age > 18 + stmt->set_where( + std::make_unique(std::make_unique("age"), TokenType::Gt, + std::make_unique(Value::make_int64(VAL_18)))); + + // GROUP BY age + stmt->add_group_by(std::make_unique("age")); + + // HAVING COUNT(*) > 5 + auto count_func = std::make_unique("COUNT"); + stmt->set_having( + std::make_unique(std::move(count_func), TokenType::Gt, + std::make_unique(Value::make_int64(VAL_5)))); + + // ORDER BY name DESC (using simplified check since we don't have DESC enum exposed in + // expression easily here) + stmt->add_order_by(std::make_unique("name")); + + stmt->set_limit(LIMIT_10); + stmt->set_offset(OFFSET_5); + + const std::string sql = stmt->to_string(); + + EXPECT_STREQ( + sql.c_str(), + "SELECT DISTINCT id, name FROM users JOIN orders ON users.id = orders.user_id LEFT JOIN " + "metadata WHERE age > 18 GROUP BY age HAVING COUNT(*) > 5 ORDER BY name LIMIT 10 OFFSET 5"); +} + +TEST(StatementTests, InsertStatementMultiRow) { + auto stmt = std::make_unique(); + stmt->set_table(std::make_unique("users")); + stmt->add_column(std::make_unique("id")); + stmt->add_column(std::make_unique("val")); + + std::vector> row1; + row1.emplace_back(std::make_unique(Value::make_int64(1))); + row1.emplace_back(std::make_unique(Value::make_text("A"))); + stmt->add_row(std::move(row1)); + + std::vector> row2; + row2.emplace_back(std::make_unique(Value::make_int64(2))); + row2.emplace_back(std::make_unique(Value::make_text("B"))); + stmt->add_row(std::move(row2)); + + EXPECT_STREQ(stmt->to_string().c_str(), + "INSERT INTO users (id, val) VALUES (1, 'A'), (2, 'B')"); +} + +TEST(StatementTests, UpdateStatementBasic) { + auto stmt = std::make_unique(); + stmt->set_table(std::make_unique("products")); + + stmt->add_set(std::make_unique("price"), + std::make_unique(Value::make_int64(PRICE_100))); + stmt->add_set(std::make_unique("stock"), + std::make_unique(Value::make_int64(STOCK_50))); + + stmt->set_where( + std::make_unique(std::make_unique("id"), TokenType::Eq, + std::make_unique(Value::make_int64(1)))); + + // Map iteration order is unspecified, so check substrings + const std::string sql = stmt->to_string(); + EXPECT_NE(sql.find("price = 100"), std::string::npos); + EXPECT_NE(sql.find("stock = 50"), std::string::npos); +} + +TEST(StatementTests, DeleteStatementBasic) { + auto stmt = std::make_unique(); + stmt->set_table(std::make_unique("users")); + stmt->set_where( + std::make_unique(std::make_unique("id"), TokenType::Lt, + std::make_unique(Value::make_int64(0)))); + + EXPECT_STREQ(stmt->to_string().c_str(), "DELETE FROM users WHERE id < 0"); +} + +TEST(StatementTests, CreateTableStatementComplex) { + auto stmt = std::make_unique(); + stmt->set_table_name("complex_table"); + + stmt->add_column("id", "INT"); + stmt->get_last_column().is_primary_key_ = true; + + stmt->add_column("name", "TEXT"); + stmt->get_last_column().is_not_null_ = true; + stmt->get_last_column().is_unique_ = true; + + EXPECT_STREQ(stmt->to_string().c_str(), + "CREATE TABLE complex_table (id INT PRIMARY KEY, name TEXT NOT NULL UNIQUE)"); } } // namespace diff --git a/tests/transaction_manager_tests.cpp b/tests/transaction_manager_tests.cpp index 0daa9abb..29b3ff50 100644 --- a/tests/transaction_manager_tests.cpp +++ b/tests/transaction_manager_tests.cpp @@ -26,7 +26,7 @@ TEST(TransactionManagerTests, Basic) { storage::BufferPoolManager bpm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); LockManager lm; - TransactionManager tm(lm, *catalog, bpm); + TransactionManager tm(lm, *catalog, bpm, bpm.get_log_manager()); Transaction* const txn1 = tm.begin(); ASSERT_NE(txn1, nullptr); @@ -46,7 +46,7 @@ TEST(TransactionManagerTests, Isolation) { storage::BufferPoolManager bpm(cloudsql::config::Config::DEFAULT_BUFFER_POOL_SIZE, disk_manager); LockManager lm; - TransactionManager tm(lm, *catalog, bpm); + TransactionManager tm(lm, *catalog, bpm, bpm.get_log_manager()); Transaction* const txn1 = tm.begin(); Transaction* const txn2 = tm.begin();