@@ -7,22 +7,53 @@ private import codeql.ruby.controlflow.CfgNodes
77private import codeql.ruby.dataflow.SSA
88private import codeql.ruby.ast.internal.Constant
99private import codeql.ruby.InclusionTests
10+ private import codeql.ruby.ast.internal.Literal
1011
11- private predicate stringConstCompare ( CfgNodes:: ExprCfgNode g , CfgNode e , boolean branch ) {
12+ cached
13+ private predicate stringConstCompare ( CfgNodes:: AstCfgNode guard , CfgNode testedNode , boolean branch ) {
1214 exists ( CfgNodes:: ExprNodes:: ComparisonOperationCfgNode c |
13- c = g and
15+ c = guard and
1416 exists ( CfgNodes:: ExprNodes:: StringLiteralCfgNode strLitNode |
15- c .getExpr ( ) instanceof EqExpr and branch = true
17+ // Only consider strings without any interpolations
18+ not strLitNode .getExpr ( ) .getComponent ( _) instanceof StringInterpolationComponent and
19+ c .getExpr ( ) instanceof EqExpr and
20+ branch = true
1621 or
1722 c .getExpr ( ) instanceof CaseEqExpr and branch = true
1823 or
1924 c .getExpr ( ) instanceof NEExpr and branch = false
2025 |
21- c .getLeftOperand ( ) = strLitNode and c .getRightOperand ( ) = e
26+ c .getLeftOperand ( ) = strLitNode and c .getRightOperand ( ) = testedNode
2227 or
23- c .getLeftOperand ( ) = e and c .getRightOperand ( ) = strLitNode
28+ c .getLeftOperand ( ) = testedNode and c .getRightOperand ( ) = strLitNode
2429 )
2530 )
31+ or
32+ stringConstCaseCompare ( guard , testedNode , branch )
33+ or
34+ exists ( CfgNodes:: ExprNodes:: BinaryOperationCfgNode g |
35+ g = guard and
36+ stringConstCompareOr ( guard , branch ) and
37+ stringConstCompare ( g .getLeftOperand ( ) , testedNode , _)
38+ )
39+ }
40+
41+ /**
42+ * Holds if `guard` is an `or` expression whose operands are string comparison guards.
43+ * For example:
44+ *
45+ * ```rb
46+ * x == "foo" or x == "bar"
47+ * ```
48+ */
49+ private predicate stringConstCompareOr (
50+ CfgNodes:: ExprNodes:: BinaryOperationCfgNode guard , boolean branch
51+ ) {
52+ guard .getExpr ( ) instanceof LogicalOrExpr and
53+ branch = true and
54+ forall ( CfgNode innerGuard | innerGuard = guard .getAnOperand ( ) |
55+ stringConstCompare ( innerGuard , any ( Ssa:: Definition def ) .getARead ( ) , branch )
56+ )
2657}
2758
2859/**
@@ -72,10 +103,13 @@ deprecated class StringConstCompare extends DataFlow::BarrierGuard,
72103 }
73104}
74105
75- private predicate stringConstArrayInclusionCall ( CfgNodes:: ExprCfgNode g , CfgNode e , boolean branch ) {
106+ cached
107+ private predicate stringConstArrayInclusionCall (
108+ CfgNodes:: AstCfgNode guard , CfgNode testedNode , boolean branch
109+ ) {
76110 exists ( InclusionTest t |
77- t .asExpr ( ) = g and
78- e = t .getContainedNode ( ) .asExpr ( ) and
111+ t .asExpr ( ) = guard and
112+ testedNode = t .getContainedNode ( ) .asExpr ( ) and
79113 branch = t .getPolarity ( )
80114 |
81115 exists ( ExprNodes:: ArrayLiteralCfgNode arr |
@@ -132,3 +166,68 @@ deprecated class StringConstArrayInclusionCall extends DataFlow::BarrierGuard,
132166
133167 override predicate checks ( CfgNode expr , boolean branch ) { expr = checkedNode and branch = true }
134168}
169+
170+ /**
171+ * A validation of a value by comparing with a constant string via a `case`
172+ * expression. For example:
173+ *
174+ * ```rb
175+ * name = params[:user_name]
176+ * case name
177+ * when "alice"
178+ * User.find_by("username = #{name}")
179+ * when *["bob", "charlie"]
180+ * User.find_by("username = #{name}")
181+ * when "dave", "eve" # this is not yet recognised as a barrier guard
182+ * User.find_by("username = #{name}")
183+ * end
184+ * ```
185+ */
186+ private predicate stringConstCaseCompare (
187+ CfgNodes:: AstCfgNode guard , CfgNode testedNode , boolean branch
188+ ) {
189+ branch = true and
190+ exists ( CfgNodes:: ExprNodes:: CaseExprCfgNode case |
191+ case .getValue ( ) = testedNode and
192+ (
193+ guard =
194+ any ( CfgNodes:: ExprNodes:: WhenClauseCfgNode branchNode |
195+ branchNode = case .getBranch ( _) and
196+ // For simplicity, consider patterns that contain only string literals or arrays of string literals
197+ forall ( ExprCfgNode pattern | pattern = branchNode .getPattern ( _) |
198+ // when "foo"
199+ // when "foo", "bar"
200+ pattern instanceof ExprNodes:: StringLiteralCfgNode
201+ or
202+ pattern =
203+ any ( CfgNodes:: ExprNodes:: SplatExprCfgNode splat |
204+ // when *["foo", "bar"]
205+ forex ( ExprCfgNode elem |
206+ elem = splat .getOperand ( ) .( ExprNodes:: ArrayLiteralCfgNode ) .getAnArgument ( )
207+ |
208+ elem instanceof ExprNodes:: StringLiteralCfgNode
209+ )
210+ or
211+ // when *some_var
212+ // when *SOME_CONST
213+ exists ( ExprNodes:: ArrayLiteralCfgNode arr |
214+ isArrayConstant ( splat .getOperand ( ) , arr ) and
215+ forall ( ExprCfgNode elem | elem = arr .getAnArgument ( ) |
216+ elem instanceof ExprNodes:: StringLiteralCfgNode
217+ )
218+ )
219+ )
220+ )
221+ )
222+ or
223+ // in "foo"
224+ exists (
225+ CfgNodes:: ExprNodes:: InClauseCfgNode branchNode , ExprNodes:: StringLiteralCfgNode pattern
226+ |
227+ branchNode = case .getBranch ( _) and
228+ pattern = branchNode .getPattern ( ) and
229+ guard = pattern
230+ )
231+ )
232+ )
233+ }
0 commit comments