Skip to content

Commit 419dd85

Browse files
committed
Rust: Handle chained let expressions
1 parent f3bdf7d commit 419dd85

File tree

28 files changed

+505
-191
lines changed

28 files changed

+505
-191
lines changed

rust/ql/lib/codeql/rust/controlflow/internal/CfgNodes.qll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ private import codeql.rust.controlflow.CfgNodes
77
private import codeql.rust.internal.CachedStages
88

99
private predicate isPostOrder(AstNode n) {
10-
n instanceof Expr and
11-
not n instanceof LetExpr
10+
n instanceof Expr
1211
or
1312
n instanceof OrPat
1413
or

rust/ql/lib/codeql/rust/controlflow/internal/ControlFlowGraphImpl.qll

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,7 @@ class TypeReprTree extends LeafTree instanceof TypeRepr { }
200200
/**
201201
* Provides `ControlFlowTree`s for expressions.
202202
*
203-
* Since expressions construct values, they are modeled in post-order, except for
204-
* `LetExpr`s.
203+
* Since expressions construct values, they are modeled in post-order.
205204
*/
206205
module ExprTrees {
207206
class ArrayExprTree extends StandardPostOrderTree, ArrayExpr {
@@ -341,21 +340,15 @@ module ExprTrees {
341340
child = [super.getCondition(), super.getABranch()]
342341
}
343342

344-
private ConditionalCompletion conditionCompletion(Completion c) {
345-
if super.getCondition() instanceof LetExpr
346-
then result = c.(MatchCompletion)
347-
else result = c.(BooleanCompletion)
348-
}
349-
350343
override predicate succ(AstNode pred, AstNode succ, Completion c) {
351344
// Edges from the condition to the branches
352345
last(super.getCondition(), pred, c) and
353346
(
354-
first(super.getThen(), succ) and this.conditionCompletion(c).succeeded()
347+
first(super.getThen(), succ) and c.(ConditionalCompletion).succeeded()
355348
or
356-
first(super.getElse(), succ) and this.conditionCompletion(c).failed()
349+
first(super.getElse(), succ) and c.(ConditionalCompletion).failed()
357350
or
358-
not super.hasElse() and succ = this and this.conditionCompletion(c).failed()
351+
not super.hasElse() and succ = this and c.(ConditionalCompletion).failed()
359352
)
360353
or
361354
// An edge from the then branch to the last node
@@ -401,10 +394,7 @@ module ExprTrees {
401394
}
402395
}
403396

404-
// `LetExpr` is a pre-order tree such that the pattern itself ends up
405-
// dominating successors in the graph in the same way that patterns do in
406-
// `match` expressions.
407-
class LetExprTree extends StandardPreOrderTree, LetExpr {
397+
class LetExprTree extends StandardPostOrderTree, LetExpr {
408398
override AstNode getChildNode(int i) {
409399
i = 0 and
410400
result = this.getScrutinee()
@@ -456,21 +446,15 @@ module ExprTrees {
456446

457447
override predicate first(AstNode node) { first(super.getCondition(), node) }
458448

459-
private ConditionalCompletion conditionCompletion(Completion c) {
460-
if super.getCondition() instanceof LetExpr
461-
then result = c.(MatchCompletion)
462-
else result = c.(BooleanCompletion)
463-
}
464-
465449
override predicate succ(AstNode pred, AstNode succ, Completion c) {
466450
super.succ(pred, succ, c)
467451
or
468452
last(super.getCondition(), pred, c) and
469-
this.conditionCompletion(c).succeeded() and
453+
c.(ConditionalCompletion).succeeded() and
470454
first(this.getLoopBody(), succ)
471455
or
472456
last(super.getCondition(), pred, c) and
473-
this.conditionCompletion(c).failed() and
457+
c.(ConditionalCompletion).failed() and
474458
succ = this
475459
}
476460
}

rust/ql/lib/codeql/rust/controlflow/internal/Splitting.qll

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,7 @@ module ConditionalCompletionSplitting {
7171
child = parent.(LogicalNotExpr).getExpr() and
7272
childCompletion.getDual() = parentCompletion
7373
or
74-
(
75-
childCompletion = parentCompletion
76-
or
77-
// needed for `let` expressions
78-
childCompletion.(MatchCompletion).getValue() =
79-
parentCompletion.(BooleanCompletion).getValue()
80-
) and
74+
childCompletion = parentCompletion and
8175
(
8276
child = parent.(BinaryLogicalOperation).getAnOperand()
8377
or
@@ -92,6 +86,9 @@ module ConditionalCompletionSplitting {
9286
or
9387
child = parent.(PatternTrees::PostOrderPatTree).getPat(_)
9488
)
89+
or
90+
child = parent.(LetExpr).getPat() and
91+
childCompletion.(MatchCompletion).getValue() = parentCompletion.(BooleanCompletion).getValue()
9592
}
9693
}
9794

rust/ql/lib/codeql/rust/dataflow/Ssa.qll

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,16 @@ module Ssa {
194194
ae.getRhs() = value
195195
)
196196
or
197-
exists(LetStmtCfgNode ls |
198-
ls.getPat().(IdentPatCfgNode).getName() = write and
199-
ls.getInitializer() = value
197+
exists(IdentPatCfgNode pat | pat.getName() = write |
198+
exists(LetStmtCfgNode ls |
199+
pat = ls.getPat() and
200+
ls.getInitializer() = value
201+
)
202+
or
203+
exists(LetExprCfgNode le |
204+
pat = le.getPat() and
205+
le.getScrutinee() = value
206+
)
200207
)
201208
}
202209

rust/ql/lib/codeql/rust/dataflow/internal/DataFlowImpl.qll

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ module LocalFlow {
241241
nodeTo.getCfgNode() = s.getPat()
242242
)
243243
or
244+
// An edge from the right-hand side of a let expression to the left-hand side.
245+
exists(LetExprCfgNode e |
246+
nodeFrom.getCfgNode() = e.getScrutinee() and
247+
nodeTo.getCfgNode() = e.getPat()
248+
)
249+
or
244250
exists(IdentPatCfgNode p |
245251
not p.isRef() and
246252
nodeFrom.getCfgNode() = p and
@@ -379,6 +385,8 @@ module RustDataFlow implements InputSig<Location> {
379385
predicate neverSkipInPathGraph(Node node) {
380386
node.(Node::Node).getCfgNode() = any(LetStmtCfgNode s).getPat()
381387
or
388+
node.(Node::Node).getCfgNode() = any(LetExprCfgNode e).getPat()
389+
or
382390
node.(Node::Node).getCfgNode() = any(AssignmentExprCfgNode a).getLhs()
383391
or
384392
exists(MatchExprCfgNode match |
@@ -899,6 +907,12 @@ module VariableCapture {
899907
v.getPat() = ls.getPat().getPat() and
900908
ls.getInitializer() = source
901909
)
910+
or
911+
exists(LetExprCfgNode le |
912+
this = le and
913+
v.getPat() = le.getPat().getPat() and
914+
le.getScrutinee() = source
915+
)
902916
}
903917

904918
CapturedVariable getVariable() { result = v }

rust/ql/lib/codeql/rust/elements/internal/VariableImpl.qll

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ module Impl {
3636
ClosureBodyScope() { this = any(ClosureExpr ce).getBody() }
3737
}
3838

39+
class IfExprScope extends VariableScope, IfExpr { }
40+
41+
class WhileExprScope extends VariableScope, WhileExpr { }
42+
3943
private Pat getAPatAncestor(Pat p) {
4044
(p instanceof IdentPat or p instanceof OrPat) and
4145
exists(Pat p0 | result = p0.getParentPat() |
@@ -152,8 +156,14 @@ module Impl {
152156
/** Gets the `let` statement that introduces this variable, if any. */
153157
LetStmt getLetStmt() { this.getPat() = result.getPat() }
154158

159+
/** Gets the `let` expression that introduces this variable, if any. */
160+
LetExpr getLetExpr() { this.getPat() = result.getPat() }
161+
155162
/** Gets the initial value of this variable, if any. */
156-
Expr getInitializer() { result = this.getLetStmt().getInitializer() }
163+
Expr getInitializer() {
164+
result = this.getLetStmt().getInitializer() or
165+
result = this.getLetExpr().getScrutinee()
166+
}
157167

158168
/** Holds if this variable is captured. */
159169
predicate isCaptured() { this.getAnAccess().isCapture() }
@@ -193,15 +203,60 @@ module Impl {
193203
string getName() { result = name_ }
194204
}
195205

206+
private AstNode getElseBranch(
207+
AstNode elseParentParent, int index, AstNode elseParent, int elseIndex
208+
) {
209+
elseParent = getImmediateChild(elseParentParent, index) and
210+
result = getImmediateChild(elseParent, elseIndex) and
211+
(
212+
result = elseParent.(LetStmt).getLetElse()
213+
or
214+
result = elseParent.(IfExpr).getElse()
215+
)
216+
}
217+
218+
private AstNode getLoopBody(LoopingExpr loop) { result = loop.getLoopBody() }
219+
220+
pragma[nomagic]
221+
private Element getImmediateChildAdj(Element e, int preOrd, int index, int postOrd) {
222+
result = getImmediateChild(e, index) and
223+
preOrd = 0 and
224+
postOrd = 0 and
225+
not result = getElseBranch(_, _, e, index) and
226+
not result = getLoopBody(e)
227+
or
228+
result = getElseBranch(e, index, _, _) and
229+
preOrd = 0 and
230+
postOrd = -1
231+
or
232+
result = getLoopBody(e) and
233+
index = 0 and
234+
preOrd = 1 and
235+
postOrd = 0
236+
}
237+
238+
pragma[nomagic]
239+
private Element getImmediateChildAdj(Element e, int index) {
240+
result =
241+
rank[index + 1](Element res, int i, int preOrd, int postOrd |
242+
res = getImmediateChildAdj(e, preOrd, i, postOrd)
243+
|
244+
res order by preOrd, i, postOrd
245+
)
246+
}
247+
248+
private Element getImmediateParentAdj(Element e) { e = getImmediateChildAdj(result, _) }
249+
196250
private AstNode getAnAncestorInVariableScope(AstNode n) {
197251
(
198252
n instanceof Pat or
199253
n instanceof VariableAccessCand or
200254
n instanceof LetStmt or
255+
n instanceof LetExpr or
201256
n instanceof VariableScope
202257
) and
203258
exists(AstNode n0 |
204-
result = getImmediateParent(n0) or
259+
result = getImmediateParentAdj(n0) or
205260
result = n0.(FormatTemplateVariableAccess).getArgument().getParent().getParent()
206261
|
207262
n0 = n
@@ -243,31 +298,32 @@ module Impl {
243298
this instanceof VariableScope or
244299
this instanceof VariableAccessCand or
245300
this instanceof LetStmt or
246-
getImmediateChild(this, _) instanceof RelevantElement
301+
this instanceof LetExpr or
302+
getImmediateChildAdj(this, _) instanceof RelevantElement
247303
}
248304

249305
pragma[nomagic]
250-
private RelevantElement getChild(int index) { result = getImmediateChild(this, index) }
306+
private RelevantElement getChild(int index) { result = getImmediateChildAdj(this, index) }
251307

252308
pragma[nomagic]
253-
private RelevantElement getImmediateChildMin(int index) {
309+
private RelevantElement getImmediateChildAdjMin(int index) {
254310
// A child may have multiple positions for different accessors,
255311
// so always use the first
256312
result = this.getChild(index) and
257313
index = min(int i | result = this.getChild(i) | i)
258314
}
259315

260316
pragma[nomagic]
261-
RelevantElement getImmediateChild(int index) {
317+
RelevantElement getImmediateChildAdj(int index) {
262318
result =
263-
rank[index + 1](Element res, int i | res = this.getImmediateChildMin(i) | res order by i)
319+
rank[index + 1](Element res, int i | res = this.getImmediateChildAdjMin(i) | res order by i)
264320
}
265321

266322
pragma[nomagic]
267323
RelevantElement getImmediateLastChild() {
268324
exists(int last |
269-
result = this.getImmediateChild(last) and
270-
not exists(this.getImmediateChild(last + 1))
325+
result = this.getImmediateChildAdj(last) and
326+
not exists(this.getImmediateChildAdj(last + 1))
271327
)
272328
}
273329
}
@@ -288,13 +344,13 @@ module Impl {
288344
|
289345
// first child of a previously numbered node
290346
result = getPreOrderNumbering(scope, parent) + 1 and
291-
n = parent.getImmediateChild(0)
347+
n = parent.getImmediateChildAdj(0)
292348
or
293349
// non-first child of a previously numbered node
294350
exists(RelevantElement child, int i |
295351
result = getLastPreOrderNumbering(scope, child) + 1 and
296-
child = parent.getImmediateChild(i) and
297-
n = parent.getImmediateChild(i + 1)
352+
child = parent.getImmediateChildAdj(i) and
353+
n = parent.getImmediateChildAdj(i + 1)
298354
)
299355
)
300356
}
@@ -309,7 +365,7 @@ module Impl {
309365
result = getPreOrderNumbering(scope, leaf) and
310366
leaf != scope and
311367
(
312-
not exists(leaf.getImmediateChild(_))
368+
not exists(leaf.getImmediateChildAdj(_))
313369
or
314370
leaf instanceof VariableScope
315371
)
@@ -331,7 +387,7 @@ module Impl {
331387
/**
332388
* Holds if `v` is named `name` and is declared inside variable scope
333389
* `scope`. The pre-order numbering of the binding site of `v`, amongst
334-
* all nodes nester under `scope`, is `ord`.
390+
* all nodes nested under `scope`, is `ord`.
335391
*/
336392
private predicate variableDeclInScope(Variable v, VariableScope scope, string name, int ord) {
337393
name = v.getText() and
@@ -354,25 +410,19 @@ module Impl {
354410
ord = getLastPreOrderNumbering(scope, let) + 1
355411
)
356412
or
357-
exists(IfExpr ie, LetExpr let |
413+
exists(LetExpr let |
358414
let.getPat() = pat and
359-
ie.getCondition() = let and
360-
scope = ie.getThen() and
361-
ord = getPreOrderNumbering(scope, scope)
415+
scope = getEnclosingScope(let) and
416+
// for `let` expressions, variables are bound _after_ the statement, i.e.
417+
// not in the RHS
418+
ord = getLastPreOrderNumbering(scope, let) + 1
362419
)
363420
or
364421
exists(ForExpr fe |
365422
fe.getPat() = pat and
366423
scope = fe.getLoopBody() and
367424
ord = getPreOrderNumbering(scope, scope)
368425
)
369-
or
370-
exists(WhileExpr we, LetExpr let |
371-
let.getPat() = pat and
372-
we.getCondition() = let and
373-
scope = we.getLoopBody() and
374-
ord = getPreOrderNumbering(scope, scope)
375-
)
376426
)
377427
)
378428
}
@@ -612,7 +662,7 @@ module Impl {
612662
or
613663
exists(Expr mid |
614664
assignmentExprDescendant(mid) and
615-
getImmediateParent(e) = mid and
665+
getImmediateParentAdj(e) = mid and
616666
not mid instanceof DerefExpr and
617667
not mid instanceof FieldExpr and
618668
not mid instanceof IndexExpr

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ private module CertainTypeInference {
328328
let.getInitializer() = n2
329329
)
330330
or
331+
exists(LetExpr let |
332+
let.getPat() = n1 and
333+
let.getScrutinee() = n2
334+
)
335+
or
331336
n1 = n2.(ParenExpr).getExpr()
332337
)
333338
or
@@ -466,11 +471,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
466471
or
467472
n1 = n2.(MatchExpr).getAnArm().getExpr()
468473
or
469-
exists(LetExpr let |
470-
n1 = let.getScrutinee() and
471-
n2 = let.getPat()
472-
)
473-
or
474474
exists(MatchExpr me |
475475
n1 = me.getScrutinee() and
476476
n2 = me.getAnArm().getPat()

0 commit comments

Comments
 (0)