Skip to content

Commit e808376

Browse files
committed
C#: Refactor LINQ logic
Factor `ClauseCall` out into three classes to make it clear when the fields `operand` and `declaration` can be `null`.
1 parent 6411d1c commit e808376

File tree

1 file changed

+102
-87
lines changed
  • csharp/extractor/Semmle.Extraction.CSharp/Entities/Expressions

1 file changed

+102
-87
lines changed

csharp/extractor/Semmle.Extraction.CSharp/Entities/Expressions/Query.cs

Lines changed: 102 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -33,58 +33,37 @@ public QueryCall(Context cx, IMethodSymbol method, SyntaxNode clause, IExpressio
3333
/// <summary>
3434
/// Represents a chain of method calls (the operand being recursive).
3535
/// </summary>
36-
class ClauseCall
36+
abstract class Clause
3737
{
38-
public ClauseCall operand;
39-
public IMethodSymbol method;
40-
public readonly List<ExpressionSyntax> arguments = new List<ExpressionSyntax>();
41-
public SyntaxNode node;
42-
public ISymbol declaration;
43-
public SyntaxToken name;
44-
public ISymbol intoDeclaration;
38+
protected readonly IMethodSymbol method;
39+
protected readonly List<ExpressionSyntax> arguments = new List<ExpressionSyntax>();
40+
protected readonly SyntaxNode node;
4541

46-
public ExpressionSyntax Expr => arguments.First();
47-
48-
public ClauseCall WithClause(IMethodSymbol newMethod, SyntaxNode newNode, SyntaxToken newName = default(SyntaxToken), ISymbol newDeclaration = null)
42+
protected Clause(IMethodSymbol method, SyntaxNode node)
4943
{
50-
return new ClauseCall
51-
{
52-
operand = this,
53-
method = newMethod,
54-
node = newNode,
55-
name = newName,
56-
declaration = newDeclaration
57-
};
44+
this.method = method;
45+
this.node = node;
5846
}
5947

60-
public ClauseCall AddArgument(ExpressionSyntax arg)
61-
{
62-
if (arg != null)
63-
arguments.Add(arg);
64-
return this;
65-
}
48+
public ExpressionSyntax Expr => arguments.First();
6649

67-
public ClauseCall WithInto(ISymbol into)
68-
{
69-
intoDeclaration = into;
70-
return this;
71-
}
50+
public CallClause WithCallClause(IMethodSymbol newMethod, SyntaxNode newNode) =>
51+
new CallClause(this, newMethod, newNode);
7252

73-
Expression DeclareRangeVariable(Context cx, IExpressionParentEntity parent, int child, bool getElement)
74-
{
75-
return DeclareRangeVariable(cx, parent, child, getElement, declaration);
76-
}
53+
public LetClause WithLetClause(IMethodSymbol newMethod, SyntaxNode newNode, ISymbol newDeclaration, SyntaxToken newName) =>
54+
new LetClause(this, newMethod, newNode, newDeclaration, newName);
7755

78-
void DeclareIntoVariable(Context cx, IExpressionParentEntity parent, int intoChild, bool getElement)
56+
public Clause AddArgument(ExpressionSyntax arg)
7957
{
80-
if (intoDeclaration != null)
81-
DeclareRangeVariable(cx, parent, intoChild, getElement, intoDeclaration);
58+
if (arg != null)
59+
arguments.Add(arg);
60+
return this;
8261
}
8362

84-
Expression DeclareRangeVariable(Context cx, IExpressionParentEntity parent, int child, bool getElement, ISymbol variableSymbol)
63+
protected Expression DeclareRangeVariable(Context cx, IExpressionParentEntity parent, int child, bool getElement, ISymbol variableSymbol, SyntaxToken name)
8564
{
8665
var type = Type.Create(cx, cx.GetType(Expr));
87-
Semmle.Extraction.Entities.Location nameLoc;
66+
Extraction.Entities.Location nameLoc;
8867

8968
Type declType;
9069
if (getElement)
@@ -115,42 +94,87 @@ Expression DeclareRangeVariable(Context cx, IExpressionParentEntity parent, int
11594
return decl;
11695
}
11796

118-
void PopulateArguments(Context cx, QueryCall callExpr, int child)
97+
protected void PopulateArguments(Context cx, QueryCall callExpr, int child)
11998
{
12099
foreach (var e in arguments)
121100
{
122101
Expression.Create(cx, e, callExpr, child++);
123102
}
124103
}
125104

126-
public Expression Populate(Context cx, IExpressionParentEntity parent, int child)
105+
public abstract Expression Populate(Context cx, IExpressionParentEntity parent, int child);
106+
}
107+
108+
class RangeClause : Clause
109+
{
110+
readonly ISymbol declaration;
111+
readonly SyntaxToken name;
112+
113+
public RangeClause(IMethodSymbol method, SyntaxNode node, ISymbol declaration, SyntaxToken name) : base(method, node)
127114
{
128-
if (declaration != null) // The first "from" clause, or a "let" clause
129-
{
130-
if (operand == null)
131-
{
132-
return DeclareRangeVariable(cx, parent, child, true);
133-
}
134-
else
135-
{
136-
if (method == null)
137-
cx.ModelError(node, "Unable to determine target of query expression");
138-
139-
var callExpr = new QueryCall(cx, method, node, parent, child);
140-
operand.Populate(cx, callExpr, 0);
141-
DeclareRangeVariable(cx, callExpr, 1, false);
142-
PopulateArguments(cx, callExpr, 2);
143-
DeclareIntoVariable(cx, callExpr, 2 + arguments.Count, false);
144-
return callExpr;
145-
}
146-
}
147-
else
148-
{
149-
var callExpr = new QueryCall(cx, method, node, parent, child);
150-
operand.Populate(cx, callExpr, 0);
151-
PopulateArguments(cx, callExpr, 1);
152-
return callExpr;
153-
}
115+
this.declaration = declaration;
116+
this.name = name;
117+
}
118+
119+
public override Expression Populate(Context cx, IExpressionParentEntity parent, int child) =>
120+
DeclareRangeVariable(cx, parent, child, true, declaration, name);
121+
}
122+
123+
class LetClause : Clause
124+
{
125+
readonly Clause operand;
126+
readonly ISymbol declaration;
127+
readonly SyntaxToken name;
128+
ISymbol intoDeclaration;
129+
130+
public LetClause(Clause operand, IMethodSymbol method, SyntaxNode node, ISymbol declaration, SyntaxToken name) : base(method, node)
131+
{
132+
this.operand = operand;
133+
this.declaration = declaration;
134+
this.name = name;
135+
}
136+
137+
public Clause WithInto(ISymbol into)
138+
{
139+
intoDeclaration = into;
140+
return this;
141+
}
142+
143+
void DeclareIntoVariable(Context cx, IExpressionParentEntity parent, int intoChild, bool getElement)
144+
{
145+
if (intoDeclaration != null)
146+
DeclareRangeVariable(cx, parent, intoChild, getElement, intoDeclaration, name);
147+
}
148+
149+
public override Expression Populate(Context cx, IExpressionParentEntity parent, int child)
150+
{
151+
if (method == null)
152+
cx.ModelError(node, "Unable to determine target of query expression");
153+
154+
var callExpr = new QueryCall(cx, method, node, parent, child);
155+
operand.Populate(cx, callExpr, 0);
156+
DeclareRangeVariable(cx, callExpr, 1, false, declaration, name);
157+
PopulateArguments(cx, callExpr, 2);
158+
DeclareIntoVariable(cx, callExpr, 2 + arguments.Count, false);
159+
return callExpr;
160+
}
161+
}
162+
163+
class CallClause : Clause
164+
{
165+
readonly Clause operand;
166+
167+
public CallClause(Clause operand, IMethodSymbol method, SyntaxNode node) : base(method, node)
168+
{
169+
this.operand = operand;
170+
}
171+
172+
public override Expression Populate(Context cx, IExpressionParentEntity parent, int child)
173+
{
174+
var callExpr = new QueryCall(cx, method, node, parent, child);
175+
operand.Populate(cx, callExpr, 0);
176+
PopulateArguments(cx, callExpr, 1);
177+
return callExpr;
154178
}
155179
}
156180

@@ -161,18 +185,12 @@ public Expression Populate(Context cx, IExpressionParentEntity parent, int child
161185
/// <param name="cx">The extraction context.</param>
162186
/// <param name="node">The query expression.</param>
163187
/// <returns>A "syntax tree" of the query.</returns>
164-
static ClauseCall ConstructQueryExpression(Context cx, QueryExpressionSyntax node)
188+
static Clause ConstructQueryExpression(Context cx, QueryExpressionSyntax node)
165189
{
166190
var info = cx.Model(node).GetQueryClauseInfo(node.FromClause);
167191
var method = info.OperationInfo.Symbol as IMethodSymbol;
168192

169-
ClauseCall clauseExpr = new ClauseCall
170-
{
171-
declaration = cx.Model(node).GetDeclaredSymbol(node.FromClause),
172-
name = node.FromClause.Identifier,
173-
method = method,
174-
node = node.FromClause
175-
}.AddArgument(node.FromClause.Expression);
193+
Clause clauseExpr = new RangeClause(method, node.FromClause, cx.Model(node).GetDeclaredSymbol(node.FromClause), node.FromClause.Identifier).AddArgument(node.FromClause.Expression);
176194

177195
foreach (var qc in node.Body.Clauses)
178196
{
@@ -188,39 +206,39 @@ static ClauseCall ConstructQueryExpression(Context cx, QueryExpressionSyntax nod
188206
{
189207
method = cx.Model(node).GetSymbolInfo(ordering).Symbol as IMethodSymbol;
190208

191-
clauseExpr = clauseExpr.WithClause(method, orderByClause).AddArgument(ordering.Expression);
209+
clauseExpr = clauseExpr.WithCallClause(method, orderByClause).AddArgument(ordering.Expression);
192210

193211
if (method == null)
194212
cx.ModelError(ordering, "Could not determine method call for orderby clause");
195213
}
196214
break;
197215
case SyntaxKind.WhereClause:
198216
var whereClause = (WhereClauseSyntax)qc;
199-
clauseExpr = clauseExpr.WithClause(method, whereClause).AddArgument(whereClause.Condition);
217+
clauseExpr = clauseExpr.WithCallClause(method, whereClause).AddArgument(whereClause.Condition);
200218
break;
201219
case SyntaxKind.FromClause:
202220
var fromClause = (FromClauseSyntax)qc;
203221
clauseExpr = clauseExpr.
204-
WithClause(method, fromClause, fromClause.Identifier, cx.Model(node).GetDeclaredSymbol(fromClause)).
222+
WithLetClause(method, fromClause, cx.Model(node).GetDeclaredSymbol(fromClause), fromClause.Identifier).
205223
AddArgument(fromClause.Expression);
206224
break;
207225
case SyntaxKind.LetClause:
208226
var letClause = (LetClauseSyntax)qc;
209-
clauseExpr = clauseExpr.WithClause(method, letClause, letClause.Identifier, cx.Model(node).GetDeclaredSymbol(letClause)).
227+
clauseExpr = clauseExpr.WithLetClause(method, letClause, cx.Model(node).GetDeclaredSymbol(letClause), letClause.Identifier).
210228
AddArgument(letClause.Expression);
211229
break;
212230
case SyntaxKind.JoinClause:
213231
var joinClause = (JoinClauseSyntax)qc;
214232

215-
clauseExpr = clauseExpr.WithClause(method, joinClause, joinClause.Identifier, cx.Model(node).GetDeclaredSymbol(joinClause)).
233+
clauseExpr = clauseExpr.WithLetClause(method, joinClause, cx.Model(node).GetDeclaredSymbol(joinClause), joinClause.Identifier).
216234
AddArgument(joinClause.InExpression).
217235
AddArgument(joinClause.LeftExpression).
218236
AddArgument(joinClause.RightExpression);
219237

220238
if (joinClause.Into != null)
221239
{
222240
var into = cx.Model(node).GetDeclaredSymbol(joinClause.Into);
223-
clauseExpr.WithInto(into);
241+
((LetClause)clauseExpr).WithInto(into);
224242
}
225243

226244
break;
@@ -231,16 +249,13 @@ static ClauseCall ConstructQueryExpression(Context cx, QueryExpressionSyntax nod
231249

232250
method = cx.Model(node).GetSymbolInfo(node.Body.SelectOrGroup).Symbol as IMethodSymbol;
233251

234-
var selectClause = node.Body.SelectOrGroup as SelectClauseSyntax;
235-
var groupClause = node.Body.SelectOrGroup as GroupClauseSyntax;
236-
237-
clauseExpr = new ClauseCall { operand = clauseExpr, method = method, node = node.Body.SelectOrGroup };
252+
clauseExpr = new CallClause(clauseExpr, method, node.Body.SelectOrGroup);
238253

239-
if (selectClause != null)
254+
if (node.Body.SelectOrGroup is SelectClauseSyntax selectClause)
240255
{
241256
clauseExpr.AddArgument(selectClause.Expression);
242257
}
243-
else if (groupClause != null)
258+
else if (node.Body.SelectOrGroup is GroupClauseSyntax groupClause)
244259
{
245260
clauseExpr.
246261
AddArgument(groupClause.GroupExpression).

0 commit comments

Comments
 (0)