From 5114a6d813cee9b3ebc37d81f5b2a128492cfe64 Mon Sep 17 00:00:00 2001 From: Donald Duncan Date: Mon, 8 Dec 2025 13:54:14 +1300 Subject: [PATCH] Fix LINQ conversion of `.Contains()` --- .../Query/Linq/ExpressionExtensions.cs | 84 +++++++++++++++++++ .../Query/IDatasyncPullQuery_Tests.cs | 2 +- .../Query/IDatasyncQueryable_Tests.cs | 2 +- .../Service/DatasyncServiceClient_Tests.cs | 2 +- .../Service/Integration_Query_Tests.cs | 2 +- 5 files changed, 88 insertions(+), 4 deletions(-) diff --git a/src/CommunityToolkit.Datasync.Client/Query/Linq/ExpressionExtensions.cs b/src/CommunityToolkit.Datasync.Client/Query/Linq/ExpressionExtensions.cs index 1b4acf34..c9a34c7a 100644 --- a/src/CommunityToolkit.Datasync.Client/Query/Linq/ExpressionExtensions.cs +++ b/src/CommunityToolkit.Datasync.Client/Query/Linq/ExpressionExtensions.cs @@ -6,8 +6,10 @@ // a generalized "nullable" option here to allow us to do that. #nullable disable +using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Linq.Expressions; +using System.Reflection; namespace CommunityToolkit.Datasync.Client.Query.Linq; @@ -17,6 +19,30 @@ namespace CommunityToolkit.Datasync.Client.Query.Linq; /// internal static class ExpressionExtensions { + private static readonly MethodInfo Contains; + private static readonly MethodInfo SequenceEqual; + + static ExpressionExtensions() + { + Dictionary> queryableMethodGroups = typeof(Enumerable) + .GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) + .GroupBy(mi => mi.Name) + .ToDictionary(e => e.Key, l => l.ToList()); + + MethodInfo GetMethod(string name, int genericParameterCount, Func parameterGenerator) + => queryableMethodGroups[name].Single(mi => ((genericParameterCount == 0 && !mi.IsGenericMethod) + || (mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount)) + && mi.GetParameters().Select(e => e.ParameterType).SequenceEqual( + parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : []))); + + Contains = GetMethod( + nameof(Enumerable.Contains), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), types[0]]); + SequenceEqual = GetMethod( + nameof(Enumerable.SequenceEqual), 1, + types => [typeof(IEnumerable<>).MakeGenericType(types[0]), typeof(IEnumerable<>).MakeGenericType(types[0])]); + } + /// /// Walk the expression and compute all the subtrees that are not dependent on any /// of the expressions parameters. @@ -127,6 +153,7 @@ internal static bool IsValidLambdaExpression(this MethodCallExpression expressio /// The partially evaluated expression internal static Expression PartiallyEvaluate(this Expression expression) { + expression = expression.RemoveSpanImplicitCast(); List subtrees = expression.FindIndependentSubtrees(); return VisitorHelper.VisitAll(expression, (Expression expr, Func recurse) => { @@ -143,6 +170,63 @@ internal static Expression PartiallyEvaluate(this Expression expression) }); } + internal static Expression RemoveSpanImplicitCast(this Expression expression) + { + return VisitorHelper.VisitAll(expression, (Expression expr, Func recurse) => + { + if (expr is MethodCallExpression methodCall) + { + MethodInfo method = methodCall.Method; + + if (method.DeclaringType == typeof(MemoryExtensions)) + { + switch (method.Name) + { + case nameof(MemoryExtensions.Contains) + when methodCall.Arguments is [Expression arg0, Expression arg1] && TryUnwrapSpanImplicitCast(arg0, out Expression unwrappedArg0): + { + Expression unwrappedExpr = Expression.Call( + Contains.MakeGenericMethod(methodCall.Method.GetGenericArguments()[0]), + unwrappedArg0, arg1); + return recurse(unwrappedExpr); + } + + case nameof(MemoryExtensions.SequenceEqual) + when methodCall.Arguments is [Expression arg0, Expression arg1] + && TryUnwrapSpanImplicitCast(arg0, out Expression unwrappedArg0) + && TryUnwrapSpanImplicitCast(arg1, out Expression unwrappedArg1): + { + Expression unwrappedExpr = Expression.Call( + SequenceEqual.MakeGenericMethod(methodCall.Method.GetGenericArguments()[0]), + unwrappedArg0, unwrappedArg1); + return recurse(unwrappedExpr); + } + } + + static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result) + { + if (expression is MethodCallExpression + { + Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType }, + Arguments: [Expression unwrapped] + } + && implicitCastDeclaringType.GetGenericTypeDefinition() is Type genericTypeDefinition + && (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>))) + { + result = unwrapped; + return true; + } + + result = null; + return false; + } + } + } + + return recurse(expr); + }); + } + /// /// Remove the quote from quoted expressions. /// diff --git a/tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncPullQuery_Tests.cs b/tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncPullQuery_Tests.cs index c0f2d470..ae346dfa 100644 --- a/tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncPullQuery_Tests.cs +++ b/tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncPullQuery_Tests.cs @@ -1177,7 +1177,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase() ); } - [Fact(Skip = "OData v8.4 does not allow string.contains")] + [Fact] public void Linq_Where_String_Contains() { string[] ratings = ["A", "B"]; diff --git a/tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncQueryable_Tests.cs b/tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncQueryable_Tests.cs index f57752d7..88c95b2d 100644 --- a/tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncQueryable_Tests.cs +++ b/tests/CommunityToolkit.Datasync.Client.Test/Query/IDatasyncQueryable_Tests.cs @@ -1416,7 +1416,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase() ); } - [Fact(Skip = "OData v8.4 does not allow string.contains")] + [Fact] public void Linq_Where_String_Contains() { string[] ratings = ["A", "B"]; diff --git a/tests/CommunityToolkit.Datasync.Client.Test/Service/DatasyncServiceClient_Tests.cs b/tests/CommunityToolkit.Datasync.Client.Test/Service/DatasyncServiceClient_Tests.cs index edad68d3..354b2faf 100644 --- a/tests/CommunityToolkit.Datasync.Client.Test/Service/DatasyncServiceClient_Tests.cs +++ b/tests/CommunityToolkit.Datasync.Client.Test/Service/DatasyncServiceClient_Tests.cs @@ -3547,7 +3547,7 @@ public void Linq_Where_StartsWith_InvariantIgnoreCase() ); } - [Fact(Skip = "OData v8.4 does not allow string.contains")] + [Fact] public void Linq_Where_String_Contains() { string[] ratings = ["A", "B"]; diff --git a/tests/CommunityToolkit.Datasync.Client.Test/Service/Integration_Query_Tests.cs b/tests/CommunityToolkit.Datasync.Client.Test/Service/Integration_Query_Tests.cs index de43fd19..bb20aee7 100644 --- a/tests/CommunityToolkit.Datasync.Client.Test/Service/Integration_Query_Tests.cs +++ b/tests/CommunityToolkit.Datasync.Client.Test/Service/Integration_Query_Tests.cs @@ -540,7 +540,7 @@ await KitchenSinkQueryTest( // ); //} - [Fact(Skip = "OData v8.4 does not allow string.contains")] + [Fact] public async Task KitchenSinkQueryTest_020() { SeedKitchenSinkWithCountryData();