diff --git a/CMakeLists.txt b/CMakeLists.txt index 32cd1bd1..00d5586d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,10 +120,12 @@ set(SPARROW_IPC_HEADERS ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/arrow_interface/arrow_schema/private_data.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_output_stream.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/chunk_memory_serializer.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/compression.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/config.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/config/sparrow_ipc_version.hpp - ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/compression.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_array_impl.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_fixedsizebinary_array.hpp + ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_interval_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_primitive_array.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_utils.hpp ${SPARROW_IPC_INCLUDE_DIR}/sparrow_ipc/deserialize_variable_size_binary_array.hpp diff --git a/include/sparrow_ipc/deserialize_array_impl.hpp b/include/sparrow_ipc/deserialize_array_impl.hpp new file mode 100644 index 00000000..ee781223 --- /dev/null +++ b/include/sparrow_ipc/deserialize_array_impl.hpp @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include + +#include + +#include "Message_generated.h" +#include "sparrow_ipc/arrow_interface/arrow_array.hpp" +#include "sparrow_ipc/arrow_interface/arrow_schema.hpp" +#include "sparrow_ipc/deserialize_utils.hpp" + +namespace sparrow_ipc::detail +{ + /** + * @brief Generic implementation for deserializing non-owning arrays with simple layout. + * + * This function provides the common deserialization logic for array types that have + * a validity buffer and a single data buffer (e.g., primitive_array, interval_array). + * + * @tparam ArrayType The array type template (e.g., sparrow::primitive_array) + * @tparam T The element type + * + * @param record_batch The FlatBuffer RecordBatch containing metadata + * @param body The raw buffer data + * @param name The array column name + * @param metadata Optional metadata pairs + * @param nullable Whether the array is nullable + * @param buffer_index The current buffer index (incremented by this function) + * + * @return The deserialized array of type ArrayType + */ + template class ArrayType, typename T> + [[nodiscard]] ArrayType deserialize_non_owning_simple_array( + const org::apache::arrow::flatbuf::RecordBatch& record_batch, + std::span body, + std::string_view name, + const std::optional>& metadata, + bool nullable, + size_t& buffer_index + ) + { + const std::string_view format = data_type_to_format( + sparrow::detail::get_data_type_from_array>::get() + ); + + // Set up flags based on nullable + std::optional> flags; + if (nullable) + { + flags = std::unordered_set{sparrow::ArrowFlag::NULLABLE}; + } + + ArrowSchema schema = make_non_owning_arrow_schema( + format, + name.data(), + metadata, + flags, + 0, + nullptr, + nullptr + ); + + const auto compression = record_batch.compression(); + std::vector buffers; + + auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index); + auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index); + + if (compression) + { + buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); + buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression)); + } + else + { + buffers.emplace_back(validity_buffer_span); + buffers.emplace_back(data_buffer_span); + } + + // TODO bitmap_ptr is not used anymore... Leave it for now, and remove later if no need confirmed + const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length()); + + ArrowArray array = make_arrow_array( + record_batch.length(), + null_count, + 0, + 0, + nullptr, + nullptr, + std::move(buffers) + ); + + sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; + return ArrayType{std::move(ap)}; + } +} diff --git a/include/sparrow_ipc/deserialize_interval_array.hpp b/include/sparrow_ipc/deserialize_interval_array.hpp new file mode 100644 index 00000000..6cf6b23c --- /dev/null +++ b/include/sparrow_ipc/deserialize_interval_array.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include "Message_generated.h" +#include "sparrow_ipc/deserialize_array_impl.hpp" + +namespace sparrow_ipc +{ + template + [[nodiscard]] sparrow::interval_array deserialize_non_owning_interval_array( + const org::apache::arrow::flatbuf::RecordBatch& record_batch, + std::span body, + std::string_view name, + const std::optional>& metadata, + bool nullable, + size_t& buffer_index + ) + { + return detail::deserialize_non_owning_simple_array( + record_batch, + body, + name, + metadata, + nullable, + buffer_index + ); + } +} diff --git a/include/sparrow_ipc/deserialize_primitive_array.hpp b/include/sparrow_ipc/deserialize_primitive_array.hpp index 01949ac8..6904f036 100644 --- a/include/sparrow_ipc/deserialize_primitive_array.hpp +++ b/include/sparrow_ipc/deserialize_primitive_array.hpp @@ -1,16 +1,13 @@ #pragma once #include -#include #include #include #include #include "Message_generated.h" -#include "sparrow_ipc/arrow_interface/arrow_array.hpp" -#include "sparrow_ipc/arrow_interface/arrow_schema.hpp" -#include "sparrow_ipc/deserialize_utils.hpp" +#include "sparrow_ipc/deserialize_array_impl.hpp" namespace sparrow_ipc { @@ -24,58 +21,13 @@ namespace sparrow_ipc size_t& buffer_index ) { - const std::string_view format = data_type_to_format( - sparrow::detail::get_data_type_from_array>::get() - ); - - // Set up flags based on nullable - std::optional> flags; - if (nullable) - { - flags = std::unordered_set{sparrow::ArrowFlag::NULLABLE}; - } - - ArrowSchema schema = make_non_owning_arrow_schema( - format, - name.data(), + return detail::deserialize_non_owning_simple_array( + record_batch, + body, + name, metadata, - flags, - 0, - nullptr, - nullptr - ); - - const auto compression = record_batch.compression(); - std::vector buffers; - - auto validity_buffer_span = utils::get_buffer(record_batch, body, buffer_index); - auto data_buffer_span = utils::get_buffer(record_batch, body, buffer_index); - - if (compression) - { - buffers.push_back(utils::get_decompressed_buffer(validity_buffer_span, compression)); - buffers.push_back(utils::get_decompressed_buffer(data_buffer_span, compression)); - } - else - { - buffers.push_back(validity_buffer_span); - buffers.push_back(data_buffer_span); - } - - // TODO bitmap_ptr is not used anymore... Leave it for now, and remove later if no need confirmed - const auto [bitmap_ptr, null_count] = utils::get_bitmap_pointer_and_null_count(validity_buffer_span, record_batch.length()); - - ArrowArray array = make_arrow_array( - record_batch.length(), - null_count, - 0, - 0, - nullptr, - nullptr, - std::move(buffers) + nullable, + buffer_index ); - - sparrow::arrow_proxy ap{std::move(array), std::move(schema)}; - return sparrow::primitive_array{std::move(ap)}; } } diff --git a/src/deserialize.cpp b/src/deserialize.cpp index 92063de1..603cece3 100644 --- a/src/deserialize.cpp +++ b/src/deserialize.cpp @@ -3,6 +3,7 @@ #include #include "sparrow_ipc/deserialize_fixedsizebinary_array.hpp" +#include "sparrow_ipc/deserialize_interval_array.hpp" #include "sparrow_ipc/deserialize_primitive_array.hpp" #include "sparrow_ipc/deserialize_variable_size_binary_array.hpp" #include "sparrow_ipc/encapsulated_message.hpp" @@ -11,11 +12,23 @@ namespace sparrow_ipc { + namespace + { + // Integer bit width constants + constexpr int32_t BIT_WIDTH_8 = 8; + constexpr int32_t BIT_WIDTH_16 = 16; + constexpr int32_t BIT_WIDTH_32 = 32; + constexpr int32_t BIT_WIDTH_64 = 64; + + // End-of-stream marker size in bytes + constexpr size_t END_OF_STREAM_MARKER_SIZE = 8; + } const org::apache::arrow::flatbuf::RecordBatch* deserialize_record_batch_message(std::span data, size_t& current_offset) { current_offset += sizeof(uint32_t); - const auto batch_message = org::apache::arrow::flatbuf::GetMessage(data.data() + current_offset); + const auto message_data = data.subspan(current_offset); + const auto* batch_message = org::apache::arrow::flatbuf::GetMessage(message_data.data()); if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch) { throw std::runtime_error("Expected RecordBatch message, but got a different type."); @@ -28,21 +41,21 @@ namespace sparrow_ipc * * This function processes each field in the schema and deserializes the corresponding * data from the RecordBatch into sparrow::array objects. It handles various Arrow data - * types including primitive types (bool, integers, floating point), binary data, and - * string data with their respective size variants. + * types including primitive types (bool, integers, floating point), binary data, string + * data, fixed-size binary data, and interval types. * * @param record_batch The Apache Arrow FlatBuffer RecordBatch containing the serialized data * @param schema The Apache Arrow FlatBuffer Schema defining the structure and types of the data * @param encapsulated_message The message containing the binary data buffers - * @param field_metadata Metadata for each field + * @param field_metadata Metadata associated with each field in the schema * * @return std::vector A vector of deserialized arrays, one for each field in the schema * - * @throws std::runtime_error If an unsupported data type, integer bit width, or floating point precision - * is encountered + * @throws std::runtime_error If an unsupported data type, integer bit width, floating point precision, + * or interval unit is encountered * - * The function maintains a buffer index that is incremented as it processes each field - * to correctly map data buffers to their corresponding arrays. + * @note The function maintains a buffer index that is incremented as it processes each field + * to correctly map data buffers to their corresponding arrays. */ std::vector get_arrays_from_record_batch( const org::apache::arrow::flatbuf::RecordBatch& record_batch, @@ -90,7 +103,7 @@ namespace sparrow_ipc break; case org::apache::arrow::flatbuf::Type::Int: { - const auto int_type = field->type_as_Int(); + const auto* int_type = field->type_as_Int(); const auto bit_width = int_type->bitWidth(); const bool is_signed = int_type->is_signed(); @@ -99,11 +112,11 @@ namespace sparrow_ipc switch (bit_width) { // clang-format off - case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - default: throw std::runtime_error("Unsupported integer bit width."); + case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width)); // clang-format on } } @@ -112,11 +125,11 @@ namespace sparrow_ipc switch (bit_width) { // clang-format off - case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; - default: throw std::runtime_error("Unsupported integer bit width."); + case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; + default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width)); // clang-format on } } @@ -124,7 +137,7 @@ namespace sparrow_ipc break; case org::apache::arrow::flatbuf::Type::FloatingPoint: { - const auto float_type = field->type_as_FloatingPoint(); + const auto* float_type = field->type_as_FloatingPoint(); switch (float_type->precision()) { // clang-format off @@ -138,14 +151,17 @@ namespace sparrow_ipc arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()()); break; default: - throw std::runtime_error("Unsupported floating point precision."); + throw std::runtime_error( + "Unsupported floating point precision: " + + std::to_string(static_cast(float_type->precision())) + ); // clang-format on } break; } case org::apache::arrow::flatbuf::Type::FixedSizeBinary: { - const auto fixed_size_binary_field = field->type_as_FixedSizeBinary(); + const auto* fixed_size_binary_field = field->type_as_FixedSizeBinary(); arrays.emplace_back(deserialize_non_owning_fixedwidthbinary( record_batch, encapsulated_message.body(), @@ -205,8 +221,61 @@ namespace sparrow_ipc ) ); break; + case org::apache::arrow::flatbuf::Type::Interval: + { + const auto* interval_type = field->type_as_Interval(); + const org::apache::arrow::flatbuf::IntervalUnit interval_unit = interval_type->unit(); + switch (interval_unit) + { + case org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH: + arrays.emplace_back( + deserialize_non_owning_interval_array( + record_batch, + encapsulated_message.body(), + name, + metadata, + nullable, + buffer_index + ) + ); + break; + case org::apache::arrow::flatbuf::IntervalUnit::DAY_TIME: + arrays.emplace_back( + deserialize_non_owning_interval_array( + record_batch, + encapsulated_message.body(), + name, + metadata, + nullable, + buffer_index + ) + ); + break; + case org::apache::arrow::flatbuf::IntervalUnit::MONTH_DAY_NANO: + arrays.emplace_back( + deserialize_non_owning_interval_array( + record_batch, + encapsulated_message.body(), + name, + metadata, + nullable, + buffer_index + ) + ); + break; + default: + throw std::runtime_error( + "Unsupported interval unit: " + + std::to_string(static_cast(interval_unit)) + ); + } + } + break; default: - throw std::runtime_error("Unsupported type."); + throw std::runtime_error( + "Unsupported field type: " + std::to_string(static_cast(field_type)) + + " for field '" + name + "'" + ); } } return arrays; @@ -220,10 +289,12 @@ namespace sparrow_ipc std::vector fields_nullable; std::vector field_types; std::vector>> fields_metadata; - do + + while (!data.empty()) { - // Check for end-of-stream marker here as data could contain only that (if no record batches present/written) - if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8))) + // Check for end-of-stream marker + if (data.size() >= END_OF_STREAM_MARKER_SIZE + && is_end_of_stream(data.subspan(0, END_OF_STREAM_MARKER_SIZE))) { break; } @@ -276,12 +347,12 @@ namespace sparrow_ipc { if (schema == nullptr) { - throw std::runtime_error("Schema message is missing."); + throw std::runtime_error("RecordBatch encountered before Schema message."); } - const auto record_batch = message->header_as_RecordBatch(); + const auto* record_batch = message->header_as_RecordBatch(); if (record_batch == nullptr) { - throw std::runtime_error("RecordBatch message is missing."); + throw std::runtime_error("RecordBatch message header is null."); } std::vector arrays = get_arrays_from_record_batch( *record_batch, @@ -289,8 +360,7 @@ namespace sparrow_ipc encapsulated_message, fields_metadata ); - auto names_copy = field_names; // TODO: Remove when issue with the to_vector of - // record_batch is fixed + auto names_copy = field_names; sparrow::record_batch sp_record_batch(std::move(names_copy), std::move(arrays)); record_batches.emplace_back(std::move(sp_record_batch)); } @@ -298,12 +368,12 @@ namespace sparrow_ipc case org::apache::arrow::flatbuf::MessageHeader::Tensor: case org::apache::arrow::flatbuf::MessageHeader::DictionaryBatch: case org::apache::arrow::flatbuf::MessageHeader::SparseTensor: - throw std::runtime_error("Not supported"); + throw std::runtime_error("Unsupported message type: Tensor, DictionaryBatch, or SparseTensor"); default: throw std::runtime_error("Unknown message header type."); } data = rest; - } while (!data.empty()); + } return record_batches; } } diff --git a/tests/test_de_serialization_with_files.cpp b/tests/test_de_serialization_with_files.cpp index 8cb74e8f..0805508c 100644 --- a/tests/test_de_serialization_with_files.cpp +++ b/tests/test_de_serialization_with_files.cpp @@ -33,6 +33,7 @@ const std::vector files_paths_to_test = { tests_resources_files_path / "generated_large_binary", tests_resources_files_path / "generated_binary_zerolength", tests_resources_files_path / "generated_binary_no_batches", + tests_resources_files_path / "generated_interval", }; const std::vector files_paths_to_test_with_lz4_compression = {