Skip to content

Commit 35b1978

Browse files
committed
Accept stringref arguments to .extract()
1 parent b3b897b commit 35b1978

File tree

6 files changed

+33
-22
lines changed

6 files changed

+33
-22
lines changed

mlir/lib/Tools/mlir-query/MatchersInternal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ class DynMatcher {
7979
DynMatcher *clone() const { return new DynMatcher(*this); }
8080

8181
void setExtract(bool extractFunction) { ExtractFunction = extractFunction; };
82+
void setFunctionName(StringRef functionName) { FunctionName = functionName; };
8283

8384
bool getExtract() const { return ExtractFunction; };
85+
StringRef getFunctionName() const { return FunctionName; };
8486

8587
private:
8688
MLIRNodeKind SupportedKind;
@@ -91,6 +93,7 @@ class DynMatcher {
9193
MLIRNodeKind RestrictKind;
9294
llvm::IntrusiveRefCntPtr<DynMatcherInterface> Implementation;
9395
bool ExtractFunction;
96+
StringRef FunctionName;
9497
};
9598

9699
// Wrapper of a MatcherInterface<T> *

mlir/lib/Tools/mlir-query/Parser.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,11 +324,13 @@ bool Parser::parseMatcherExpressionImpl(VariantValue *Value) {
324324
}
325325

326326
bool extractFunction = false;
327+
StringRef functionName;
327328
if (Tokenizer->peekNextToken().Kind == TokenInfo::TK_Period) {
328329
// Parse ".extract()"
329330
Tokenizer->consumeNextToken(); // consume the period.
330331
const TokenInfo ExtractToken = Tokenizer->consumeNextToken();
331332
const TokenInfo OpenToken = Tokenizer->consumeNextToken();
333+
const TokenInfo NameToken = Tokenizer->consumeNextToken();
332334
const TokenInfo CloseToken = Tokenizer->consumeNextToken();
333335

334336
// TODO: We could use different error codes for each/some to be more
@@ -342,18 +344,25 @@ bool Parser::parseMatcherExpressionImpl(VariantValue *Value) {
342344
Error->addError(OpenToken.Range, Error->ET_ParserMalformedBindExpr);
343345
return false;
344346
}
347+
if (NameToken.Kind != TokenInfo::TK_Literal ||
348+
!NameToken.Value.isString()) {
349+
Error->addError(NameToken.Range, Error->ET_ParserMalformedBindExpr);
350+
return false;
351+
}
345352
if (CloseToken.Kind != TokenInfo::TK_CloseParen) {
346353
Error->addError(CloseToken.Range, Error->ET_ParserMalformedBindExpr);
347354
return false;
348355
}
356+
functionName = NameToken.Value.getString();
349357
extractFunction = true;
358+
LLVM_DEBUG(DBGS() << "Function Name: " << functionName << "\n");
350359
}
351360

352361
// Merge the start and end infos.
353362
SourceRange MatcherRange = NameToken.Range;
354363
MatcherRange.End = EndToken.Range.End;
355-
DynMatcher *Result = S->actOnMatcherExpression(NameToken.Text, MatcherRange,
356-
extractFunction, Args, Error);
364+
DynMatcher *Result = S->actOnMatcherExpression(
365+
NameToken.Text, MatcherRange, extractFunction, functionName, Args, Error);
357366

358367
if (Result == nullptr) {
359368
// TODO: Add appropriate error.
@@ -415,10 +424,11 @@ class RegistrySema : public Parser::Sema {
415424
DynMatcher *actOnMatcherExpression(StringRef MatcherName,
416425
const SourceRange &NameRange,
417426
bool ExtractFunction,
427+
StringRef FunctionName,
418428
ArrayRef<ParserValue> Args,
419429
Diagnostics *Error) override {
420-
return Registry::constructMatcherWrapper(MatcherName, NameRange,
421-
ExtractFunction, Args, Error);
430+
return Registry::constructMatcherWrapper(
431+
MatcherName, NameRange, ExtractFunction, FunctionName, Args, Error);
422432
}
423433
};
424434

mlir/lib/Tools/mlir-query/Parser.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,10 @@ class Parser {
6060
// if an error occurred. In that case, Error will contain a
6161
// description of the error.
6262
// The caller takes ownership of the Matcher object returned.
63-
virtual DynMatcher *actOnMatcherExpression(StringRef MatcherName,
64-
const SourceRange &NameRange,
65-
bool ExtractFunction,
66-
ArrayRef<ParserValue> Args,
67-
Diagnostics *Error) = 0;
63+
virtual DynMatcher *
64+
actOnMatcherExpression(StringRef MatcherName, const SourceRange &NameRange,
65+
bool ExtractFunction, StringRef FunctionName,
66+
ArrayRef<ParserValue> Args, Diagnostics *Error) = 0;
6867
};
6968

7069
// Parse a matcher expression, creating matchers from the registry.

mlir/lib/Tools/mlir-query/Query.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ getMatches(Operation *rootOp, const matcher::DynMatcher *matcher) {
5656

5757
// TODO: Only supports operation node type.
5858
Operation *extractFunction(std::vector<matcher::DynTypedNode> &nodes,
59-
OpBuilder builder) {
59+
OpBuilder builder, StringRef functionName) {
6060
std::vector<Operation *> slice;
6161
std::vector<Value> values;
6262

@@ -81,7 +81,7 @@ Operation *extractFunction(std::vector<matcher::DynTypedNode> &nodes,
8181

8282
auto loc = builder.getUnknownLoc();
8383
func::FuncOp funcOp = func::FuncOp::create(
84-
loc, "extracted",
84+
loc, functionName,
8585
builder.getFunctionType(ValueRange(values), resultType));
8686

8787
loc = funcOp.getLoc();
@@ -119,10 +119,11 @@ bool MatchQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const {
119119
auto matches = getMatches(rootOp, matcher);
120120

121121
if (matcher->getExtract()) {
122+
auto functionName = matcher->getFunctionName();
122123
MLIRContext context;
123124
context.loadDialect<func::FuncDialect>();
124125
OpBuilder builder(&context);
125-
Operation *function = extractFunction(matches, builder);
126+
Operation *function = extractFunction(matches, builder, functionName);
126127
OS << "\n\n" << *function << "\n\n\n";
127128
} else {
128129
unsigned MatchCount = 0;

mlir/lib/Tools/mlir-query/Registry.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,15 @@ DynMatcher *Registry::constructMatcher(StringRef MatcherName,
120120
}
121121

122122
// static
123-
DynMatcher *Registry::constructMatcherWrapper(StringRef MatcherName,
124-
const SourceRange &NameRange,
125-
bool ExtractFunction,
126-
ArrayRef<ParserValue> Args,
127-
Diagnostics *Error) {
123+
DynMatcher *Registry::constructMatcherWrapper(
124+
StringRef MatcherName, const SourceRange &NameRange, bool ExtractFunction,
125+
StringRef FunctionName, ArrayRef<ParserValue> Args, Diagnostics *Error) {
128126

129127
DynMatcher *Out = constructMatcher(MatcherName, NameRange, Args, Error);
130128
if (!Out)
131129
return Out;
132130
Out->setExtract(ExtractFunction);
131+
Out->setFunctionName(FunctionName);
133132
return Out;
134133
}
135134

mlir/lib/Tools/mlir-query/Registry.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ class Registry {
4545
const SourceRange &NameRange,
4646
ArrayRef<ParserValue> Args,
4747
Diagnostics *Error);
48-
static DynMatcher *constructMatcherWrapper(StringRef MatcherName,
49-
const SourceRange &NameRange,
50-
bool ExtractFunction,
51-
ArrayRef<ParserValue> Args,
52-
Diagnostics *Error);
48+
static DynMatcher *
49+
constructMatcherWrapper(StringRef MatcherName, const SourceRange &NameRange,
50+
bool ExtractFunction, StringRef FunctionName,
51+
ArrayRef<ParserValue> Args, Diagnostics *Error);
5352
};
5453

5554
} // namespace matcher

0 commit comments

Comments
 (0)