From b1e0b7cbd6bd7422175efc6ee5378f8dd91469c5 Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 21 Mar 2025 11:27:37 -0700 Subject: [PATCH] Make `.startsWith`, `.endsWith`, and `.contains` cheap again PiperOrigin-RevId: 739245572 --- common/values/string_value.cc | 76 ++++++++++++++++++++++++++++ common/values/string_value.h | 12 +++++ common/values/string_value_test.cc | 51 +++++++++++++++++++ runtime/standard/string_functions.cc | 7 ++- 4 files changed, 142 insertions(+), 4 deletions(-) diff --git a/common/values/string_value.cc b/common/values/string_value.cc index d8068545e..8fb4f4a1d 100644 --- a/common/values/string_value.cc +++ b/common/values/string_value.cc @@ -22,6 +22,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/strings/cord.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "common/value.h" @@ -204,4 +205,79 @@ int StringValue::Compare(const StringValue& string) const { [this](const auto& alternative) -> int { return Compare(alternative); }); } +bool StringValue::StartsWith(absl::string_view string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> bool { + return absl::StartsWith(lhs, string); + }, + [&](const absl::Cord& lhs) -> bool { return lhs.StartsWith(string); })); +} + +bool StringValue::StartsWith(const absl::Cord& string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> bool { + return lhs.size() >= string.size() && + lhs.substr(0, string.size()) == string; + }, + [&](const absl::Cord& lhs) -> bool { return lhs.StartsWith(string); })); +} + +bool StringValue::StartsWith(const StringValue& string) const { + return string.value_.Visit(absl::Overload( + [&](absl::string_view rhs) -> bool { return StartsWith(rhs); }, + [&](const absl::Cord& rhs) -> bool { return StartsWith(rhs); })); +} + +bool StringValue::EndsWith(absl::string_view string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> bool { + return absl::EndsWith(lhs, string); + }, + [&](const absl::Cord& lhs) -> bool { return lhs.EndsWith(string); })); +} + +bool StringValue::EndsWith(const absl::Cord& string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> bool { + return lhs.size() >= string.size() && + lhs.substr(lhs.size() - string.size()) == string; + }, + [&](const absl::Cord& lhs) -> bool { return lhs.EndsWith(string); })); +} + +bool StringValue::EndsWith(const StringValue& string) const { + return string.value_.Visit(absl::Overload( + [&](absl::string_view rhs) -> bool { return EndsWith(rhs); }, + [&](const absl::Cord& rhs) -> bool { return EndsWith(rhs); })); +} + +bool StringValue::Contains(absl::string_view string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> bool { + return absl::StrContains(lhs, string); + }, + [&](const absl::Cord& lhs) -> bool { return lhs.Contains(string); })); +} + +bool StringValue::Contains(const absl::Cord& string) const { + return value_.Visit(absl::Overload( + [&](absl::string_view lhs) -> bool { + if (auto flat = string.TryFlat(); flat) { + return absl::StrContains(lhs, *flat); + } + // There is no nice way to do this. We cannot use std::search due to + // absl::Cord::CharIterator being an input iterator instead of a forward + // iterator. So just make an external cord with a noop releaser. We know + // the external cord will not outlive this function. + return absl::MakeCordFromExternal(lhs, []() {}).Contains(string); + }, + [&](const absl::Cord& lhs) -> bool { return lhs.Contains(string); })); +} + +bool StringValue::Contains(const StringValue& string) const { + return string.value_.Visit(absl::Overload( + [&](absl::string_view rhs) -> bool { return Contains(rhs); }, + [&](const absl::Cord& rhs) -> bool { return Contains(rhs); })); +} + } // namespace cel diff --git a/common/values/string_value.h b/common/values/string_value.h index 8763fa3d0..7f322f152 100644 --- a/common/values/string_value.h +++ b/common/values/string_value.h @@ -196,6 +196,18 @@ class StringValue final : private common_internal::ValueMixin { int Compare(const absl::Cord& string) const; int Compare(const StringValue& string) const; + bool StartsWith(absl::string_view string) const; + bool StartsWith(const absl::Cord& string) const; + bool StartsWith(const StringValue& string) const; + + bool EndsWith(absl::string_view string) const; + bool EndsWith(const absl::Cord& string) const; + bool EndsWith(const StringValue& string) const; + + bool Contains(absl::string_view string) const; + bool Contains(const absl::Cord& string) const; + bool Contains(const StringValue& string) const; + absl::optional TryFlat() const ABSL_ATTRIBUTE_LIFETIME_BOUND { return value_.TryFlat(); diff --git a/common/values/string_value_test.cc b/common/values/string_value_test.cc index 79a55c900..244fd3f7e 100644 --- a/common/values/string_value_test.cc +++ b/common/values/string_value_test.cc @@ -157,5 +157,56 @@ TEST_F(StringValueTest, LessThan) { EXPECT_LT(absl::Cord("bar"), StringValue("foo")); } +TEST_F(StringValueTest, StartsWith) { + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .StartsWith(StringValue("This string is large enough"))); + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .StartsWith(StringValue(absl::Cord("This string is large enough")))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .StartsWith(StringValue("This string is large enough"))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .StartsWith(StringValue(absl::Cord("This string is large enough")))); +} + +TEST_F(StringValueTest, EndsWith) { + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .EndsWith(StringValue("to not be stored inline!"))); + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .EndsWith(StringValue(absl::Cord("to not be stored inline!")))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .EndsWith(StringValue("to not be stored inline!"))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .EndsWith(StringValue(absl::Cord("to not be stored inline!")))); +} + +TEST_F(StringValueTest, Contains) { + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .Contains(StringValue("string is large enough"))); + EXPECT_TRUE( + StringValue("This string is large enough to not be stored inline!") + .Contains(StringValue(absl::Cord("string is large enough")))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .Contains(StringValue("string is large enough"))); + EXPECT_TRUE( + StringValue( + absl::Cord("This string is large enough to not be stored inline!")) + .Contains(StringValue(absl::Cord("string is large enough")))); +} + } // namespace } // namespace cel diff --git a/runtime/standard/string_functions.cc b/runtime/standard/string_functions.cc index db7ef1b91..e6b60c618 100644 --- a/runtime/standard/string_functions.cc +++ b/runtime/standard/string_functions.cc @@ -19,7 +19,6 @@ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "base/builtins.h" @@ -60,15 +59,15 @@ absl::StatusOr ConcatBytes( } bool StringContains(const StringValue& value, const StringValue& substr) { - return absl::StrContains(value.ToString(), substr.ToString()); + return value.Contains(substr); } bool StringEndsWith(const StringValue& value, const StringValue& suffix) { - return absl::EndsWith(value.ToString(), suffix.ToString()); + return value.EndsWith(suffix); } bool StringStartsWith(const StringValue& value, const StringValue& prefix) { - return absl::StartsWith(value.ToString(), prefix.ToString()); + return value.StartsWith(prefix); } absl::Status RegisterSizeFunctions(FunctionRegistry& registry) {