diff --git a/CMakeLists.txt b/CMakeLists.txt index 98001ec5..c720e39d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -131,6 +131,12 @@ if(BUILD_TESTS) add_cloudsql_test(multi_raft_tests tests/multi_raft_tests.cpp) add_cloudsql_test(distributed_txn_tests tests/distributed_txn_tests.cpp) add_cloudsql_test(analytics_tests tests/analytics_tests.cpp) + add_cloudsql_test(raft_manager_tests tests/raft_manager_tests.cpp) + add_cloudsql_test(raft_protocol_tests tests/raft_protocol_tests.cpp) + add_cloudsql_test(columnar_table_tests tests/columnar_table_tests.cpp) + add_cloudsql_test(storage_manager_tests tests/storage_manager_tests.cpp) + add_cloudsql_test(rpc_server_tests tests/rpc_server_tests.cpp) + add_cloudsql_test(operator_tests tests/operator_tests.cpp) add_custom_target(run-tests COMMAND ${CMAKE_CTEST_COMMAND} diff --git a/tests/columnar_table_tests.cpp b/tests/columnar_table_tests.cpp new file mode 100644 index 00000000..cee7d2d0 --- /dev/null +++ b/tests/columnar_table_tests.cpp @@ -0,0 +1,284 @@ +/** + * @file columnar_table_tests.cpp + * @brief Unit tests for ColumnarTable - column-oriented storage + */ + +#include + +#include +#include +#include +#include + +#include "common/value.hpp" +#include "executor/types.hpp" +#include "storage/columnar_table.hpp" +#include "storage/storage_manager.hpp" + +using namespace cloudsql; +using namespace cloudsql::storage; +using namespace cloudsql::executor; + +namespace { + +static void cleanup_table(const std::string& name) { + // clang-format off + std::remove(("./test_data/" + name + ".meta.bin").c_str()); + std::remove(("./test_data/" + name + ".col0.nulls.bin").c_str()); + std::remove(("./test_data/" + name + ".col0.data.bin").c_str()); + std::remove(("./test_data/" + name + ".col1.nulls.bin").c_str()); + std::remove(("./test_data/" + name + ".col1.data.bin").c_str()); + // clang-format on +} + +class ColumnarTableTests : public ::testing::Test { + protected: + void SetUp() override { + sm_ = std::make_unique("./test_data"); + sm_->create_dir_if_not_exists(); + } + + void TearDown() override { sm_.reset(); } + + std::unique_ptr sm_; +}; + +TEST_F(ColumnarTableTests, BasicInt64Lifecycle) { + const std::string name = "col_test_int"; + cleanup_table(name); + + Schema schema; + schema.add_column("id", common::ValueType::TYPE_INT64); + schema.add_column("val", common::ValueType::TYPE_INT64); + + ColumnarTable table(name, *sm_, schema); + + ASSERT_TRUE(table.create()); + ASSERT_TRUE(table.open()); + ASSERT_EQ(table.row_count(), 0U); + + // Build batch with 2 rows + auto batch = VectorBatch::create(schema); + batch->get_column(0).append(common::Value::make_int64(1)); + batch->get_column(0).append(common::Value::make_int64(2)); + batch->get_column(1).append(common::Value::make_int64(100)); + batch->get_column(1).append(common::Value::make_int64(200)); + batch->set_row_count(2); + + ASSERT_TRUE(table.append_batch(*batch)); + ASSERT_EQ(table.row_count(), 2U); + + // Reopen and verify + ColumnarTable table2(name, *sm_, schema); + ASSERT_TRUE(table2.open()); + ASSERT_EQ(table2.row_count(), 2U); + + auto read_batch = VectorBatch::create(schema); + ASSERT_TRUE(table2.read_batch(0, 10, *read_batch)); + ASSERT_EQ(read_batch->row_count(), 2U); +} + +TEST_F(ColumnarTableTests, BasicFloat64Lifecycle) { + const std::string name = "col_test_float"; + cleanup_table(name); + + Schema schema; + schema.add_column("x", common::ValueType::TYPE_FLOAT64); + schema.add_column("y", common::ValueType::TYPE_FLOAT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_TRUE(table.create()); + + auto batch = VectorBatch::create(schema); + batch->get_column(0).append(common::Value::make_float64(1.5)); + batch->get_column(1).append(common::Value::make_float64(2.7)); + batch->set_row_count(1); + ASSERT_TRUE(table.append_batch(*batch)); + ASSERT_EQ(table.row_count(), 1U); + + auto out = VectorBatch::create(schema); + ASSERT_TRUE(table.read_batch(0, 1, *out)); + ASSERT_EQ(out->row_count(), 1U); + + auto& col_x = dynamic_cast&>(out->get_column(0)); + EXPECT_FLOAT_EQ(col_x.get(0).to_float64(), 1.5); +} + +TEST_F(ColumnarTableTests, NullValueHandling) { + const std::string name = "col_test_null"; + cleanup_table(name); + + Schema schema; + schema.add_column("nullable_col", common::ValueType::TYPE_INT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_TRUE(table.create()); + + auto batch = VectorBatch::create(schema); + batch->get_column(0).append(common::Value::make_null()); + batch->get_column(0).append(common::Value::make_int64(42)); + batch->set_row_count(2); + ASSERT_TRUE(table.append_batch(*batch)); + + auto out = VectorBatch::create(schema); + ASSERT_TRUE(table.read_batch(0, 2, *out)); + ASSERT_EQ(out->row_count(), 2U); + + auto& col = dynamic_cast&>(out->get_column(0)); + EXPECT_TRUE(col.is_null(0)); + EXPECT_FALSE(col.is_null(1)); + EXPECT_EQ(col.get(1).to_int64(), 42); +} + +TEST_F(ColumnarTableTests, MultiBatchAppendRead) { + const std::string name = "col_test_multi"; + cleanup_table(name); + + Schema schema; + schema.add_column("val", common::ValueType::TYPE_INT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_TRUE(table.create()); + + // Append 3 batches of 100 rows each + for (int batch_num = 0; batch_num < 3; ++batch_num) { + auto batch = VectorBatch::create(schema); + for (int i = 0; i < 100; ++i) { + batch->get_column(0).append(common::Value::make_int64(batch_num * 100 + i)); + } + batch->set_row_count(100); + ASSERT_TRUE(table.append_batch(*batch)); + } + + ASSERT_EQ(table.row_count(), 300U); + + // Read in pages + auto out = VectorBatch::create(schema); + ASSERT_TRUE(table.read_batch(0, 100, *out)); + ASSERT_EQ(out->row_count(), 100U); + + out = VectorBatch::create(schema); + ASSERT_TRUE(table.read_batch(100, 100, *out)); + ASSERT_EQ(out->row_count(), 100U); + + out = VectorBatch::create(schema); + ASSERT_TRUE(table.read_batch(200, 100, *out)); + ASSERT_EQ(out->row_count(), 100U); +} + +TEST_F(ColumnarTableTests, ReadBatchBeyondEnd) { + const std::string name = "col_test_beyond"; + cleanup_table(name); + + Schema schema; + schema.add_column("id", common::ValueType::TYPE_INT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_TRUE(table.create()); + + auto batch = VectorBatch::create(schema); + batch->get_column(0).append(common::Value::make_int64(1)); + batch->set_row_count(1); + ASSERT_TRUE(table.append_batch(*batch)); + + auto out = VectorBatch::create(schema); + EXPECT_FALSE(table.read_batch(5, 10, *out)); + EXPECT_FALSE(table.read_batch(1, 10, *out)); +} + +TEST_F(ColumnarTableTests, ReadBatchPartial) { + const std::string name = "col_test_partial"; + cleanup_table(name); + + Schema schema; + schema.add_column("id", common::ValueType::TYPE_INT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_TRUE(table.create()); + + auto batch = VectorBatch::create(schema); + for (int i = 0; i < 10; ++i) { + batch->get_column(0).append(common::Value::make_int64(i)); + } + batch->set_row_count(10); + ASSERT_TRUE(table.append_batch(*batch)); + + // Read starting mid-table with batch_size larger than remaining + auto out = VectorBatch::create(schema); + ASSERT_TRUE(table.read_batch(8, 100, *out)); + ASSERT_EQ(out->row_count(), 2U); +} + +TEST_F(ColumnarTableTests, UnsupportedTypeThrows) { + const std::string name = "col_test_unsupported"; + cleanup_table(name); + + Schema schema; + schema.add_column("text_col", common::ValueType::TYPE_TEXT); + + // VectorBatch::create() throws when it sees TYPE_TEXT (unsupported) + EXPECT_THROW([[maybe_unused]] auto batch = VectorBatch::create(schema), std::runtime_error); +} + +TEST_F(ColumnarTableTests, CreateTwice) { + const std::string name = "col_test_twice"; + cleanup_table(name); + + Schema schema; + schema.add_column("id", common::ValueType::TYPE_INT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_TRUE(table.create()); + + // Second create() on existing files - behavior depends on ofstream flags + // This is a basic sanity check + ASSERT_TRUE(table.create()); +} + +TEST_F(ColumnarTableTests, OpenWithoutCreate) { + const std::string name = "col_test_missing"; + + Schema schema; + schema.add_column("id", common::ValueType::TYPE_INT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_FALSE(table.open()); +} + +TEST_F(ColumnarTableTests, EmptyBatch) { + const std::string name = "col_test_empty"; + cleanup_table(name); + + Schema schema; + schema.add_column("id", common::ValueType::TYPE_INT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_TRUE(table.create()); + + auto batch = VectorBatch::create(schema); + batch->set_row_count(0); + ASSERT_TRUE(table.append_batch(*batch)); + ASSERT_EQ(table.row_count(), 0U); + + auto out = VectorBatch::create(schema); + ASSERT_FALSE(table.read_batch(0, 10, *out)); +} + +TEST_F(ColumnarTableTests, SchemaAccessor) { + const std::string name = "col_test_schema"; + cleanup_table(name); + + Schema schema; + schema.add_column("col1", common::ValueType::TYPE_INT64); + schema.add_column("col2", common::ValueType::TYPE_FLOAT64); + + ColumnarTable table(name, *sm_, schema); + ASSERT_TRUE(table.create()); + + const auto& retrieved_schema = table.schema(); + ASSERT_EQ(retrieved_schema.column_count(), 2U); + ASSERT_EQ(retrieved_schema.get_column(0).name(), "col1"); + ASSERT_EQ(retrieved_schema.get_column(1).type(), common::ValueType::TYPE_FLOAT64); +} + +} // namespace diff --git a/tests/operator_tests.cpp b/tests/operator_tests.cpp new file mode 100644 index 00000000..b17e45e8 --- /dev/null +++ b/tests/operator_tests.cpp @@ -0,0 +1,590 @@ +/** + * @file operator_tests.cpp + * @brief Unit tests for executor/operator.cpp - Volcano-style execution operators + */ + +#include + +#include +#include + +#include "common/value.hpp" +#include "executor/operator.hpp" +#include "executor/types.hpp" +#include "parser/expression.hpp" + +using namespace cloudsql; +using namespace cloudsql::executor; +using namespace cloudsql::parser; + +namespace { + +// Helper to create a simple schema +Schema make_schema(const std::vector>& cols) { + Schema s; + for (const auto& [name, type] : cols) { + s.add_column(name, type); + } + return s; +} + +// Helper to create a tuple +Tuple make_tuple(const std::vector& vals) { + return Tuple(std::pmr::vector(vals.begin(), vals.end())); +} + +// Helper to create ColumnExpr +std::unique_ptr col_expr(const std::string& name) { + return std::make_unique(name); +} + +// Helper to create ConstantExpr +std::unique_ptr const_expr(const common::Value& val) { + return std::make_unique(val); +} + +// Helper to create BinaryExpr (comparison) +std::unique_ptr binary_expr(std::unique_ptr left, TokenType op, + std::unique_ptr right) { + return std::make_unique(std::move(left), op, std::move(right)); +} + +// Helper to create AggregateInfo +AggregateInfo make_agg(AggregateType type, const std::string& name, + std::unique_ptr expr = nullptr) { + AggregateInfo info; + info.type = type; + info.name = name; + info.expr = std::move(expr); + return info; +} + +// Helper: create a BufferScanOperator with test data +std::unique_ptr make_buffer_scan(const std::string& table_name, + const std::vector& data, + const Schema& schema) { + return std::make_unique("context1", table_name, data, schema); +} + +// Helper: create a FilterOperator with a condition +std::unique_ptr make_filter(std::unique_ptr child, + std::unique_ptr condition) { + return std::make_unique(std::move(child), std::move(condition)); +} + +// Helper: create a ProjectOperator +std::unique_ptr make_project(std::unique_ptr child, + std::vector> cols) { + return std::make_unique(std::move(child), std::move(cols)); +} + +// Helper: create a SortOperator +std::unique_ptr make_sort(std::unique_ptr child, + std::vector> keys, + std::vector asc) { + return std::make_unique(std::move(child), std::move(keys), std::move(asc)); +} + +// Helper: create an AggregateOperator +std::unique_ptr make_agg_op(std::unique_ptr child, + std::vector> group_by, + std::vector aggs) { + return std::make_unique(std::move(child), std::move(group_by), + std::move(aggs)); +} + +// Helper: create a LimitOperator +std::unique_ptr make_limit(std::unique_ptr child, int64_t limit, + int64_t offset = 0) { + return std::make_unique(std::move(child), limit, offset); +} + +// Helper: create a HashJoinOperator +std::unique_ptr make_hash_join(std::unique_ptr left, + std::unique_ptr right, + std::unique_ptr left_key, + std::unique_ptr right_key, + JoinType join_type = JoinType::Inner) { + return std::make_unique(std::move(left), std::move(right), + std::move(left_key), std::move(right_key), join_type); +} + +class OperatorTests : public ::testing::Test { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +TEST_F(OperatorTests, BufferScanBasic) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + data.push_back(make_tuple({common::Value::make_int64(1)})); + data.push_back(make_tuple({common::Value::make_int64(2)})); + data.push_back(make_tuple({common::Value::make_int64(3)})); + + auto scan = make_buffer_scan("test_table", data, schema); + ASSERT_TRUE(scan->init()); + ASSERT_TRUE(scan->open()); + + int count = 0; + Tuple tuple; + while (scan->next(tuple)) { + count++; + } + EXPECT_EQ(count, 3); + scan->close(); +} + +TEST_F(OperatorTests, BufferScanEmpty) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + + auto scan = make_buffer_scan("test_table", data, schema); + ASSERT_TRUE(scan->init()); + ASSERT_TRUE(scan->open()); + + Tuple tuple; + EXPECT_FALSE(scan->next(tuple)); + scan->close(); +} + +TEST_F(OperatorTests, BufferScanExhausted) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + data.push_back(make_tuple({common::Value::make_int64(1)})); + + auto scan = make_buffer_scan("test_table", data, schema); + ASSERT_TRUE(scan->init()); + ASSERT_TRUE(scan->open()); + + Tuple tuple; + EXPECT_TRUE(scan->next(tuple)); + EXPECT_FALSE(scan->next(tuple)); + EXPECT_FALSE(scan->next(tuple)); + scan->close(); +} + +TEST_F(OperatorTests, LimitBasic) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + for (int i = 0; i < 5; i++) { + data.push_back(make_tuple({common::Value::make_int64(i)})); + } + + auto scan = make_buffer_scan("test_table", data, schema); + auto limit = make_limit(std::move(scan), 2); + + ASSERT_TRUE(limit->init()); + ASSERT_TRUE(limit->open()); + + int count = 0; + Tuple tuple; + while (limit->next(tuple)) { + count++; + } + EXPECT_EQ(count, 2); + limit->close(); +} + +TEST_F(OperatorTests, LimitWithOffset) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + for (int i = 0; i < 5; i++) { + data.push_back(make_tuple({common::Value::make_int64(i)})); + } + + auto scan = make_buffer_scan("test_table", data, schema); + auto limit = make_limit(std::move(scan), 2, 2); // offset 2, limit 2 + + ASSERT_TRUE(limit->init()); + ASSERT_TRUE(limit->open()); + + int count = 0; + Tuple tuple; + while (limit->next(tuple)) { + count++; + } + EXPECT_EQ(count, 2); + limit->close(); +} + +TEST_F(OperatorTests, LimitZero) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + for (int i = 0; i < 5; i++) { + data.push_back(make_tuple({common::Value::make_int64(i)})); + } + + auto scan = make_buffer_scan("test_table", data, schema); + auto limit = make_limit(std::move(scan), 0); + + ASSERT_TRUE(limit->init()); + ASSERT_TRUE(limit->open()); + + Tuple tuple; + EXPECT_FALSE(limit->next(tuple)); + limit->close(); +} + +TEST_F(OperatorTests, LimitNegative) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + for (int i = 0; i < 5; i++) { + data.push_back(make_tuple({common::Value::make_int64(i)})); + } + + auto scan = make_buffer_scan("test_table", data, schema); + auto limit = make_limit(std::move(scan), -1); // negative = no limit + + ASSERT_TRUE(limit->init()); + ASSERT_TRUE(limit->open()); + + int count = 0; + Tuple tuple; + while (limit->next(tuple)) { + count++; + } + EXPECT_EQ(count, 5); // all tuples returned + limit->close(); +} + +TEST_F(OperatorTests, FilterBasic) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + for (int i = 0; i < 5; i++) { + data.push_back(make_tuple({common::Value::make_int64(i)})); + } + + auto scan = make_buffer_scan("test_table", data, schema); + // Filter: id >= 2 + auto filter = make_filter( + std::move(scan), + binary_expr(col_expr("id"), TokenType::Ge, const_expr(common::Value::make_int64(2)))); + + ASSERT_TRUE(filter->init()); + ASSERT_TRUE(filter->open()); + + int count = 0; + Tuple tuple; + while (filter->next(tuple)) { + count++; + } + EXPECT_EQ(count, 3); // 2, 3, 4 + filter->close(); +} + +TEST_F(OperatorTests, FilterAllFiltered) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + for (int i = 0; i < 5; i++) { + data.push_back(make_tuple({common::Value::make_int64(i)})); + } + + auto scan = make_buffer_scan("test_table", data, schema); + // Filter: id > 100 (filters all) + auto filter = make_filter( + std::move(scan), + binary_expr(col_expr("id"), TokenType::Gt, const_expr(common::Value::make_int64(100)))); + + ASSERT_TRUE(filter->init()); + ASSERT_TRUE(filter->open()); + + Tuple tuple; + EXPECT_FALSE(filter->next(tuple)); + filter->close(); +} + +TEST_F(OperatorTests, ProjectBasic) { + Schema schema = make_schema( + {{"id", common::ValueType::TYPE_INT64}, {"name", common::ValueType::TYPE_TEXT}}); + std::vector data; + data.push_back(make_tuple({common::Value::make_int64(1), common::Value::make_text("alice")})); + data.push_back(make_tuple({common::Value::make_int64(2), common::Value::make_text("bob")})); + + auto scan = make_buffer_scan("test_table", data, schema); + std::vector> cols; + cols.push_back(col_expr("name")); + auto project = make_project(std::move(scan), std::move(cols)); + + ASSERT_TRUE(project->init()); + ASSERT_TRUE(project->open()); + + int count = 0; + Tuple tuple; + while (project->next(tuple)) { + count++; + EXPECT_EQ(tuple.size(), 1U); + } + EXPECT_EQ(count, 2); + project->close(); +} + +TEST_F(OperatorTests, SortBasic) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + data.push_back(make_tuple({common::Value::make_int64(3)})); + data.push_back(make_tuple({common::Value::make_int64(1)})); + data.push_back(make_tuple({common::Value::make_int64(2)})); + + auto scan = make_buffer_scan("test_table", data, schema); + std::vector> keys; + keys.push_back(col_expr("id")); + auto sort = make_sort(std::move(scan), std::move(keys), {true}); // ascending + + ASSERT_TRUE(sort->init()); + ASSERT_TRUE(sort->open()); + + std::vector values; + Tuple tuple; + while (sort->next(tuple)) { + values.push_back(tuple.get(0).to_int64()); + } + ASSERT_EQ(values.size(), 3U); + EXPECT_EQ(values[0], 1); + EXPECT_EQ(values[1], 2); + EXPECT_EQ(values[2], 3); + sort->close(); +} + +TEST_F(OperatorTests, SortDescending) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + data.push_back(make_tuple({common::Value::make_int64(1)})); + data.push_back(make_tuple({common::Value::make_int64(3)})); + data.push_back(make_tuple({common::Value::make_int64(2)})); + + auto scan = make_buffer_scan("test_table", data, schema); + std::vector> keys; + keys.push_back(col_expr("id")); + auto sort = make_sort(std::move(scan), std::move(keys), {false}); // descending + + ASSERT_TRUE(sort->init()); + ASSERT_TRUE(sort->open()); + + std::vector values; + Tuple tuple; + while (sort->next(tuple)) { + values.push_back(tuple.get(0).to_int64()); + } + ASSERT_EQ(values.size(), 3U); + EXPECT_EQ(values[0], 3); + EXPECT_EQ(values[1], 2); + EXPECT_EQ(values[2], 1); + sort->close(); +} + +TEST_F(OperatorTests, AggregateCountAll) { + Schema schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector data; + for (int i = 0; i < 5; i++) { + data.push_back(make_tuple({common::Value::make_int64(i * 10)})); + } + + auto scan = make_buffer_scan("test_table", data, schema); + std::vector aggs; + aggs.push_back(make_agg(AggregateType::Count, "count")); // COUNT(*) + auto agg = make_agg_op(std::move(scan), {}, std::move(aggs)); + + ASSERT_TRUE(agg->init()); + ASSERT_TRUE(agg->open()); + + Tuple tuple; + EXPECT_TRUE(agg->next(tuple)); + EXPECT_EQ(tuple.get(0).to_int64(), 5); + EXPECT_FALSE(agg->next(tuple)); + agg->close(); +} + +TEST_F(OperatorTests, AggregateSum) { + Schema schema = make_schema({{"val", common::ValueType::TYPE_INT64}}); + std::vector data; + data.push_back(make_tuple({common::Value::make_int64(10)})); + data.push_back(make_tuple({common::Value::make_int64(20)})); + data.push_back(make_tuple({common::Value::make_int64(30)})); + + auto scan = make_buffer_scan("test_table", data, schema); + std::vector aggs; + aggs.push_back(make_agg(AggregateType::Sum, "total", col_expr("val"))); + auto agg = make_agg_op(std::move(scan), {}, std::move(aggs)); + + ASSERT_TRUE(agg->init()); + ASSERT_TRUE(agg->open()); + + Tuple tuple; + EXPECT_TRUE(agg->next(tuple)); + EXPECT_EQ(tuple.get(0).to_int64(), 60); + EXPECT_FALSE(agg->next(tuple)); + agg->close(); +} + +TEST_F(OperatorTests, AggregateMinMax) { + Schema schema = make_schema({{"val", common::ValueType::TYPE_INT64}}); + std::vector data; + data.push_back(make_tuple({common::Value::make_int64(30)})); + data.push_back(make_tuple({common::Value::make_int64(10)})); + data.push_back(make_tuple({common::Value::make_int64(20)})); + + auto scan = make_buffer_scan("test_table", data, schema); + std::vector aggs; + aggs.push_back(make_agg(AggregateType::Min, "min_val", col_expr("val"))); + aggs.push_back(make_agg(AggregateType::Max, "max_val", col_expr("val"))); + auto agg = make_agg_op(std::move(scan), {}, std::move(aggs)); + + ASSERT_TRUE(agg->init()); + ASSERT_TRUE(agg->open()); + + Tuple tuple; + EXPECT_TRUE(agg->next(tuple)); + EXPECT_EQ(tuple.get(0).to_int64(), 10); // min + EXPECT_EQ(tuple.get(1).to_int64(), 30); // max + EXPECT_FALSE(agg->next(tuple)); + agg->close(); +} + +TEST_F(OperatorTests, HashJoinInner) { + // Left table: one column with values 1, 2 + Schema left_schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector left_data; + left_data.push_back(make_tuple({common::Value::make_int64(1)})); + left_data.push_back(make_tuple({common::Value::make_int64(2)})); + left_data.push_back(make_tuple({common::Value::make_int64(3)})); // no match + + // Right table: one column with values 2, 3 + Schema right_schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector right_data; + right_data.push_back(make_tuple({common::Value::make_int64(2)})); + right_data.push_back(make_tuple({common::Value::make_int64(3)})); + right_data.push_back(make_tuple({common::Value::make_int64(4)})); // no match + + auto left_scan = make_buffer_scan("left_table", left_data, left_schema); + auto right_scan = make_buffer_scan("right_table", right_data, right_schema); + + auto join = make_hash_join(std::move(left_scan), std::move(right_scan), col_expr("id"), + col_expr("id"), JoinType::Inner); + + ASSERT_TRUE(join->init()); + ASSERT_TRUE(join->open()); + + std::vector> results; + Tuple tuple; + while (join->next(tuple)) { + results.push_back({tuple.get(0).to_int64(), tuple.get(1).to_int64()}); + } + + EXPECT_EQ(results.size(), 2U); + // 2 matches: (1,?) no, (2,2) yes, (3,3) yes + // Left has 1,2,3 - Right has 2,3,4 + // Inner join: (2,2), (3,3) + EXPECT_EQ(results[0].first, 2); + EXPECT_EQ(results[0].second, 2); + EXPECT_EQ(results[1].first, 3); + EXPECT_EQ(results[1].second, 3); + join->close(); +} + +TEST_F(OperatorTests, HashJoinLeft) { + // Left table: values 1, 2 + Schema left_schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector left_data; + left_data.push_back(make_tuple({common::Value::make_int64(1)})); // no match + left_data.push_back(make_tuple({common::Value::make_int64(2)})); // matches + + // Right table: values 2, 3 + Schema right_schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector right_data; + right_data.push_back(make_tuple({common::Value::make_int64(2)})); + right_data.push_back(make_tuple({common::Value::make_int64(3)})); + + auto left_scan = make_buffer_scan("left_table", left_data, left_schema); + auto right_scan = make_buffer_scan("right_table", right_data, right_schema); + + auto join = make_hash_join(std::move(left_scan), std::move(right_scan), col_expr("id"), + col_expr("id"), JoinType::Left); + + ASSERT_TRUE(join->init()); + ASSERT_TRUE(join->open()); + + std::vector results; + Tuple tuple; + while (join->next(tuple)) { + results.push_back(tuple.get(0).to_int64()); + } + + EXPECT_EQ(results.size(), 2U); + EXPECT_EQ(results[0], 1); // left tuple with no match - NULLs + EXPECT_EQ(results[1], 2); // matched + join->close(); +} + +TEST_F(OperatorTests, HashJoinEmpty) { + // Left has data + Schema left_schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector left_data; + left_data.push_back(make_tuple({common::Value::make_int64(1)})); + + // Right is empty + Schema right_schema = make_schema({{"id", common::ValueType::TYPE_INT64}}); + std::vector right_data; + + auto left_scan = make_buffer_scan("left_table", left_data, left_schema); + auto right_scan = make_buffer_scan("right_table", right_data, right_schema); + + auto join = make_hash_join(std::move(left_scan), std::move(right_scan), col_expr("id"), + col_expr("id"), JoinType::Inner); + + ASSERT_TRUE(join->init()); + ASSERT_TRUE(join->open()); + + Tuple tuple; + EXPECT_FALSE(join->next(tuple)); + join->close(); +} + +TEST_F(OperatorTests, PipelineFilterProject) { + // Input: (1,"alice"), (2,"bob"), (3,"charlie") + Schema schema = make_schema( + {{"id", common::ValueType::TYPE_INT64}, {"name", common::ValueType::TYPE_TEXT}}); + std::vector data; + data.push_back(make_tuple({common::Value::make_int64(1), common::Value::make_text("alice")})); + data.push_back(make_tuple({common::Value::make_int64(2), common::Value::make_text("bob")})); + data.push_back(make_tuple({common::Value::make_int64(3), common::Value::make_text("charlie")})); + + auto scan = make_buffer_scan("test_table", data, schema); + + // Filter: id >= 2 + auto filter = make_filter( + std::move(scan), + binary_expr(col_expr("id"), TokenType::Ge, const_expr(common::Value::make_int64(2)))); + + // Project: name column only + std::vector> cols; + cols.push_back(col_expr("name")); + auto project = make_project(std::move(filter), std::move(cols)); + + ASSERT_TRUE(project->init()); + ASSERT_TRUE(project->open()); + + int count = 0; + Tuple tuple; + while (project->next(tuple)) { + count++; + EXPECT_EQ(tuple.size(), 1U); + EXPECT_STREQ(tuple.get(0).as_text().c_str(), count == 1 ? "bob" : "charlie"); + } + EXPECT_EQ(count, 2); + project->close(); +} + +TEST_F(OperatorTests, OperatorTypeEnum) { + EXPECT_EQ(OperatorType::SeqScan, OperatorType::SeqScan); + EXPECT_EQ(OperatorType::IndexScan, OperatorType::IndexScan); + EXPECT_EQ(OperatorType::Filter, OperatorType::Filter); + EXPECT_EQ(OperatorType::Project, OperatorType::Project); + EXPECT_EQ(OperatorType::HashJoin, OperatorType::HashJoin); + EXPECT_EQ(OperatorType::Sort, OperatorType::Sort); + EXPECT_EQ(OperatorType::Aggregate, OperatorType::Aggregate); + EXPECT_EQ(OperatorType::Limit, OperatorType::Limit); + EXPECT_EQ(OperatorType::BufferScan, OperatorType::BufferScan); +} + +} // namespace diff --git a/tests/raft_manager_tests.cpp b/tests/raft_manager_tests.cpp new file mode 100644 index 00000000..57c3643e --- /dev/null +++ b/tests/raft_manager_tests.cpp @@ -0,0 +1,123 @@ +/** + * @file raft_manager_tests.cpp + * @brief Unit tests for RaftManager - Multi-Raft group management + */ + +#include + +#include +#include +#include +#include + +#include "common/cluster_manager.hpp" +#include "common/config.hpp" +#include "distributed/raft_group.hpp" +#include "distributed/raft_manager.hpp" +#include "network/rpc_server.hpp" + +using namespace cloudsql; +using namespace cloudsql::raft; +using namespace cloudsql::cluster; +using namespace cloudsql::network; + +namespace { + +class RaftManagerTests : public ::testing::Test { + protected: + void SetUp() override { + config_.mode = config::RunMode::Coordinator; + constexpr uint16_t TEST_PORT = 6100; + config_.cluster_port = TEST_PORT; + cm_ = std::make_unique(&config_); + rpc_ = std::make_unique(TEST_PORT); + ASSERT_TRUE(rpc_->start()) << "RpcServer failed to start - port may be in use"; + manager_ = std::make_unique("node1", *cm_, *rpc_); + } + + void TearDown() override { + if (manager_) { + manager_->stop(); + } + if (rpc_) { + rpc_->stop(); + } + } + + config::Config config_; + std::unique_ptr cm_; + std::unique_ptr rpc_; + std::unique_ptr manager_; +}; + +TEST_F(RaftManagerTests, GroupCreation) { + auto group = manager_->get_or_create_group(1); + ASSERT_NE(group, nullptr); + EXPECT_EQ(group->group_id(), 1); +} + +TEST_F(RaftManagerTests, GroupCreationReturnsExisting) { + auto group1 = manager_->get_or_create_group(1); + auto group2 = manager_->get_or_create_group(1); + EXPECT_EQ(group1, group2); +} + +TEST_F(RaftManagerTests, GetGroupExisting) { + auto group = manager_->get_or_create_group(42); + ASSERT_NE(group, nullptr); + EXPECT_EQ(group->group_id(), 42); + + auto retrieved = manager_->get_group(42); + ASSERT_NE(retrieved, nullptr); + EXPECT_EQ(retrieved, group); +} + +TEST_F(RaftManagerTests, GetGroupNonExistent) { + auto result = manager_->get_group(999); + EXPECT_EQ(result, nullptr); +} + +TEST_F(RaftManagerTests, GetGroupAfterGetOrCreate) { + auto created = manager_->get_or_create_group(77); + ASSERT_NE(created, nullptr); + + auto retrieved = manager_->get_group(77); + ASSERT_NE(retrieved, nullptr); + EXPECT_EQ(created, retrieved); +} + +TEST_F(RaftManagerTests, MultipleGroups) { + auto group1 = manager_->get_or_create_group(1); + auto group2 = manager_->get_or_create_group(2); + auto group3 = manager_->get_or_create_group(3); + + ASSERT_NE(group1, nullptr); + ASSERT_NE(group2, nullptr); + ASSERT_NE(group3, nullptr); + + EXPECT_NE(group1, group2); + EXPECT_NE(group2, group3); + EXPECT_NE(group1, group3); + + EXPECT_EQ(group1->group_id(), 1); + EXPECT_EQ(group2->group_id(), 2); + EXPECT_EQ(group3->group_id(), 3); +} + +TEST_F(RaftManagerTests, LifecycleStartStop) { + manager_->start(); + manager_->stop(); + manager_->start(); + manager_->stop(); +} + +TEST_F(RaftManagerTests, GetOrCreateGroupAfterStartStop) { + auto group1 = manager_->get_or_create_group(1); + manager_->start(); + manager_->stop(); + + auto group2 = manager_->get_or_create_group(1); + EXPECT_EQ(group1, group2); +} + +} // namespace diff --git a/tests/raft_protocol_tests.cpp b/tests/raft_protocol_tests.cpp new file mode 100644 index 00000000..3f5fd046 --- /dev/null +++ b/tests/raft_protocol_tests.cpp @@ -0,0 +1,135 @@ +/** + * @file raft_protocol_tests.cpp + * @brief Unit tests for RaftGroup protocol implementation + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include "common/cluster_manager.hpp" +#include "common/config.hpp" +#include "distributed/raft_group.hpp" +#include "distributed/raft_types.hpp" +#include "network/rpc_message.hpp" +#include "network/rpc_server.hpp" + +using namespace cloudsql; +using namespace cloudsql::raft; +using namespace cloudsql::cluster; +using namespace cloudsql::network; + +namespace { + +class MockStateMachine : public RaftStateMachine { + public: + void apply(const LogEntry& entry) override { applied_entries_.push_back(entry); } + + std::vector applied_entries_; +}; + +class RaftProtocolTests : public ::testing::Test { + protected: + void SetUp() override { + config_.mode = config::RunMode::Coordinator; + constexpr uint16_t TEST_PORT = 6200; + config_.cluster_port = TEST_PORT; + cm_ = std::make_unique(&config_); + rpc_ = std::make_unique(TEST_PORT); + rpc_->start(); + state_machine_ = std::make_unique(); + } + + void TearDown() override { + if (group_) { + group_->stop(); + } + if (rpc_) { + rpc_->stop(); + } + cm_.reset(); + } + + config::Config config_; + std::unique_ptr cm_; + std::unique_ptr rpc_; + std::unique_ptr state_machine_; + std::unique_ptr group_; +}; + +TEST_F(RaftProtocolTests, ReplicateFailsWhenNotLeader) { + group_ = std::make_unique(1, "node1", *cm_, *rpc_); + group_->set_state_machine(state_machine_.get()); + + std::vector data = {1, 2, 3}; + EXPECT_FALSE(group_->replicate(data)); + EXPECT_FALSE(group_->is_leader()); +} + +TEST_F(RaftProtocolTests, ReplicateAppendsEntry) { + group_ = std::make_unique(1, "node1", *cm_, *rpc_); + group_->set_state_machine(state_machine_.get()); + group_->start(); + + std::this_thread::sleep_for(std::chrono::milliseconds(400)); + + std::vector data = {1, 2, 3}; + if (!group_->is_leader()) { + GTEST_SKIP() << "Could not become leader in test timeout"; + } + + EXPECT_TRUE(group_->replicate(data)); + EXPECT_EQ(state_machine_->applied_entries_.size(), 0); +} + +TEST_F(RaftProtocolTests, StatePersistence) { + const uint16_t group_id = 15000; + + { + auto local_group = std::make_unique(group_id, "node1", *cm_, *rpc_); + local_group->set_state_machine(state_machine_.get()); + local_group->start(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + local_group->stop(); + } + + auto loaded_group = std::make_unique(group_id, "node1", *cm_, *rpc_); + loaded_group->set_state_machine(state_machine_.get()); +} + +TEST_F(RaftProtocolTests, LoadStateNonExistent) { + group_ = std::make_unique(9999, "nonexistent_node", *cm_, *rpc_); + group_->set_state_machine(state_machine_.get()); +} + +TEST_F(RaftProtocolTests, MultipleGroupsDistinctState) { + auto group1 = std::make_unique(1, "node1", *cm_, *rpc_); + auto group2 = std::make_unique(2, "node1", *cm_, *rpc_); + + EXPECT_NE(group1->group_id(), group2->group_id()); + EXPECT_EQ(group1->group_id(), 1); + EXPECT_EQ(group2->group_id(), 2); +} + +TEST_F(RaftProtocolTests, GetGroupId) { + group_ = std::make_unique(42, "node1", *cm_, *rpc_); + EXPECT_EQ(group_->group_id(), 42); +} + +TEST_F(RaftProtocolTests, StopWithoutStart) { + group_ = std::make_unique(1, "node1", *cm_, *rpc_); + group_->stop(); +} + +TEST_F(RaftProtocolTests, SetStateMachine) { + group_ = std::make_unique(1, "node1", *cm_, *rpc_); + group_->set_state_machine(state_machine_.get()); + group_->set_state_machine(nullptr); +} + +} // namespace diff --git a/tests/rpc_server_tests.cpp b/tests/rpc_server_tests.cpp new file mode 100644 index 00000000..6166dc35 --- /dev/null +++ b/tests/rpc_server_tests.cpp @@ -0,0 +1,286 @@ +/** + * @file rpc_server_tests.cpp + * @brief Unit tests for RpcServer - Internal RPC server for node-to-node communication + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "network/rpc_client.hpp" +#include "network/rpc_message.hpp" +#include "network/rpc_server.hpp" + +using namespace cloudsql::network; + +namespace { + +// Ignore SIGPIPE to prevent crashes when writing to closed sockets +struct SigpipeGuard { + SigpipeGuard() { std::signal(SIGPIPE, SIG_IGN); } +}; +SigpipeGuard g_sigpipe; + +class RpcServerTests : public ::testing::Test { + protected: + void SetUp() override { + // Use a unique port for each test to avoid TIME_WAIT issues + port_ = TEST_PORT_BASE_ + next_port_++; + server_ = std::make_unique(port_); + handler_called_ = false; + } + + void TearDown() override { + if (server_) { + server_->stop(); + } + // Small delay to allow socket to settle + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + static constexpr uint16_t TEST_PORT_BASE_ = 6300; + static std::atomic next_port_; + uint16_t port_; + std::unique_ptr server_; + std::atomic handler_called_{false}; +}; + +std::atomic RpcServerTests::next_port_{0}; + +TEST_F(RpcServerTests, LifecycleStartStop) { + ASSERT_TRUE(server_->start()); + server_->stop(); + ASSERT_TRUE(server_->start()); + server_->stop(); +} + +TEST_F(RpcServerTests, DoubleStartReturnsFalse) { + ASSERT_TRUE(server_->start()); + ASSERT_FALSE(server_->start()); + // Force cleanup before next test + server_->stop(); + // Give socket time to release + std::this_thread::sleep_for(std::chrono::milliseconds(50)); +} + +TEST_F(RpcServerTests, SetAndGetHandler) { + auto handler = [](const RpcHeader&, const std::vector&, int) {}; + server_->set_handler(RpcType::Heartbeat, handler); + auto retrieved = server_->get_handler(RpcType::Heartbeat); + ASSERT_NE(retrieved, nullptr); +} + +TEST_F(RpcServerTests, GetHandlerNotSet) { + auto retrieved = server_->get_handler(RpcType::RegisterNode); + EXPECT_EQ(retrieved, nullptr); +} + +TEST_F(RpcServerTests, HandlerOverride) { + int call_count = 0; + auto handler1 = [&](const RpcHeader&, const std::vector&, int) { call_count++; }; + auto handler2 = [&](const RpcHeader&, const std::vector&, int) { call_count += 10; }; + + server_->set_handler(RpcType::Heartbeat, handler1); + server_->set_handler(RpcType::Heartbeat, handler2); + auto retrieved = server_->get_handler(RpcType::Heartbeat); + ASSERT_NE(retrieved, nullptr); +} + +TEST_F(RpcServerTests, ClearHandlersAfterStop) { + auto handler = [](const RpcHeader&, const std::vector&, int) {}; + server_->set_handler(RpcType::Heartbeat, handler); + server_->start(); + server_->stop(); + auto retrieved = server_->get_handler(RpcType::Heartbeat); + EXPECT_EQ(retrieved, nullptr); +} + +TEST_F(RpcServerTests, ZeroPayloadHandler) { + server_->start(); + + bool called = false; + server_->set_handler(RpcType::Heartbeat, + [&called](const RpcHeader& h, const std::vector& p, int fd) { + called = true; + EXPECT_EQ(p.size(), 0U); + }); + + // Connect and send RPC with zero payload + int fd = socket(AF_INET, SOCK_STREAM, 0); + ASSERT_GE(fd, 0); + + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port_); + inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr); + + ASSERT_EQ(connect(fd, (sockaddr*)&addr, sizeof(addr)), 0); + + // Send header + RpcHeader hdr; + hdr.type = RpcType::Heartbeat; + hdr.payload_len = 0; + char h_buf[RpcHeader::HEADER_SIZE]; + hdr.encode(h_buf); + send(fd, h_buf, RpcHeader::HEADER_SIZE, 0); + + // Give time for the server to process and call the handler + for (int i = 0; i < 10 && !called; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + EXPECT_TRUE(called); + + close(fd); +} + +TEST_F(RpcServerTests, MultipleConnections) { + server_->start(); + + int call_count = 0; + server_->set_handler( + RpcType::Heartbeat, + [&call_count](const RpcHeader&, const std::vector&, int) { call_count++; }); + + std::vector fds; + for (int i = 0; i < 5; ++i) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port_); + inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr); + if (connect(fd, (sockaddr*)&addr, sizeof(addr)) == 0) { + fds.push_back(fd); + } + } + + // Send RPCs + for (int fd : fds) { + RpcHeader hdr; + hdr.type = RpcType::Heartbeat; + hdr.payload_len = 0; + char h_buf[RpcHeader::HEADER_SIZE]; + hdr.encode(h_buf); + send(fd, h_buf, RpcHeader::HEADER_SIZE, 0); + } + + // Give time for the server to process all 5 + for (int i = 0; i < 20 && call_count < 5; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + + for (int fd : fds) { + close(fd); + } + + EXPECT_EQ(call_count, 5); +} + +TEST_F(RpcServerTests, ClientDisconnectMidHeader) { + server_->start(); + + server_->set_handler(RpcType::Heartbeat, + [](const RpcHeader&, const std::vector&, int) {}); + + int fd = socket(AF_INET, SOCK_STREAM, 0); + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port_); + inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr); + connect(fd, (sockaddr*)&addr, sizeof(addr)); + + // Send partial header then disconnect + char partial[6]; + std::memset(partial, 0, 6); + send(fd, partial, 6, 0); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + close(fd); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); +} + +TEST_F(RpcServerTests, ClientDisconnectMidPayload) { + server_->start(); + + server_->set_handler(RpcType::Heartbeat, + [](const RpcHeader&, const std::vector&, int) {}); + + int fd = socket(AF_INET, SOCK_STREAM, 0); + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port_); + inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr); + connect(fd, (sockaddr*)&addr, sizeof(addr)); + + // Send full header indicating payload but don't send payload + RpcHeader hdr; + hdr.type = RpcType::Heartbeat; + hdr.payload_len = 100; // Request 100 bytes but we won't send them + char h_buf[RpcHeader::HEADER_SIZE]; + hdr.encode(h_buf); + send(fd, h_buf, RpcHeader::HEADER_SIZE, 0); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + close(fd); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); +} + +TEST_F(RpcServerTests, FullRoundTripWithClient) { + server_->start(); + + server_->set_handler(RpcType::QueryResults, + [](const RpcHeader& h, const std::vector& p, int fd) { + // Echo back the payload + RpcHeader resp_h; + resp_h.type = RpcType::QueryResults; + resp_h.payload_len = static_cast(p.size()); + char h_buf[RpcHeader::HEADER_SIZE]; + resp_h.encode(h_buf); + send(fd, h_buf, RpcHeader::HEADER_SIZE, 0); + if (!p.empty()) { + send(fd, p.data(), p.size(), 0); + } + }); + + RpcClient client("127.0.0.1", port_); + ASSERT_TRUE(client.connect()); + + std::vector payload = {1, 2, 3, 4, 5}; + std::vector response; + ASSERT_TRUE(client.call(RpcType::QueryResults, payload, response, 0)); + + EXPECT_EQ(response.size(), 5U); + EXPECT_EQ(response[0], 1); + EXPECT_EQ(response[4], 5); +} + +TEST_F(RpcServerTests, NoHandlerRegistered) { + server_->start(); + // Don't set any handler + + int fd = socket(AF_INET, SOCK_STREAM, 0); + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port_); + inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr); + connect(fd, (sockaddr*)&addr, sizeof(addr)); + + // Send an RPC with no handler registered + RpcHeader hdr; + hdr.type = RpcType::Error; + hdr.payload_len = 0; + char h_buf[RpcHeader::HEADER_SIZE]; + hdr.encode(h_buf); + send(fd, h_buf, RpcHeader::HEADER_SIZE, 0); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + close(fd); + server_->stop(); +} + +} // namespace diff --git a/tests/storage_manager_tests.cpp b/tests/storage_manager_tests.cpp new file mode 100644 index 00000000..1900308f --- /dev/null +++ b/tests/storage_manager_tests.cpp @@ -0,0 +1,236 @@ +/** + * @file storage_manager_tests.cpp + * @brief Unit tests for StorageManager - low-level disk I/O and page-level access + */ + +#include + +#include +#include +#include +#include +#include + +#include "storage/storage_manager.hpp" + +using namespace cloudsql::storage; + +namespace { + +static void cleanup_file(const std::string& dir, const std::string& name) { + std::remove((dir + "/" + name).c_str()); +} + +class StorageManagerTests : public ::testing::Test { + protected: + void SetUp() override { + sm_ = std::make_unique("./test_data"); + sm_->create_dir_if_not_exists(); + } + + void TearDown() override { sm_.reset(); } + + std::unique_ptr sm_; +}; + +TEST_F(StorageManagerTests, OpenCloseBasic) { + const std::string filename = "open_close_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_TRUE(sm_->open_file(filename)); + ASSERT_TRUE(sm_->close_file(filename)); +} + +TEST_F(StorageManagerTests, OpenNonExistentCreatesFile) { + const std::string filename = "new_file_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_FALSE(sm_->file_exists(filename)); + ASSERT_TRUE(sm_->open_file(filename)); + ASSERT_TRUE(sm_->file_exists(filename)); +} + +TEST_F(StorageManagerTests, OpenTwiceReturnsTrue) { + const std::string filename = "double_open_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_TRUE(sm_->open_file(filename)); + ASSERT_TRUE(sm_->open_file(filename)); + ASSERT_TRUE(sm_->close_file(filename)); +} + +TEST_F(StorageManagerTests, CloseNonExistentReturnsFalse) { + ASSERT_FALSE(sm_->close_file("nonexistent_file.db")); +} + +TEST_F(StorageManagerTests, ReadWritePageBasic) { + const std::string filename = "page_rw_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_TRUE(sm_->open_file(filename)); + + char write_buf[StorageManager::PAGE_SIZE]; + char read_buf[StorageManager::PAGE_SIZE]; + std::memset(write_buf, 0, StorageManager::PAGE_SIZE); + std::memset(read_buf, 0, StorageManager::PAGE_SIZE); + + // Write pattern to page 0 + for (int i = 0; i < 16; ++i) { + write_buf[i * 16] = static_cast(i); + } + ASSERT_TRUE(sm_->write_page(filename, 0, write_buf)); + + // Read back and verify + ASSERT_TRUE(sm_->read_page(filename, 0, read_buf)); + ASSERT_EQ(std::memcmp(write_buf, read_buf, StorageManager::PAGE_SIZE), 0); +} + +TEST_F(StorageManagerTests, ReadBeyondEOFFillsZeros) { + const std::string filename = "beyond_eof_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_TRUE(sm_->open_file(filename)); + + char read_buf[StorageManager::PAGE_SIZE]; + std::memset(read_buf, 0xFF, StorageManager::PAGE_SIZE); // Fill with sentinel + + // Read page 10 from empty file - should zero-fill + ASSERT_TRUE(sm_->read_page(filename, 10, read_buf)); + + // Verify all zeros + for (size_t i = 0; i < StorageManager::PAGE_SIZE; ++i) { + EXPECT_EQ(read_buf[i], 0) << "Byte at index " << i << " was not zero"; + } +} + +TEST_F(StorageManagerTests, PartialReadReturnsFalse) { + const std::string filename = "partial_read_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_TRUE(sm_->open_file(filename)); + + // Write a small amount of data + char write_buf[StorageManager::PAGE_SIZE]; + std::memset(write_buf, 0xAB, StorageManager::PAGE_SIZE); + ASSERT_TRUE(sm_->write_page(filename, 0, write_buf)); + + // Try to read the small write as a full page should succeed (EOF handling fills zeros) + char read_buf[StorageManager::PAGE_SIZE]; + ASSERT_TRUE(sm_->read_page(filename, 0, read_buf)); +} + +TEST_F(StorageManagerTests, AllocatePageOnEmptyFile) { + const std::string filename = "allocate_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_TRUE(sm_->open_file(filename)); + ASSERT_EQ(sm_->allocate_page(filename), 0U); +} + +TEST_F(StorageManagerTests, AllocatePageSequential) { + const std::string filename = "allocate_seq_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_TRUE(sm_->open_file(filename)); + + // allocate_page returns next page index based on file size + // But it does NOT write to file - you need to write_page + ASSERT_EQ(sm_->allocate_page(filename), 0U); + + // Write a page, then allocate should give next index + char buf[StorageManager::PAGE_SIZE]; + std::memset(buf, 0, StorageManager::PAGE_SIZE); + ASSERT_TRUE(sm_->write_page(filename, 0, buf)); + ASSERT_EQ(sm_->allocate_page(filename), 1U); +} + +TEST_F(StorageManagerTests, CreateDirIfNotExistsBasic) { + // Directory should already exist from SetUp + ASSERT_TRUE(sm_->create_dir_if_not_exists()); +} + +TEST_F(StorageManagerTests, CreateDirAlreadyExists) { + // create_dir_if_not_exists should return true even if dir exists + ASSERT_TRUE(sm_->create_dir_if_not_exists()); + ASSERT_TRUE(sm_->create_dir_if_not_exists()); +} + +TEST_F(StorageManagerTests, FileExistsAfterOpen) { + const std::string filename = "exists_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_FALSE(sm_->file_exists(filename)); + ASSERT_TRUE(sm_->open_file(filename)); + ASSERT_TRUE(sm_->file_exists(filename)); +} + +TEST_F(StorageManagerTests, GetFullPath) { + const std::string path = sm_->get_full_path("test.db"); + EXPECT_EQ(path, "./test_data/test.db"); +} + +TEST_F(StorageManagerTests, MultipleFilesOpen) { + const std::string file1 = "multi1.db"; + const std::string file2 = "multi2.db"; + const std::string file3 = "multi3.db"; + cleanup_file("./test_data", file1); + cleanup_file("./test_data", file2); + cleanup_file("./test_data", file3); + + ASSERT_TRUE(sm_->open_file(file1)); + ASSERT_TRUE(sm_->open_file(file2)); + ASSERT_TRUE(sm_->open_file(file3)); + + ASSERT_TRUE(sm_->file_exists(file1)); + ASSERT_TRUE(sm_->file_exists(file2)); + ASSERT_TRUE(sm_->file_exists(file3)); +} + +TEST_F(StorageManagerTests, StatsAccurateAfterOperations) { + const std::string filename = "stats_test.db"; + cleanup_file("./test_data", filename); + + const auto& stats = sm_->get_stats(); + auto initial_pages_read = stats.pages_read.load(); + auto initial_pages_written = stats.pages_written.load(); + + ASSERT_TRUE(sm_->open_file(filename)); + + char buf[StorageManager::PAGE_SIZE]; + std::memset(buf, 0, StorageManager::PAGE_SIZE); + ASSERT_TRUE(sm_->write_page(filename, 0, buf)); + ASSERT_TRUE(sm_->read_page(filename, 0, buf)); + + EXPECT_GT(stats.pages_written.load(), initial_pages_written); +} + +TEST_F(StorageManagerTests, WriteAndReadDifferentPages) { + const std::string filename = "diff_pages_test.db"; + cleanup_file("./test_data", filename); + + ASSERT_TRUE(sm_->open_file(filename)); + + char page0[StorageManager::PAGE_SIZE]; + char page1[StorageManager::PAGE_SIZE]; + char page5[StorageManager::PAGE_SIZE]; + char read_buf[StorageManager::PAGE_SIZE]; + + std::memset(page0, 0xAA, StorageManager::PAGE_SIZE); + std::memset(page1, 0xBB, StorageManager::PAGE_SIZE); + std::memset(page5, 0xCC, StorageManager::PAGE_SIZE); + + ASSERT_TRUE(sm_->write_page(filename, 0, page0)); + ASSERT_TRUE(sm_->write_page(filename, 1, page1)); + ASSERT_TRUE(sm_->write_page(filename, 5, page5)); + + ASSERT_TRUE(sm_->read_page(filename, 0, read_buf)); + EXPECT_EQ(std::memcmp(page0, read_buf, StorageManager::PAGE_SIZE), 0); + + ASSERT_TRUE(sm_->read_page(filename, 1, read_buf)); + EXPECT_EQ(std::memcmp(page1, read_buf, StorageManager::PAGE_SIZE), 0); + + ASSERT_TRUE(sm_->read_page(filename, 5, read_buf)); + EXPECT_EQ(std::memcmp(page5, read_buf, StorageManager::PAGE_SIZE), 0); +} + +} // namespace