Skip to content

Commit 4236cf6

Browse files
authored
CSHARP-5628: Add new boolean expression simplifications to PartialEvaluator (#1803)
1 parent a6d42d2 commit 4236cf6

File tree

3 files changed

+318
-34
lines changed

3 files changed

+318
-34
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Misc/PartialEvaluator.cs

Lines changed: 124 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ public static Expression EvaluatePartially(Expression expression)
5050
// nested types
5151
private class SubtreeEvaluator : ExpressionVisitor
5252
{
53+
// private static fields
54+
private static readonly Expression __falseConstantExpression = Expression.Constant(false, typeof(bool));
55+
private static readonly Expression __trueConstantExpression = Expression.Constant(true, typeof(bool));
56+
5357
// private fields
5458
private readonly HashSet<Expression> _candidates;
5559

@@ -75,60 +79,137 @@ public override Expression Visit(Expression expression)
7579

7680
protected override Expression VisitBinary(BinaryExpression node)
7781
{
78-
if (node.NodeType == ExpressionType.AndAlso)
82+
var leftExpression = node.Left;
83+
var rightExpression = node.Right;
84+
85+
if (leftExpression.Type == typeof(bool) && rightExpression.Type == typeof(bool))
7986
{
80-
var leftExpression = Visit(node.Left);
81-
if (leftExpression is ConstantExpression constantLeftExpression )
87+
if (node.NodeType == ExpressionType.AndAlso)
8288
{
83-
var value = (bool)constantLeftExpression.Value;
84-
return value ? Visit(node.Right) : Expression.Constant(false);
89+
leftExpression = Visit(leftExpression);
90+
if (IsConstant<bool>(leftExpression, out var leftValue))
91+
{
92+
// true && Q => Q
93+
// false && Q => false
94+
return leftValue ? Visit(rightExpression) : __falseConstantExpression;
95+
}
96+
97+
rightExpression = Visit(rightExpression);
98+
if (IsConstant<bool>(rightExpression, out var rightValue))
99+
{
100+
// P && true => P
101+
// P && false => false
102+
return rightValue ? leftExpression : __falseConstantExpression;
103+
}
104+
105+
return node.Update(leftExpression, conversion: null, rightExpression);
85106
}
86107

87-
var rightExpression = Visit(node.Right);
88-
if (rightExpression is ConstantExpression constantRightExpression)
108+
if (node.NodeType == ExpressionType.OrElse)
89109
{
90-
var value = (bool)constantRightExpression.Value;
91-
return value ? leftExpression : Expression.Constant(false);
110+
leftExpression = Visit(leftExpression);
111+
if (IsConstant<bool>(leftExpression, out var leftValue))
112+
{
113+
// true || Q => true
114+
// false || Q => Q
115+
return leftValue ? __trueConstantExpression : Visit(rightExpression);
116+
}
117+
118+
rightExpression = Visit(rightExpression);
119+
if (IsConstant<bool>(rightExpression, out var rightValue))
120+
{
121+
// P || true => true
122+
// P || false => P
123+
return rightValue ? __trueConstantExpression : leftExpression;
124+
}
125+
126+
return node.Update(leftExpression, conversion: null, rightExpression);
92127
}
128+
}
129+
130+
return base.VisitBinary(node);
131+
}
93132

94-
return node.Update(leftExpression, conversion: null, rightExpression);
133+
protected override Expression VisitConditional(ConditionalExpression node)
134+
{
135+
var test = Visit(node.Test);
136+
137+
if (IsConstant<bool>(test, out var testValue))
138+
{
139+
// true ? A : B => A
140+
// false ? A : B => B
141+
return testValue ? Visit(node.IfTrue) : Visit(node.IfFalse);
95142
}
96143

97-
if (node.NodeType == ExpressionType.OrElse)
144+
var ifTrue = Visit(node.IfTrue);
145+
var ifFalse = Visit(node.IfFalse);
146+
147+
if (BothAreConstant<bool>(ifTrue, ifFalse, out var ifTrueValue, out var ifFalseValue))
98148
{
99-
var leftExpression = Visit(node.Left);
100-
if (leftExpression is ConstantExpression constantLeftExpression)
149+
return (ifTrueValue, ifFalseValue) switch
101150
{
102-
var value = (bool)constantLeftExpression.Value;
103-
return value ? Expression.Constant(true) : Visit(node.Right);
104-
}
151+
(false, false) => __falseConstantExpression, // T ? false : false => false
152+
(false, true) => Expression.Not(test), // T ? false : true => !T
153+
(true, false) => test, // T ? true : false => T
154+
(true, true) => __trueConstantExpression // T ? true : true => true
155+
};
156+
}
157+
else if (IsConstant<bool>(ifTrue, out ifTrueValue))
158+
{
159+
// T ? true : Q => T || Q
160+
// T ? false : Q => !T && Q
161+
return ifTrueValue
162+
? Visit(Expression.OrElse(test, ifFalse))
163+
: Visit(Expression.AndAlso(Expression.Not(test), ifFalse));
164+
}
165+
else if (IsConstant<bool>(ifFalse, out ifFalseValue))
166+
{
167+
// T ? P : true => !T || P
168+
// T ? P : false => T && P
169+
return ifFalseValue
170+
? Visit(Expression.OrElse(Expression.Not(test), ifTrue))
171+
: Visit(Expression.AndAlso(test, ifTrue));
172+
}
105173

106-
var rightExpression = Visit(node.Right);
107-
if (rightExpression is ConstantExpression constantRightExpression)
174+
return node.Update(test, ifTrue, ifFalse);
175+
}
176+
177+
protected override Expression VisitUnary(UnaryExpression node)
178+
{
179+
var operand = Visit(node.Operand);
180+
181+
if (node.Type == typeof(bool) &&
182+
node.NodeType == ExpressionType.Not)
183+
{
184+
if (operand is UnaryExpression innerUnaryExpressionOperand &&
185+
innerUnaryExpressionOperand.NodeType == ExpressionType.Not)
108186
{
109-
var value = (bool)constantRightExpression.Value;
110-
return value ? Expression.Constant(true) : leftExpression;
187+
// !!P => P
188+
return innerUnaryExpressionOperand.Operand;
111189
}
112-
113-
return node.Update(leftExpression, conversion: null, rightExpression);
114190
}
115191

116-
return base.VisitBinary(node);
192+
return node.Update(operand);
117193
}
118194

119-
protected override Expression VisitConditional(ConditionalExpression node)
195+
// private methods
196+
private bool BothAreConstant<T>(Expression expression1, Expression expression2, out T constantValue1, out T constantValue2)
120197
{
121-
var test = Visit(node.Test);
122-
if (test is ConstantExpression constantTestExpression)
198+
if (expression1 is ConstantExpression constantExpression1 &&
199+
expression2 is ConstantExpression constantExpression2 &&
200+
constantExpression1.Type == typeof(T) &&
201+
constantExpression2.Type == typeof(T))
123202
{
124-
var value = (bool)constantTestExpression.Value;
125-
return value ? Visit(node.IfTrue) : Visit(node.IfFalse);
203+
constantValue1 = (T)constantExpression1.Value;
204+
constantValue2 = (T)constantExpression2.Value;
205+
return true;
126206
}
127207

128-
return node.Update(test, Visit(node.IfTrue), Visit(node.IfFalse));
208+
constantValue1 = default;
209+
constantValue2 = default;
210+
return false;
129211
}
130212

131-
// private methods
132213
private Expression Evaluate(Expression expression)
133214
{
134215
if (expression.NodeType == ExpressionType.Constant)
@@ -139,6 +220,19 @@ private Expression Evaluate(Expression expression)
139220
Delegate fn = lambda.Compile();
140221
return Expression.Constant(fn.DynamicInvoke(null), expression.Type);
141222
}
223+
224+
private bool IsConstant<T>(Expression expression, out T constantValue)
225+
{
226+
if (expression is ConstantExpression constantExpression1 &&
227+
constantExpression1.Type == typeof(T))
228+
{
229+
constantValue = (T)constantExpression1.Value;
230+
return true;
231+
}
232+
233+
constantValue = default;
234+
return false;
235+
}
142236
}
143237

