-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[CIR] Upstream global initialization for VectorType #137511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Upstream global initialization for VectorType #137511
Conversation
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds global initialization for VectorType Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/137511.diff 7 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
index fb3f7b1632436..624a82762ab18 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
@@ -204,7 +204,7 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}]>
];
- // Printing and parsing available in CIRDialect.cpp
+ // Printing and parsing available in CIRAttrs.cpp
let hasCustomAssemblyFormat = 1;
// Enable verifier.
@@ -215,6 +215,37 @@ def ConstArrayAttr : CIR_Attr<"ConstArray", "const_array", [TypedAttrInterface]>
}];
}
+//===----------------------------------------------------------------------===//
+// ConstVectorAttr
+//===----------------------------------------------------------------------===//
+
+def ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector",
+ [TypedAttrInterface]> {
+ let summary = "A constant vector from ArrayAttr";
+ let description = [{
+ A CIR vector attribute is an array of literals of the specified attribute
+ types.
+ }];
+
+ let parameters = (ins AttributeSelfTypeParameter<"">:$type,
+ "mlir::ArrayAttr":$elts);
+
+ // Define a custom builder for the type; that removes the need to pass in an
+ // MLIRContext instance, as it can be inferred from the `type`.
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "cir::VectorType":$type,
+ "mlir::ArrayAttr":$elts), [{
+ return $_get(type.getContext(), type, elts);
+ }]>
+ ];
+
+ // Printing and parsing available in CIRAttrs.cpp
+ let hasCustomAssemblyFormat = 1;
+
+ // Enable verifier.
+ let genVerifyDecl = 1;
+}
+
//===----------------------------------------------------------------------===//
// ConstPtrAttr
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
index b9a74e90a5960..6e5c7b8fb51f8 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
@@ -373,8 +373,27 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
elements, typedFiller);
}
case APValue::Vector: {
- cgm.errorNYI("ConstExprEmitter::tryEmitPrivate vector");
- return {};
+ const QualType elementType =
+ destType->castAs<VectorType>()->getElementType();
+ const unsigned numElements = value.getVectorLength();
+
+ SmallVector<mlir::Attribute, 16> elements;
+ elements.reserve(numElements);
+
+ for (unsigned i = 0; i < numElements; ++i) {
+ const mlir::Attribute element =
+ tryEmitPrivateForMemory(value.getVectorElt(i), elementType);
+ if (!element)
+ return {};
+ elements.push_back(element);
+ }
+
+ const auto desiredVecTy =
+ mlir::cast<cir::VectorType>(cgm.convertType(destType));
+
+ return cir::ConstVectorAttr::get(
+ desiredVecTy,
+ mlir::ArrayAttr::get(cgm.getBuilder().getContext(), elements));
}
case APValue::MemberPointer: {
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate member pointer");
diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
index a8d9f6a0e6e9b..b9b27f33207b8 100644
--- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
@@ -299,6 +299,94 @@ void ConstArrayAttr::print(AsmPrinter &printer) const {
printer << ">";
}
+//===----------------------------------------------------------------------===//
+// CIR ConstVectorAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::ConstVectorAttr::verify(
+ function_ref<::mlir::InFlightDiagnostic()> emitError, Type type,
+ ArrayAttr elts) {
+
+ if (!mlir::isa<cir::VectorType>(type)) {
+ return emitError() << "type of cir::ConstVectorAttr is not a "
+ "cir::VectorType: "
+ << type;
+ }
+
+ const auto vecType = mlir::cast<cir::VectorType>(type);
+
+ if (vecType.getSize() != elts.size()) {
+ return emitError()
+ << "number of constant elements should match vector size";
+ }
+
+ // Check if the types of the elements match
+ LogicalResult elementTypeCheck = success();
+ elts.walkImmediateSubElements(
+ [&](Attribute element) {
+ if (elementTypeCheck.failed()) {
+ // An earlier element didn't match
+ return;
+ }
+ auto typedElement = mlir::dyn_cast<TypedAttr>(element);
+ if (!typedElement ||
+ typedElement.getType() != vecType.getElementType()) {
+ elementTypeCheck = failure();
+ emitError() << "constant type should match vector element type";
+ }
+ },
+ [&](Type) {});
+
+ return elementTypeCheck;
+}
+
+Attribute cir::ConstVectorAttr::parse(AsmParser &parser, Type type) {
+ FailureOr<Type> resultType;
+ FailureOr<ArrayAttr> resultValue;
+
+ const SMLoc loc = parser.getCurrentLocation();
+
+ // Parse literal '<'
+ if (parser.parseLess()) {
+ return {};
+ }
+
+ // Parse variable 'value'
+ resultValue = FieldParser<ArrayAttr>::parse(parser);
+ if (failed(resultValue)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "failed to parse ConstVectorAttr parameter 'value' as "
+ "an attribute");
+ return {};
+ }
+
+ if (parser.parseOptionalColon().failed()) {
+ resultType = type;
+ } else {
+ resultType = ::mlir::FieldParser<Type>::parse(parser);
+ if (failed(resultType)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "failed to parse ConstVectorAttr parameter 'type' as "
+ "an MLIR type");
+ return {};
+ }
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater()) {
+ return {};
+ }
+
+ return parser.getChecked<ConstVectorAttr>(
+ loc, parser.getContext(), resultType.value(), resultValue.value());
+}
+
+void cir::ConstVectorAttr::print(AsmPrinter &printer) const {
+ printer << "<";
+ printer.printStrippedAttrOrType(getElts());
+ printer << ">";
+}
+
//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 939802a3af680..07847d62feadd 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -242,7 +242,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}
- if (mlir::isa<cir::ConstArrayAttr>(attrType))
+ if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType))
return success();
assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 102438c2ded02..db331691154e6 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -188,8 +188,9 @@ class CIRAttrToValue {
mlir::Value visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
- .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, cir::ConstPtrAttr,
- cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
+ .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
+ cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
+ [&](auto attrT) { return visitCirAttr(attrT); })
.Default([&](auto attrT) { return mlir::Value(); });
}
@@ -197,6 +198,7 @@ class CIRAttrToValue {
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
+ mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
mlir::Value visitCirAttr(cir::ZeroAttr attr);
private:
@@ -275,6 +277,33 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
return result;
}
+/// ConstVectorAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
+ const mlir::Type llvmTy = converter->convertType(attr.getType());
+ const mlir::Location loc = parentOp->getLoc();
+
+ SmallVector<mlir::Attribute> mlirValues;
+ for (const mlir::Attribute elementAttr : attr.getElts()) {
+ mlir::Attribute mlirAttr;
+ if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
+ mlirAttr = rewriter.getIntegerAttr(
+ converter->convertType(intAttr.getType()), intAttr.getValue());
+ } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
+ mlirAttr = rewriter.getFloatAttr(
+ converter->convertType(floatAttr.getType()), floatAttr.getValue());
+ } else {
+ llvm_unreachable(
+ "vector constant with an element that is neither an int nor a float");
+ }
+ mlirValues.push_back(mlirAttr);
+ }
+
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, llvmTy,
+ mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy),
+ mlirValues));
+}
+
/// ZeroAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) {
mlir::Location loc = parentOp->getLoc();
@@ -888,7 +917,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
cir::GlobalOp op, mlir::Attribute init,
mlir::ConversionPatternRewriter &rewriter) const {
// TODO: Generalize this handling when more types are needed here.
- assert((isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(init)));
+ assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
+ cir::ZeroAttr>(init)));
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
@@ -941,8 +971,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
op.emitError() << "unsupported initializer '" << init.value() << "'";
return mlir::failure();
}
- } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
- init.value())) {
+ } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
+ cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) {
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
// to the appropriate value.
diff --git a/clang/test/CIR/CodeGen/vector-ext.cpp b/clang/test/CIR/CodeGen/vector-ext.cpp
index 13726edf3d259..7759a32fc1378 100644
--- a/clang/test/CIR/CodeGen/vector-ext.cpp
+++ b/clang/test/CIR/CodeGen/vector-ext.cpp
@@ -31,7 +31,7 @@ vi2 vec_c;
// OGCG: @[[VEC_C:.*]] = global <2 x i32> zeroinitializer
-vd2 d;
+vd2 vec_d;
// CIR: cir.global external @[[VEC_D:.*]] = #cir.zero : !cir.vector<2 x !cir.double>
@@ -39,6 +39,15 @@ vd2 d;
// OGCG: @[[VEC_D:.*]] = global <2 x double> zeroinitializer
+vi4 vec_e = { 1, 2, 3, 4 };
+
+// CIR: cir.global external @[[VEC_E:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
+// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+
+// LLVM: @[[VEC_E:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
+// OGCG: @[[VEC_E:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
void foo() {
vi4 a;
vi3 b;
diff --git a/clang/test/CIR/CodeGen/vector.cpp b/clang/test/CIR/CodeGen/vector.cpp
index 8f9e98fb6b3c0..4c1850141a21c 100644
--- a/clang/test/CIR/CodeGen/vector.cpp
+++ b/clang/test/CIR/CodeGen/vector.cpp
@@ -30,6 +30,15 @@ vll2 c;
// OGCG: @[[VEC_C:.*]] = global <2 x i64> zeroinitializer
+vi4 d = { 1, 2, 3, 4 };
+
+// CIR: cir.global external @[[VEC_D:.*]] = #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> :
+// CIR-SAME: !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
+
+// LLVM: @[[VEC_D:.*]] = dso_local global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
+// OGCG: @[[VEC_D:.*]] = global <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+
void vec_int_test() {
vi4 a;
vd2 b;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable to me, but Andy/et-al should do a review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me once the parsing test is added.
const mlir::Attribute element = | ||
tryEmitPrivateForMemory(value.getVectorElt(i), elementType); | ||
if (!element) | ||
return {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens to any elements that are already in the elements
vector? Do they get cleaned up and deleted properly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understood, yes the destructor will be called when we early return and will clean the vector
return elementTypeCheck; | ||
} | ||
|
||
Attribute cir::ConstVectorAttr::parse(AsmParser &parser, Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should add a test in clang/test/CIR/IR to verify this parser.
cd4f061
to
68a3d41
Compare
68a3d41
to
a1372b6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/59/builds/16971 Here is the relevant piece of the build log for the reference
|
This change adds global initialization for VectorType Issue llvm#136487
This change adds global initialization for VectorType Issue llvm#136487
This change adds global initialization for VectorType Issue llvm#136487
This change adds global initialization for VectorType Issue llvm#136487
This change adds global initialization for VectorType
Issue #136487