Skip to content

Commit a5d93e6

Browse files
committed
wip
1 parent 7fcdea0 commit a5d93e6

File tree

1 file changed

+59
-37
lines changed

1 file changed

+59
-37
lines changed

src/deserialize.cpp

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,23 @@
1111

1212
namespace sparrow_ipc
1313
{
14+
namespace
15+
{
16+
// Integer bit width constants
17+
constexpr int32_t BIT_WIDTH_8 = 8;
18+
constexpr int32_t BIT_WIDTH_16 = 16;
19+
constexpr int32_t BIT_WIDTH_32 = 32;
20+
constexpr int32_t BIT_WIDTH_64 = 64;
21+
22+
// End-of-stream marker size in bytes
23+
constexpr size_t END_OF_STREAM_MARKER_SIZE = 8;
24+
}
1425
const org::apache::arrow::flatbuf::RecordBatch*
1526
deserialize_record_batch_message(std::span<const uint8_t> data, size_t& current_offset)
1627
{
1728
current_offset += sizeof(uint32_t);
18-
const auto batch_message = org::apache::arrow::flatbuf::GetMessage(data.data() + current_offset);
29+
const auto message_data = data.subspan(current_offset);
30+
const auto* batch_message = org::apache::arrow::flatbuf::GetMessage(message_data.data());
1931
if (batch_message->header_type() != org::apache::arrow::flatbuf::MessageHeader::RecordBatch)
2032
{
2133
throw std::runtime_error("Expected RecordBatch message, but got a different type.");
@@ -28,20 +40,21 @@ namespace sparrow_ipc
2840
*
2941
* This function processes each field in the schema and deserializes the corresponding
3042
* data from the RecordBatch into sparrow::array objects. It handles various Arrow data
31-
* types including primitive types (bool, integers, floating point), binary data, and
32-
* string data with their respective size variants.
43+
* types including primitive types (bool, integers, floating point), binary data, string
44+
* data, fixed-size binary data, and interval types.
3345
*
3446
* @param record_batch The Apache Arrow FlatBuffer RecordBatch containing the serialized data
3547
* @param schema The Apache Arrow FlatBuffer Schema defining the structure and types of the data
3648
* @param encapsulated_message The message containing the binary data buffers
49+
* @param field_metadata Metadata associated with each field in the schema
3750
*
3851
* @return std::vector<sparrow::array> A vector of deserialized arrays, one for each field in the schema
3952
*
40-
* @throws std::runtime_error If an unsupported data type, integer bit width, or floating point precision
41-
* is encountered
53+
* @throws std::runtime_error If an unsupported data type, integer bit width, floating point precision,
54+
* or interval unit is encountered
4255
*
43-
* The function maintains a buffer index that is incremented as it processes each field
44-
* to correctly map data buffers to their corresponding arrays.
56+
* @note The function maintains a buffer index that is incremented as it processes each field
57+
* to correctly map data buffers to their corresponding arrays.
4558
*/
4659
std::vector<sparrow::array> get_arrays_from_record_batch(
4760
const org::apache::arrow::flatbuf::RecordBatch& record_batch,
@@ -62,7 +75,7 @@ namespace sparrow_ipc
6275
const std::optional<std::vector<sparrow::metadata_pair>>& metadata = field_metadata[field_idx++];
6376
const std::string name = field->name() == nullptr ? "" : field->name()->str();
6477
const auto field_type = field->type_type();
65-
// TODO rename all the deserialize_non_owning... fcts since this is not correct anymore
78+
6679
const auto deserialize_non_owning_primitive_array_lambda = [&]<typename T>()
6780
{
6881
return deserialize_non_owning_primitive_array<T>(
@@ -82,7 +95,7 @@ namespace sparrow_ipc
8295
break;
8396
case org::apache::arrow::flatbuf::Type::Int:
8497
{
85-
const auto int_type = field->type_as_Int();
98+
const auto* int_type = field->type_as_Int();
8699
const auto bit_width = int_type->bitWidth();
87100
const bool is_signed = int_type->is_signed();
88101

@@ -91,11 +104,11 @@ namespace sparrow_ipc
91104
switch (bit_width)
92105
{
93106
// clang-format off
94-
case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int8_t>()); break;
95-
case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int16_t>()); break;
96-
case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int32_t>()); break;
97-
case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int64_t>()); break;
98-
default: throw std::runtime_error("Unsupported integer bit width.");
107+
case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int8_t>()); break;
108+
case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int16_t>()); break;
109+
case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int32_t>()); break;
110+
case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<int64_t>()); break;
111+
default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width));
99112
// clang-format on
100113
}
101114
}
@@ -104,19 +117,19 @@ namespace sparrow_ipc
104117
switch (bit_width)
105118
{
106119
// clang-format off
107-
case 8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint8_t>()); break;
108-
case 16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint16_t>()); break;
109-
case 32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint32_t>()); break;
110-
case 64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint64_t>()); break;
111-
default: throw std::runtime_error("Unsupported integer bit width.");
120+
case BIT_WIDTH_8: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint8_t>()); break;
121+
case BIT_WIDTH_16: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint16_t>()); break;
122+
case BIT_WIDTH_32: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint32_t>()); break;
123+
case BIT_WIDTH_64: arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<uint64_t>()); break;
124+
default: throw std::runtime_error("Unsupported integer bit width: " + std::to_string(bit_width));
112125
// clang-format on
113126
}
114127
}
115128
}
116129
break;
117130
case org::apache::arrow::flatbuf::Type::FloatingPoint:
118131
{
119-
const auto float_type = field->type_as_FloatingPoint();
132+
const auto* float_type = field->type_as_FloatingPoint();
120133
switch (float_type->precision())
121134
{
122135
// clang-format off
@@ -130,14 +143,17 @@ namespace sparrow_ipc
130143
arrays.emplace_back(deserialize_non_owning_primitive_array_lambda.template operator()<double>());
131144
break;
132145
default:
133-
throw std::runtime_error("Unsupported floating point precision.");
146+
throw std::runtime_error(
147+
"Unsupported floating point precision: "
148+
+ std::to_string(static_cast<int>(float_type->precision()))
149+
);
134150
// clang-format on
135151
}
136152
break;
137153
}
138154
case org::apache::arrow::flatbuf::Type::FixedSizeBinary:
139155
{
140-
const auto fixed_size_binary_field = field->type_as_FixedSizeBinary();
156+
const auto* fixed_size_binary_field = field->type_as_FixedSizeBinary();
141157
arrays.emplace_back(deserialize_non_owning_fixedwidthbinary(
142158
record_batch,
143159
encapsulated_message.body(),
@@ -194,8 +210,8 @@ namespace sparrow_ipc
194210
break;
195211
case org::apache::arrow::flatbuf::Type::Interval:
196212
{
197-
const auto interval_type = field->type_as_Interval();
198-
org::apache::arrow::flatbuf::IntervalUnit interval_unit = interval_type->unit();
213+
const auto* interval_type = field->type_as_Interval();
214+
const org::apache::arrow::flatbuf::IntervalUnit interval_unit = interval_type->unit();
199215
switch (interval_unit)
200216
{
201217
case org::apache::arrow::flatbuf::IntervalUnit::YEAR_MONTH:
@@ -232,12 +248,18 @@ namespace sparrow_ipc
232248
);
233249
break;
234250
default:
235-
throw std::runtime_error("Unsupported interval unit.");
251+
throw std::runtime_error(
252+
"Unsupported interval unit: "
253+
+ std::to_string(static_cast<int>(interval_unit))
254+
);
236255
}
237256
}
238257
break;
239258
default:
240-
throw std::runtime_error("Unsupported type.");
259+
throw std::runtime_error(
260+
"Unsupported field type: " + std::to_string(static_cast<int>(field_type))
261+
+ " for field '" + name + "'"
262+
);
241263
}
242264
}
243265
return arrays;
@@ -251,11 +273,12 @@ namespace sparrow_ipc
251273
std::vector<bool> fields_nullable;
252274
std::vector<sparrow::data_type> field_types;
253275
std::vector<std::optional<std::vector<sparrow::metadata_pair>>> fields_metadata;
254-
do
276+
277+
while (!data.empty())
255278
{
256-
// Check for end-of-stream marker here as data could contain only that (if no record batches
257-
// present/written)
258-
if (data.size() >= 8 && is_end_of_stream(data.subspan(0, 8)))
279+
// Check for end-of-stream marker
280+
if (data.size() >= END_OF_STREAM_MARKER_SIZE
281+
&& is_end_of_stream(data.subspan(0, END_OF_STREAM_MARKER_SIZE)))
259282
{
260283
break;
261284
}
@@ -303,34 +326,33 @@ namespace sparrow_ipc
303326
{
304327
if (schema == nullptr)
305328
{
306-
throw std::runtime_error("Schema message is missing.");
329+
throw std::runtime_error("RecordBatch encountered before Schema message.");
307330
}
308-
const auto record_batch = message->header_as_RecordBatch();
331+
const auto* record_batch = message->header_as_RecordBatch();
309332
if (record_batch == nullptr)
310333
{
311-
throw std::runtime_error("RecordBatch message is missing.");
334+
throw std::runtime_error("RecordBatch message header is null.");
312335
}
313336
std::vector<sparrow::array> arrays = get_arrays_from_record_batch(
314337
*record_batch,
315338
*schema,
316339
encapsulated_message,
317340
fields_metadata
318341
);
319-
auto names_copy = field_names; // TODO: Remove when issue with the to_vector of
320-
// record_batch is fixed
342+
auto names_copy = field_names;
321343
sparrow::record_batch sp_record_batch(std::move(names_copy), std::move(arrays));
322344
record_batches.emplace_back(std::move(sp_record_batch));
323345
}
324346
break;
325347
case org::apache::arrow::flatbuf::MessageHeader::Tensor:
326348
case org::apache::arrow::flatbuf::MessageHeader::DictionaryBatch:
327349
case org::apache::arrow::flatbuf::MessageHeader::SparseTensor:
328-
throw std::runtime_error("Not supported");
350+
throw std::runtime_error("Unsupported message type: Tensor, DictionaryBatch, or SparseTensor");
329351
default:
330352
throw std::runtime_error("Unknown message header type.");
331353
}
332354
data = rest;
333-
} while (!data.empty());
355+
}
334356
return record_batches;
335357
}
336358
}

0 commit comments

Comments
 (0)