Skip to content

Commit a28a718

Browse files
authored
Merge pull request #20814 from aschackmull/guards/wrapper-perf
Guards: Improve join-order for wrapper guards
2 parents bfa3562 + b31dfdd commit a28a718

File tree

1 file changed

+83
-22
lines changed

1 file changed

+83
-22
lines changed

shared/controlflow/codeql/controlflow/Guards.qll

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,8 @@ module Make<
10081008
* wrappers.
10091009
*/
10101010
private module WrapperGuard {
1011+
private import codeql.util.DenseRank
1012+
10111013
final private class FinalExpr = Expr;
10121014

10131015
class ReturnExpr extends FinalExpr {
@@ -1019,21 +1021,58 @@ module Make<
10191021
BasicBlock getBasicBlock() { result = super.getBasicBlock() }
10201022
}
10211023

1024+
private module DenseRankInput implements DenseRankInputSig1 {
1025+
class C = NonOverridableMethod;
1026+
1027+
class Ranked = ReturnExpr;
1028+
1029+
int getRank(NonOverridableMethod m, ReturnExpr ret) {
1030+
m.getAReturnExpr() = ret and
1031+
result = ret.getLocation().getStartLine()
1032+
}
1033+
}
1034+
1035+
private module ReturnExprRank = DenseRank1<DenseRankInput>;
1036+
1037+
private predicate rankedReturnExpr = ReturnExprRank::denseRank/2;
1038+
1039+
private int maxRank(NonOverridableMethod m) {
1040+
result = max(int rnk | exists(rankedReturnExpr(m, rnk)))
1041+
}
1042+
10221043
private predicate relevantCallValue(NonOverridableMethodCall call, GuardValue val) {
10231044
BranchImplies::guardControls(call, val, _, _) or
10241045
ReturnImplies::guardControls(call, val, _, _)
10251046
}
10261047

1027-
predicate relevantReturnValue(NonOverridableMethod m, GuardValue val) {
1048+
/**
1049+
* Holds if a call to `m` having a return value of `retval` is reachable
1050+
* by a chain of implications.
1051+
*/
1052+
predicate relevantReturnValue(NonOverridableMethod m, GuardValue retval) {
10281053
exists(NonOverridableMethodCall call |
1029-
relevantCallValue(call, val) and
1054+
relevantCallValue(call, retval) and
10301055
call.getMethod() = m and
1031-
not val instanceof TException
1056+
not retval instanceof TException
1057+
)
1058+
}
1059+
1060+
/**
1061+
* Holds if a call to `m` having a return value of `retval` is reachable
1062+
* by a chain of implications, and `ret` is a return expression in `m`
1063+
* that could possibly have the value `retval`.
1064+
*/
1065+
predicate relevantReturnExprValue(NonOverridableMethod m, ReturnExpr ret, GuardValue retval) {
1066+
relevantReturnValue(m, retval) and
1067+
ret = m.getAReturnExpr() and
1068+
not exists(GuardValue notRetval |
1069+
exprHasValue(ret, notRetval) and
1070+
disjointValues(notRetval, retval)
10321071
)
10331072
}
10341073

10351074
private predicate returnGuard(Guard guard, GuardValue val) {
1036-
relevantReturnValue(guard.(ReturnExpr).getMethod(), val)
1075+
relevantReturnExprValue(_, guard, val)
10371076
}
10381077

10391078
module ReturnImplies = ImpliesTC<returnGuard/2>;
@@ -1043,28 +1082,58 @@ module Make<
10431082
guard.directlyValueControls(ret.getBasicBlock(), val)
10441083
}
10451084

1085+
private predicate parameterControlsReturnExpr(
1086+
SsaParameterInit param, GuardValue val, ReturnExpr ret
1087+
) {
1088+
exists(Guard g0, GuardValue v0 |
1089+
directlyControlsReturn(g0, v0, ret) and
1090+
BranchImplies::ssaControls(param, val, g0, v0)
1091+
)
1092+
}
1093+
10461094
/**
10471095
* Holds if `ret` is a return expression in a non-overridable method that
10481096
* on a return value of `retval` allows the conclusion that the `ppos`th
10491097
* parameter has the value `val`.
10501098
*/
10511099
private predicate validReturnInCustomGuard(
1052-
ReturnExpr ret, ParameterPosition ppos, GuardValue retval, GuardValue val
1100+
ReturnExpr ret, int rnk, NonOverridableMethod m, ParameterPosition ppos, GuardValue retval,
1101+
GuardValue val
10531102
) {
1054-
exists(NonOverridableMethod m, SsaParameterInit param |
1055-
m.getAReturnExpr() = ret and
1103+
exists(SsaParameterInit param |
1104+
ret = rankedReturnExpr(m, rnk) and
10561105
param.getParameter() = m.getParameter(ppos)
10571106
|
1058-
exists(Guard g0, GuardValue v0 |
1059-
directlyControlsReturn(g0, v0, ret) and
1060-
BranchImplies::ssaControls(param, val, g0, v0) and
1061-
relevantReturnValue(m, retval)
1062-
)
1107+
parameterControlsReturnExpr(param, val, ret) and
1108+
relevantReturnExprValue(m, ret, retval)
10631109
or
10641110
ReturnImplies::ssaControls(param, val, ret, retval)
10651111
)
10661112
}
10671113

1114+
private predicate validReturnInCustomGuardToRank(
1115+
int rnk, NonOverridableMethod m, ParameterPosition ppos, GuardValue retval, GuardValue val
1116+
) {
1117+
// The forall-range has been pushed all the way into
1118+
// `relevantReturnExprValue` and `validReturnInCustomGuard`. This means
1119+
// that this base case ensures that at least one return expression
1120+
// non-vacuously satisfies that it's a valid implication from return
1121+
// value to parameter value.
1122+
validReturnInCustomGuard(_, _, m, ppos, retval, val) and rnk = 0
1123+
or
1124+
validReturnInCustomGuardToRank(rnk - 1, m, ppos, retval, val) and
1125+
rnk <= maxRank(m) and
1126+
forall(ReturnExpr ret |
1127+
ret = rankedReturnExpr(m, rnk) and
1128+
not exists(GuardValue notRetval |
1129+
exprHasValue(ret, notRetval) and
1130+
disjointValues(notRetval, retval)
1131+
)
1132+
|
1133+
validReturnInCustomGuard(ret, rnk, m, ppos, retval, val)
1134+
)
1135+
}
1136+
10681137
private predicate guardDirectlyControlsExit(Guard guard, GuardValue val) {
10691138
exists(BasicBlock bb |
10701139
guard.directlyValueControls(bb, val) and
@@ -1080,15 +1149,7 @@ module Make<
10801149
private NonOverridableMethod wrapperGuard(
10811150
ParameterPosition ppos, GuardValue retval, GuardValue val
10821151
) {
1083-
forex(ReturnExpr ret |
1084-
result.getAReturnExpr() = ret and
1085-
not exists(GuardValue notRetval |
1086-
exprHasValue(ret, notRetval) and
1087-
disjointValues(notRetval, retval)
1088-
)
1089-
|
1090-
validReturnInCustomGuard(ret, ppos, retval, val)
1091-
)
1152+
validReturnInCustomGuardToRank(maxRank(result), result, ppos, retval, val)
10921153
or
10931154
exists(SsaParameterInit param, Guard g0, GuardValue v0 |
10941155
param.getParameter() = result.getParameter(ppos) and
@@ -1166,7 +1227,7 @@ module Make<
11661227
guardChecksDef(guard, param, val, state)
11671228
|
11681229
guard.valueControls(ret.getBasicBlock(), val) and
1169-
relevantReturnValue(m, retval)
1230+
relevantReturnExprValue(m, ret, retval)
11701231
or
11711232
ReturnImplies::guardControls(guard, val, ret, retval)
11721233
)

0 commit comments

Comments
 (0)