diff --git a/llvm/include/llvm/IR/FixedMetadataKinds.def b/llvm/include/llvm/IR/FixedMetadataKinds.def index df572e8791e13..90276eae13e4b 100644 --- a/llvm/include/llvm/IR/FixedMetadataKinds.def +++ b/llvm/include/llvm/IR/FixedMetadataKinds.def @@ -53,3 +53,4 @@ LLVM_FIXED_MD_KIND(MD_DIAssignID, "DIAssignID", 38) LLVM_FIXED_MD_KIND(MD_coro_outside_frame, "coro.outside.frame", 39) LLVM_FIXED_MD_KIND(MD_mmra, "mmra", 40) LLVM_FIXED_MD_KIND(MD_noalias_addrspace, "noalias.addrspace", 41) +LLVM_FIXED_MD_KIND(MD_callee_type, "callee_type", 42) diff --git a/llvm/include/llvm/IR/Metadata.h b/llvm/include/llvm/IR/Metadata.h index 22ab59be55eb2..c707555a068cc 100644 --- a/llvm/include/llvm/IR/Metadata.h +++ b/llvm/include/llvm/IR/Metadata.h @@ -1252,6 +1252,12 @@ class MDNode : public Metadata { bool isReplaceable() const { return isTemporary() || isAlwaysReplaceable(); } bool isAlwaysReplaceable() const { return getMetadataID() == DIAssignIDKind; } + bool hasGeneralizedMDString() const { + if (getNumOperands() < 2 || !isa(getOperand(1))) + return false; + return cast(getOperand(1))->getString().ends_with(".generalized"); + } + unsigned getNumTemporaryUses() const { assert(isTemporary() && "Only for temporaries"); return Context.getReplaceableUses()->getNumUses(); @@ -1463,6 +1469,8 @@ class MDNode : public Metadata { const Instruction *BInstr); static MDNode *getMergedMemProfMetadata(MDNode *A, MDNode *B); static MDNode *getMergedCallsiteMetadata(MDNode *A, MDNode *B); + static MDNode *getMergedCalleeTypeMetadata(LLVMContext &Ctx, MDNode *A, + MDNode *B); }; /// Tuple of metadata. diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp index 8e78cd9cc573a..82c1b96f5fb7e 100644 --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -1302,6 +1302,24 @@ static void addRange(SmallVectorImpl &EndPoints, EndPoints.push_back(High); } +MDNode *MDNode::getMergedCalleeTypeMetadata(LLVMContext &Ctx, MDNode *A, + MDNode *B) { + SmallVector AB; + SmallSet MergedCallees; + auto AddUniqueCallees = [&AB, &MergedCallees](llvm::MDNode *N) { + if (!N) + return; + for (const MDOperand &Op : N->operands()) { + Metadata *MD = Op.get(); + if (MergedCallees.insert(MD).second) + AB.push_back(MD); + } + }; + AddUniqueCallees(A); + AddUniqueCallees(B); + return llvm::MDNode::get(Ctx, AB); +} + MDNode *MDNode::getMostGenericRange(MDNode *A, MDNode *B) { // Given two ranges, we want to compute the union of the ranges. This // is slightly complicated by having to combine the intervals and merge diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 7d1918e175c0c..9d8ab4742e1b6 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -530,6 +530,7 @@ class Verifier : public InstVisitor, VerifierSupport { void visitCallStackMetadata(MDNode *MD); void visitMemProfMetadata(Instruction &I, MDNode *MD); void visitCallsiteMetadata(Instruction &I, MDNode *MD); + void visitCalleeTypeMetadata(Instruction &I, MDNode *MD); void visitDIAssignIDMetadata(Instruction &I, MDNode *MD); void visitMMRAMetadata(Instruction &I, MDNode *MD); void visitAnnotationMetadata(MDNode *Annotation); @@ -5096,6 +5097,19 @@ void Verifier::visitCallsiteMetadata(Instruction &I, MDNode *MD) { visitCallStackMetadata(MD); } +void Verifier::visitCalleeTypeMetadata(Instruction &I, MDNode *MD) { + Check(isa(I), "!callee_type metadata should only exist on calls", + &I); + for (const MDOperand &Op : MD->operands()) { + Check(isa(Op.get()), + "The callee_type metadata must be a list of type metadata nodes"); + auto *TypeMD = cast(Op.get()); + Check(TypeMD->hasGeneralizedMDString(), + "Only generalized type metadata can be part of the callee_type " + "metadata list"); + } +} + void Verifier::visitAnnotationMetadata(MDNode *Annotation) { Check(isa(Annotation), "annotation must be a tuple"); Check(Annotation->getNumOperands() >= 1, @@ -5373,6 +5387,9 @@ void Verifier::visitInstruction(Instruction &I) { if (MDNode *MD = I.getMetadata(LLVMContext::MD_callsite)) visitCallsiteMetadata(I, MD); + if (MDNode *MD = I.getMetadata(LLVMContext::MD_callee_type)) + visitCalleeTypeMetadata(I, MD); + if (MDNode *MD = I.getMetadata(LLVMContext::MD_DIAssignID)) visitDIAssignIDMetadata(I, MD); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 844e18dd7d8c5..04790af1dbd2b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -4161,6 +4161,11 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { Call, Builder.CreateBitOrPointerCast(ReturnedArg, CallTy)); } + // Drop unnecessary callee_type metadata from calls that were converted + // into direct calls. + if (Call.getMetadata(LLVMContext::MD_callee_type) && !Call.isIndirectCall()) + Call.setMetadata(LLVMContext::MD_callee_type, nullptr); + // Drop unnecessary kcfi operand bundles from calls that were converted // into direct calls. auto Bundle = Call.getOperandBundle(LLVMContext::OB_kcfi); diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 809a0d7ebeea6..d37fccfaaa301 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -3377,6 +3377,11 @@ static void combineMetadata(Instruction *K, const Instruction *J, K->setMetadata(Kind, MDNode::getMostGenericAlignmentOrDereferenceable(JMD, KMD)); break; + case LLVMContext::MD_callee_type: + if (!AAOnly) + K->setMetadata(Kind, MDNode::getMergedCalleeTypeMetadata( + K->getContext(), KMD, JMD)); + break; case LLVMContext::MD_preserve_access_index: // Preserve !preserve.access.index in K. break; diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp index 5e50536a99206..e4e110183a1b6 100644 --- a/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -987,6 +987,13 @@ void Mapper::remapInstruction(Instruction *I) { "Referenced value not in value map!"); } + // Drop callee_type metadata from calls that were remapped + // into a direct call from an indirect one. + if (auto *CB = dyn_cast(I)) { + if (CB->getMetadata(LLVMContext::MD_callee_type) && !CB->isIndirectCall()) + CB->setMetadata(LLVMContext::MD_callee_type, nullptr); + } + // Remap phi nodes' incoming blocks. if (PHINode *PN = dyn_cast(I)) { for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { diff --git a/llvm/test/Transforms/Inline/drop-callee-type-metadata.ll b/llvm/test/Transforms/Inline/drop-callee-type-metadata.ll new file mode 100644 index 0000000000000..ce7830536c200 --- /dev/null +++ b/llvm/test/Transforms/Inline/drop-callee-type-metadata.ll @@ -0,0 +1,26 @@ +;; Test if the callee_type metadata is dropped when it is +;; is mapped to a direct function call from an indirect call during inlining. + +; RUN: opt -passes="inline" -S < %s | FileCheck %s + +define i32 @_Z13call_indirectPFicEc(ptr %func, i8 %x) local_unnamed_addr !type !0 { +entry: + %call = call i32 %func(i8 %x), !callee_type !1 + ret i32 %call +} + +define i32 @_Z3barv() local_unnamed_addr !type !3 { +entry: + ; CHECK-LABEL: define i32 @_Z3barv() + ; CHECK-NEXT: entry: + ; CHECK-NOT: !callee_type + ; CHECK-NEXT: %call.i = call i32 @_Z3fooc(i8 97) + %call = call i32 @_Z13call_indirectPFicEc(ptr nonnull @_Z3fooc, i8 97) + ret i32 %call +} +declare !type !2 i32 @_Z3fooc(i8 signext) + +!0 = !{i64 0, !"_ZTSFiPvcE.generalized"} +!1 = !{!2} +!2 = !{i64 0, !"_ZTSFicE.generalized"} +!3 = !{i64 0, !"_ZTSFivE.generalized"} diff --git a/llvm/test/Transforms/InstCombine/drop-callee-type-metadata.ll b/llvm/test/Transforms/InstCombine/drop-callee-type-metadata.ll new file mode 100644 index 0000000000000..13c60f675d66d --- /dev/null +++ b/llvm/test/Transforms/InstCombine/drop-callee-type-metadata.ll @@ -0,0 +1,21 @@ +;; Test if the callee_type metadata is dropped when it is attached +;; to a direct function call during instcombine. + +; RUN: opt -passes="instcombine" -S < %s | FileCheck %s + +define i32 @_Z3barv() local_unnamed_addr !type !3 { +entry: + ; CHECK-LABEL: define i32 @_Z3barv() + ; CHECK-NEXT: entry: + ; CHECK-NOT: !callee_type + ; CHECK-NEXT: %call = call i32 @_Z3fooc(i8 97) + %call = call i32 @_Z3fooc(i8 97), !callee_type !1 + ret i32 %call +} + +declare !type !2 i32 @_Z3fooc(i8 signext) + +!0 = !{i64 0, !"_ZTSFiPvcE.generalized"} +!1 = !{!2} +!2 = !{i64 0, !"_ZTSFicE.generalized"} +!3 = !{i64 0, !"_ZTSFivE.generalized"} diff --git a/llvm/test/Verifier/callee-type-metadata.ll b/llvm/test/Verifier/callee-type-metadata.ll new file mode 100644 index 0000000000000..0107dec27de34 --- /dev/null +++ b/llvm/test/Verifier/callee-type-metadata.ll @@ -0,0 +1,30 @@ +;; Test if the callee_type metadata attached to indirect call sites adhere to the expected format. + +; RUN: not opt -passes=verify < %s 2>&1 | FileCheck %s +define i32 @_Z13call_indirectPFicEc(ptr %func, i8 signext %x) !type !0 { +entry: + %func.addr = alloca ptr, align 8 + %x.addr = alloca i8, align 1 + store ptr %func, ptr %func.addr, align 8 + store i8 %x, ptr %x.addr, align 1 + %fptr = load ptr, ptr %func.addr, align 8 + %x_val = load i8, ptr %x.addr, align 1 + ;; No failures expected for this callee_type metdata. + %call = call i32 %fptr(i8 signext %x_val), !callee_type !1 + ;; callee_type metdata is a type metadata instead of a list of type metadata nodes. + ; CHECK: The callee_type metadata must be a list of type metadata nodes + %call2 = call i32 %fptr(i8 signext %x_val), !callee_type !0 + ;; callee_type metdata must be a list of "generalized" type metadata. + ; CHECK: Only generalized type metadata can be part of the callee_type metadata list + %call3 = call i32 %fptr(i8 signext %x_val), !callee_type !4 + ret i32 %call +} + +declare !type !2 i32 @_Z3barc(i8 signext) + +!0 = !{i64 0, !"_ZTSFiPvcE.generalized"} +!1 = !{!2} +!2 = !{i64 0, !"_ZTSFicE.generalized"} +!3 = !{i64 0, !"_ZTSFicE"} +!4 = !{!3} +!8 = !{i64 0, !"_ZTSFicE.generalized"}