144238
private class Nominator : ExpressionVisitor

tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4337Tests.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ public class CSharp4337Tests : LinqIntegrationTest<CSharp4337Tests.ClassFixture>
3232
{
3333
private static (Expression<Func<C, R<bool>>> Projection, string ExpectedStage, bool[] ExpectedResults)[] __predicate_should_use_correct_representation_test_cases = new (Expression<Func<C, R<bool>>> Projection, string ExpectedStage, bool[] ExpectedResults)[]
3434
{
35-
(d => new R<bool> { N = d.Id, V = d.I1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['$I1', 1] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }),
36-
(d => new R<bool> { N = d.Id, V = d.S1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['$S1', 'E1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }),
37-
(d => new R<bool> { N = d.Id, V = E.E1 == d.I1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : [1, '$I1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }),
38-
(d => new R<bool> { N = d.Id, V = E.E1 == d.S1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['E1', '$S1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false })
35+
(d => new R<bool> { N = d.Id, V = d.I1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['$I1', 1] }, _id : 0 } }", new[] { true, false }),
36+
(d => new R<bool> { N = d.Id, V = d.S1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['$S1', 'E1'] }, _id : 0 } }", new[] { true, false }),
37+
(d => new R<bool> { N = d.Id, V = E.E1 == d.I1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : [1, '$I1'] }, _id : 0 } }", new[] { true, false }),
38+
(d => new R<bool> { N = d.Id, V = E.E1 == d.S1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['E1', '$S1'] }, _id : 0 } }", new[] { true, false })
3939
};
4040

4141
public CSharp4337Tests(ClassFixture fixture)

0 commit comments

Comments
 (0)