-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][tblgen] Add custom parsing and printing within struct #133939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Pinging a couple of people that have reviewed recent PRs to these files: @River707 @joker-eph, could you help review this PR? |
3d95ba3
to
8e20e50
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Jorn Tuyls (jtuyls) ChangesThis PR implements utilities to parse and print a comma-separated list of key-value pairs, similar to the From the docs: This enables defining custom struct parsing and printing functions if the Full diff: https://github.com/llvm/llvm-project/pull/133939.diff 6 Files Affected:
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 25c7d15eb8ed5..d77f4147744e1 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -238,6 +238,29 @@ class AsmPrinter {
void printDimensionList(ArrayRef<int64_t> shape);
+ //===----------------------------------------------------------------------===//
+ // Struct Printing
+ //===----------------------------------------------------------------------===//
+
+ /// Print a comma-separated list of key-value pairs using the provided
+ /// `keywords` and corresponding printing functions. This performs similar
+ /// printing as the the assembly format's `struct` directive printer, but
+ /// allows bringing in custom printers for fields.
+ ///
+ /// Example:
+ /// <
+ /// foo = foo_value,
+ /// bar = bar_value,
+ /// ...
+ /// >
+ virtual void
+ printStruct(ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs);
+
+ //===----------------------------------------------------------------------===//
+ // Cyclic Printing
+ //===----------------------------------------------------------------------===//
+
/// Class used to automatically end a cyclic region on destruction.
class CyclicPrintReset {
public:
@@ -1409,6 +1432,26 @@ class AsmParser {
return CyclicParseReset(this);
}
+ //===----------------------------------------------------------------------===//
+ // Struct Parsing
+ //===----------------------------------------------------------------------===//
+
+ /// Parse a comma-separated list of key-value pairs with a specified
+ /// delimiter. This performs similar parsing as the the assembly format
+ /// `struct` directive parser with custom delimiter and/or field parsing. The
+ /// variables are printed in the order they are specified in the argument list
+ /// but can be parsed in any order.
+ ///
+ /// Example:
+ /// <
+ /// foo = something_parsed_by_a_custom_parser,
+ /// bar = something_parsed_by_a_different_custom_parser,
+ /// ...
+ /// >
+ virtual ParseResult
+ parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) = 0;
+
protected:
/// Parse a handle to a resource within the assembly format for the given
/// dialect.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1f8fbfdd93568..cff3f5402dd79 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -570,6 +570,51 @@ class AsmParserImpl : public BaseT {
parser.getState().cyclicParsingStack.pop_back();
}
+ //===----------------------------------------------------------------------===//
+ // Struct Parsing
+ //===----------------------------------------------------------------------===//
+
+ /// Parse a comma-separated list of key-value pairs with a specified
+ /// delimiter.
+ ParseResult
+ parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) override {
+ assert(keywords.size() == parseFuncs.size());
+ auto keyError = [&]() -> ParseResult {
+ InFlightDiagnostic parseError =
+ emitError(getCurrentLocation(), "expected one of: ");
+ llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
+ parseError << '`' << kw << '`';
+ });
+ return parseError;
+ };
+ SmallVector<bool> seen(keywords.size(), false);
+ DenseMap<StringRef, size_t> keywordToIndex;
+ for (auto &&[idx, keyword] : llvm::enumerate(keywords))
+ keywordToIndex[keyword] = idx;
+ return parseCommaSeparatedList(
+ delimiter,
+ [&]() -> ParseResult {
+ StringRef keyword;
+ if (failed(parseOptionalKeyword(&keyword)))
+ return keyError();
+ if (!keywordToIndex.contains(keyword))
+ return keyError();
+ size_t idx = keywordToIndex[keyword];
+ if (seen[idx]) {
+ return emitError(getCurrentLocation(), "duplicated `")
+ << keyword << "` entry";
+ }
+ if (failed(parseEqual()))
+ return failure();
+ if (failed(parseFuncs[idx]()))
+ return failure();
+ seen[idx] = true;
+ return success();
+ },
+ "parse struct");
+ }
+
//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5b5ec841917e7..7814d8f2cab18 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3002,6 +3002,28 @@ void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
detail::printDimensionList(getStream(), shape);
}
+//===----------------------------------------------------------------------===//
+// Struct Printing
+//===----------------------------------------------------------------------===//
+
+/// Print a comma-separated list of key-value pairs.
+void AsmPrinter::printStruct(
+ ArrayRef<StringRef> keywords,
+ ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs) {
+ DenseMap<StringRef, llvm::function_ref<void(AsmPrinter & p)>> keywordToFunc;
+ for (auto &&[kw, printFunc] : llvm::zip(keywords, printFuncs))
+ keywordToFunc[kw] = printFunc;
+ auto &os = getStream();
+ llvm::interleaveComma(keywords, os, [&](StringRef kw) {
+ os << kw << " = ";
+ keywordToFunc[kw](*this);
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// Cyclic Printing
+//===----------------------------------------------------------------------===//
+
LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
return impl->pushCyclicPrinting(opaquePointer);
}
diff --git a/mlir/test/IR/custom-struct-attr-roundtrip.mlir b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
new file mode 100644
index 0000000000000..68f69f99b86a3
--- /dev/null
+++ b/mlir/test/IR/custom-struct-attr-roundtrip.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_struct_attr_roundtrip
+func.func @test_struct_attr_roundtrip() -> () {
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+ "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
+ "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+ "test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
+ // CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
+ "test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
+ return
+}
+
+// -----
+
+// Verify all keywords must be provided. All missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `type_str` missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
+
+// -----
+
+// Verify all keywords must be provided. `value` missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
+
+// -----
+
+// Verify invalid keyword provided.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
+"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
+
+// -----
+
+// Verify duplicated keyword provided.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{duplicated `type_str` entry}}
+"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
+
+// -----
+
+// Verify equals missing.
+
+// expected-error @below {{failed parsing `TestCustomStructAttr`}}
+// expected-error @below {{expected '='}}
+"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index fc2d77af29f12..2dae52ab7449c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -369,6 +369,15 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
}];
}
+// Test AsmParser::parseStruct and AsmPrinter::printStruct APIs through the custom
+// parser and printer.
+def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
+ let mnemonic = "custom_struct";
+ let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
+ OptionalParameter<"mlir::ArrayAttr">:$opt_value);
+ let hasCustomAssemblyFormat = 1;
+}
+
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 057d9fb4a215f..89c7c527a2247 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -316,6 +316,43 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
return success();
}
+//===----------------------------------------------------------------------===//
+// TestCustomStructAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestCustomStructAttr::parse(AsmParser &p, Type type) {
+ std::string typeStr;
+ int64_t value;
+ FailureOr<ArrayAttr> optValue;
+ if (failed(p.parseStruct(AsmParser::Delimiter::LessGreater,
+ {"type_str", "value", "opt_value"},
+ {[&]() { return p.parseString(&typeStr); },
+ [&]() { return p.parseInteger(value); },
+ [&]() {
+ optValue = mlir::FieldParser<ArrayAttr>::parse(p);
+ return success(succeeded(optValue));
+ }}))) {
+ p.emitError(p.getCurrentLocation())
+ << "failed parsing `TestCustomStructAttr`";
+ return {};
+ }
+ return get(p.getContext(), StringAttr::get(p.getContext(), typeStr), value,
+ optValue.value_or(ArrayAttr()));
+}
+
+void TestCustomStructAttr::print(AsmPrinter &p) const {
+ p << "<";
+ p.printStruct(
+ {"type_str", "value"},
+ {[&](AsmPrinter &p) { p.printStrippedAttrOrType(getTypeStr()); },
+ [&](AsmPrinter &p) { p.printStrippedAttrOrType(getValue()); }});
+ if (getOptValue() != ArrayAttr()) {
+ p << ", opt_value = ";
+ p.printStrippedAttrOrType(getOptValue());
+ }
+ p << ">";
+}
+
//===----------------------------------------------------------------------===//
// TestOpAsmAttrInterfaceAttr
//===----------------------------------------------------------------------===//
|
Instead of adding more c++ code, did you consider just expanding what the struct directive in tablegen can support? Why not extend that to support using a custom directive for the individual fields? This is generally what we have done when the declarative form is lacking something, try and extend that first. |
Yes, I briefly thought about that and tried something like
The assumption that this is a larger change might be misguided though as I am not too familiar with tablegen (yet). Additionally, I didn't see these two approaches as mutually exclusive as some fully custom printer/parser that can't be represented with the above directives might want to reuse these utilities as well for a part of it. Anyway, if the declarative form is the way to go or if you want both implementations, I am happy to implement the nested directives as well. |
The general thing for the assembly formats is that we should always to try to push for and evolve the tablegen side first, you should only default back to C++ when you really have to. Based on my observation of this MR, I would have expected we could just extend struct directive to support the very limited additional case of: a custom directive with a single parameter (aside from
For one of the cases in the StableHlo link above (though you have quite a few similar ops), you could do something like:
Then in your code, you only have one custom directive (for dim parsing/printing), and your done (your attr parser/printer are now declarative). That feels much better, and for your case would remove a lot of the C++ (in your case, there are a few ops which would be simplified just from that). |
Yes, this would remove a lot of code in most cases. Let me give this a shot. |
8e20e50
to
faa7628
Compare
@River707 I implemented the declarative form. Could you have a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on the tablegen side! Added some initial comments
faa7628
to
6d8414d
Compare
57da6a9
to
ee144db
Compare
@River707 Could you have another look? |
Thanks for the ping! LGTM. One last comment: Can you update the documentation for |
534ff00
to
c464edb
Compare
Yes, good point, I updated the documentation now. |
@River707 Could you help merge this? |
c464edb
to
074b02b
Compare
@joker-eph Could you have another look and help merge if this looks good? |
@River707 @joker-eph Could someone please help merge this if it looks good? |
Sorry, I had missed your update last week! |
@jtuyls @joker-eph I've landed 441b683 to fix a warning from this PR. Would you mind checking to see if you actually don't intend to use
|
…3939) This PR extends the `struct` directive in tablegen to support nested `custom` directives. Note that this assumes/verifies that that `custom` directive has a single parameter. This enables defining custom field parsing and printing functions if the `struct` directive doesn't suffice. There is some existing potential downstream usage for it: https://github.com/openxla/stablehlo/blob/a3c7de92425e8035437dae67ab2318a82eca79a1/stablehlo/dialect/StablehloOps.cpp#L3102
…3939) This PR extends the `struct` directive in tablegen to support nested `custom` directives. Note that this assumes/verifies that that `custom` directive has a single parameter. This enables defining custom field parsing and printing functions if the `struct` directive doesn't suffice. There is some existing potential downstream usage for it: https://github.com/openxla/stablehlo/blob/a3c7de92425e8035437dae67ab2318a82eca79a1/stablehlo/dialect/StablehloOps.cpp#L3102
…3939) This PR extends the `struct` directive in tablegen to support nested `custom` directives. Note that this assumes/verifies that that `custom` directive has a single parameter. This enables defining custom field parsing and printing functions if the `struct` directive doesn't suffice. There is some existing potential downstream usage for it: https://github.com/openxla/stablehlo/blob/a3c7de92425e8035437dae67ab2318a82eca79a1/stablehlo/dialect/StablehloOps.cpp#L3102
Thanks for the fix and heads up on the change! I think it's fine as
|
…3939) This PR extends the `struct` directive in tablegen to support nested `custom` directives. Note that this assumes/verifies that that `custom` directive has a single parameter. This enables defining custom field parsing and printing functions if the `struct` directive doesn't suffice. There is some existing potential downstream usage for it: https://github.com/openxla/stablehlo/blob/a3c7de92425e8035437dae67ab2318a82eca79a1/stablehlo/dialect/StablehloOps.cpp#L3102
This PR extends the
struct
directive in tablegen to support nestedcustom
directives. Note that this assumes/verifies that thatcustom
directive has a single parameter.This enables defining custom field parsing and printing functions if the
struct
directive doesn't suffice. There is some existing potential downstream usage for it: https://github.com/openxla/stablehlo/blob/a3c7de92425e8035437dae67ab2318a82eca79a1/stablehlo/dialect/StablehloOps.cpp#L3102