From 7e299eae0db0d7bfc20f7c1e1548bf86cdbfef5e Mon Sep 17 00:00:00 2001 From: ArrayRecord Team Date: Wed, 15 Jan 2025 08:33:36 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 715812031 --- cpp/BUILD | 3 +++ cpp/array_record_writer.cc | 13 +++++++++++- cpp/array_record_writer.h | 4 +++- cpp/array_record_writer_test.cc | 35 +++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/cpp/BUILD b/cpp/BUILD index 91a5ff6..65e8569 100644 --- a/cpp/BUILD +++ b/cpp/BUILD @@ -115,6 +115,7 @@ cc_library( "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:cord", "@abseil-cpp//absl/synchronization", "@abseil-cpp//absl/types:span", "@protobuf//:protobuf_lite", @@ -235,6 +236,8 @@ cc_test( ":test_utils", ":thread_pool", "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:cord", + "@abseil-cpp//absl/strings:cord_test_helpers", "@googletest//:gtest_main", "@riegeli//riegeli/base:initializer", "@riegeli//riegeli/bytes:string_reader", diff --git a/cpp/array_record_writer.cc b/cpp/array_record_writer.cc index 48691ba..ac05c21 100644 --- a/cpp/array_record_writer.cc +++ b/cpp/array_record_writer.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -39,8 +40,8 @@ limitations under the License. #include "cpp/common.h" #include "cpp/layout.pb.h" #include "cpp/sequenced_chunk_writer.h" -#include "cpp/tri_state_ptr.h" #include "cpp/thread_pool.h" +#include "cpp/tri_state_ptr.h" #include "google/protobuf/message_lite.h" #include "riegeli/base/object.h" #include "riegeli/base/options_parser.h" @@ -418,6 +419,16 @@ bool ArrayRecordWriterBase::WriteRecord(absl::string_view record) { return WriteRecordImpl(std::move(record)); } +bool ArrayRecordWriterBase::WriteRecord(const absl::Cord& record) { + if (auto flat = record.TryFlat(); flat.has_value()) { + return WriteRecord(*flat); + } + + std::string cord_string; + absl::AppendCordToString(record, &cord_string); + return WriteRecord(cord_string); +} + bool ArrayRecordWriterBase::WriteRecord(const void* data, size_t num_bytes) { auto view = absl::string_view(reinterpret_cast(data), num_bytes); return WriteRecordImpl(std::move(view)); diff --git a/cpp/array_record_writer.h b/cpp/array_record_writer.h index 273d576..385313d 100644 --- a/cpp/array_record_writer.h +++ b/cpp/array_record_writer.h @@ -66,12 +66,13 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "cpp/common.h" #include "cpp/sequenced_chunk_writer.h" -#include "cpp/tri_state_ptr.h" #include "cpp/thread_pool.h" +#include "cpp/tri_state_ptr.h" #include "riegeli/base/initializer.h" #include "riegeli/base/object.h" #include "riegeli/bytes/writer.h" @@ -304,6 +305,7 @@ class ArrayRecordWriterBase : public riegeli::Object { // Write records of various types. bool WriteRecord(const google::protobuf::MessageLite& record); bool WriteRecord(absl::string_view record); + bool WriteRecord(const absl::Cord& record); bool WriteRecord(const void* data, size_t num_bytes); template bool WriteRecord(absl::Span record) { diff --git a/cpp/array_record_writer_test.cc b/cpp/array_record_writer_test.cc index a8b1414..1533225 100644 --- a/cpp/array_record_writer_test.cc +++ b/cpp/array_record_writer_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include #include "gtest/gtest.h" +#include "absl/strings/cord.h" +#include "absl/strings/cord_test_helpers.h" #include "absl/strings/string_view.h" #include "cpp/common.h" #include "cpp/layout.pb.h" @@ -116,6 +118,39 @@ TEST_P(ArrayRecordWriterTest, MoveTest) { } } +TEST_P(ArrayRecordWriterTest, CordTest) { + std::string encoded; + ARThreadPool* pool = nullptr; + if (std::get<3>(GetParam())) { + pool = ArrayRecordGlobalPool(); + } + auto options = GetOptions(); + options.set_group_size(2); + auto writer = ArrayRecordWriter( + riegeli::Maker(&encoded), options, pool); + + absl::Cord flat_cord("test"); + // Empty string should not crash the writer. + absl::Cord empty_cord(""); + absl::Cord fragmented_cord = absl::MakeFragmentedCord({"aaa ", "", "c"}); + + EXPECT_TRUE(writer.WriteRecord(flat_cord)); + EXPECT_TRUE(writer.WriteRecord(empty_cord)); + EXPECT_TRUE(writer.WriteRecord(fragmented_cord)); + ASSERT_TRUE(writer.Close()); + + // Empty string should not crash the reader. + std::vector expected_strings{"test", "", "aaa c"}; + + auto reader = + riegeli::RecordReader(riegeli::Maker(encoded)); + for (const auto& expected : expected_strings) { + std::string result; + reader.ReadRecord(result); + EXPECT_EQ(result, expected); + } +} + TEST_P(ArrayRecordWriterTest, RandomDatasetTest) { std::mt19937 bitgen; constexpr uint32_t kGroupSize = 100;