diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 3671c1c71ac0b..4a6db5d895a62 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -36,6 +36,7 @@ #define LLVM_ANALYSIS_IR2VEC_H #include "llvm/ADT/DenseMap.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" @@ -162,15 +163,29 @@ using BBEmbeddingsMap = DenseMap; /// embeddings. class Vocabulary { friend class llvm::IR2VecVocabAnalysis; + + // Vocabulary Slot Layout: + // +----------------+------------------------------------------------------+ + // | Entity Type | Index Range | + // +----------------+------------------------------------------------------+ + // | Opcodes | [0 .. (MaxOpcodes-1)] | + // | Canonical Types| [MaxOpcodes .. (MaxOpcodes+MaxCanonicalTypeIDs-1)] | + // | Operands | [(MaxOpcodes+MaxCanonicalTypeIDs) .. NumCanEntries] | + // +----------------+------------------------------------------------------+ + // Note: "Similar" LLVM Types are grouped/canonicalized together. + // Operands include Comparison predicates (ICmp/FCmp). + // This can be extended to include other specializations in future. using VocabVector = std::vector; VocabVector Vocab; -public: - // Slot layout: - // [0 .. MaxOpcodes-1] => Instruction opcodes - // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types - // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds + static constexpr unsigned NumICmpPredicates = + static_cast(CmpInst::LAST_ICMP_PREDICATE) - + static_cast(CmpInst::FIRST_ICMP_PREDICATE) + 1; + static constexpr unsigned NumFCmpPredicates = + static_cast(CmpInst::LAST_FCMP_PREDICATE) - + static_cast(CmpInst::FIRST_FCMP_PREDICATE) + 1; +public: /// Canonical type IDs supported by IR2Vec Vocabulary enum class CanonicalTypeID : unsigned { FloatTy, @@ -207,13 +222,18 @@ class Vocabulary { static_cast(CanonicalTypeID::MaxCanonicalType); static constexpr unsigned MaxOperandKinds = static_cast(OperandKind::MaxOperandKind); + // CmpInst::Predicate has gaps. We want the vocabulary to be dense without + // empty slots. + static constexpr unsigned MaxPredicateKinds = + NumICmpPredicates + NumFCmpPredicates; Vocabulary() = default; LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {} LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; }; LLVM_ABI unsigned getDimension() const; - /// Total number of entries (opcodes + canonicalized types + operand kinds) + /// Total number of entries (opcodes + canonicalized types + operand kinds + + /// predicates) static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; } /// Function to get vocabulary key for a given Opcode @@ -228,16 +248,21 @@ class Vocabulary { /// Function to classify an operand into OperandKind LLVM_ABI static OperandKind getOperandKind(const Value *Op); + /// Function to get vocabulary key for a given predicate + LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P); + /// Functions to return the slot index or position of a given Opcode, TypeID, /// or OperandKind in the vocabulary. LLVM_ABI static unsigned getSlotIndex(unsigned Opcode); LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID); LLVM_ABI static unsigned getSlotIndex(const Value &Op); + LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P); /// Accessors to get the embedding for a given entity. LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const; LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const; LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const; + LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const; /// Const Iterator type aliases using const_iterator = VocabVector::const_iterator; @@ -274,7 +299,13 @@ class Vocabulary { private: constexpr static unsigned NumCanonicalEntries = - MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds; + MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds; + + // Base offsets for slot layout to simplify index computation + constexpr static unsigned OperandBaseOffset = + MaxOpcodes + MaxCanonicalTypeIDs; + constexpr static unsigned PredicateBaseOffset = + OperandBaseOffset + MaxOperandKinds; /// String mappings for CanonicalTypeID values static constexpr StringLiteral CanonicalTypeNames[] = { @@ -326,6 +357,9 @@ class Vocabulary { /// Function to convert TypeID to CanonicalTypeID LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID); + + /// Function to get the predicate enum value for a given index + LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index); }; /// Embedder provides the interface to generate embeddings (vector diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 99afc0601d523..f51f0898cb37e 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { ArgEmb += Vocab[*Op]; auto InstVector = Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + if (const auto *IC = dyn_cast(&I)) + InstVector += Vocab[IC->getPredicate()]; InstVecMap[&I] = InstVector; BBVector += InstVector; } @@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { // embeddings auto InstVector = Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + // Add compare predicate embedding as an additional operand if applicable + if (const auto *IC = dyn_cast(&I)) + InstVector += Vocab[IC->getPredicate()]; InstVecMap[&I] = InstVector; BBVector += InstVector; } @@ -278,7 +283,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) { unsigned Vocabulary::getSlotIndex(const Value &Op) { unsigned Index = static_cast(getOperandKind(&Op)); assert(Index < MaxOperandKinds && "Invalid OperandKind"); - return MaxOpcodes + MaxCanonicalTypeIDs + Index; + return OperandBaseOffset + Index; +} + +unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) { + unsigned PU = static_cast(P); + unsigned FirstFC = static_cast(CmpInst::FIRST_FCMP_PREDICATE); + unsigned FirstIC = static_cast(CmpInst::FIRST_ICMP_PREDICATE); + + unsigned PredIdx = + (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC); + return PredicateBaseOffset + PredIdx; } const Embedding &Vocabulary::operator[](unsigned Opcode) const { @@ -293,6 +308,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const { return Vocab[getSlotIndex(Arg)]; } +const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const { + return Vocab[getSlotIndex(P)]; +} + StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); #define HANDLE_INST(NUM, OPCODE, CLASS) \ @@ -338,18 +357,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { return OperandKind::VariableID; } +CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) { + assert(Index < MaxPredicateKinds && "Invalid predicate index"); + unsigned PredEnumVal = + (Index < NumFCmpPredicates) + ? (static_cast(CmpInst::FIRST_FCMP_PREDICATE) + Index) + : (static_cast(CmpInst::FIRST_ICMP_PREDICATE) + + (Index - NumFCmpPredicates)); + return static_cast(PredEnumVal); +} + +StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) { + static SmallString<16> PredNameBuffer; + if (Pred < CmpInst::FIRST_ICMP_PREDICATE) + PredNameBuffer = "FCMP_"; + else + PredNameBuffer = "ICMP_"; + PredNameBuffer += CmpInst::getPredicateName(Pred); + return PredNameBuffer; +} + StringRef Vocabulary::getStringKey(unsigned Pos) { assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary"); // Opcode if (Pos < MaxOpcodes) return getVocabKeyForOpcode(Pos + 1); // Type - if (Pos < MaxOpcodes + MaxCanonicalTypeIDs) + if (Pos < OperandBaseOffset) return getVocabKeyForCanonicalTypeID( static_cast(Pos - MaxOpcodes)); // Operand - return getVocabKeyForOperandKind( - static_cast(Pos - MaxOpcodes - MaxCanonicalTypeIDs)); + if (Pos < PredicateBaseOffset) + return getVocabKeyForOperandKind( + static_cast(Pos - OperandBaseOffset)); + // Predicates + return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset)); } // For now, assume vocabulary is stable unless explicitly invalidated. @@ -363,11 +405,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { VocabVector DummyVocab; DummyVocab.reserve(NumCanonicalEntries); float DummyVal = 0.1f; - // Create a dummy vocabulary with entries for all opcodes, types, and - // operands - for ([[maybe_unused]] unsigned _ : - seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs + - Vocabulary::MaxOperandKinds)) { + // Create a dummy vocabulary with entries for all opcodes, types, operands + // and predicates + for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) { DummyVocab.push_back(Embedding(Dim, DummyVal)); DummyVal += 0.1f; } @@ -510,6 +550,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { } Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(), NumericArgEmbeddings.end()); + + // Handle Predicates: part of Operands section. We look up predicate keys + // in ArgVocab. + std::vector NumericPredEmbeddings(Vocabulary::MaxPredicateKinds, + Embedding(Dim, 0)); + NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds); + for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) { + StringRef VocabKey = + Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK)); + auto It = ArgVocab.find(VocabKey.str()); + if (It != ArgVocab.end()) { + NumericPredEmbeddings[PK] = It->second; + continue; + } + handleMissingEntity(VocabKey.str()); + } + Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(), + NumericPredEmbeddings.end()); } IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab) diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json index 07fde84c1541b..ae36ff54686c5 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json +++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json @@ -87,6 +87,32 @@ "Function": [1, 2], "Pointer": [3, 4], "Constant": [5, 6], - "Variable": [7, 8] + "Variable": [7, 8], + "FCMP_false": [9, 10], + "FCMP_oeq": [11, 12], + "FCMP_ogt": [13, 14], + "FCMP_oge": [15, 16], + "FCMP_olt": [17, 18], + "FCMP_ole": [19, 20], + "FCMP_one": [21, 22], + "FCMP_ord": [23, 24], + "FCMP_uno": [25, 26], + "FCMP_ueq": [27, 28], + "FCMP_ugt": [29, 30], + "FCMP_uge": [31, 32], + "FCMP_ult": [33, 34], + "FCMP_ule": [35, 36], + "FCMP_une": [37, 38], + "FCMP_true": [39, 40], + "ICMP_eq": [41, 42], + "ICMP_ne": [43, 44], + "ICMP_ugt": [45, 46], + "ICMP_uge": [47, 48], + "ICMP_ult": [49, 50], + "ICMP_ule": [51, 52], + "ICMP_sgt": [53, 54], + "ICMP_sge": [55, 56], + "ICMP_slt": [57, 58], + "ICMP_sle": [59, 60] } } diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json index 932b3a217b70c..9003dc73954aa 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json +++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json @@ -86,6 +86,32 @@ "Function": [1, 2, 3], "Pointer": [4, 5, 6], "Constant": [7, 8, 9], - "Variable": [10, 11, 12] + "Variable": [10, 11, 12], + "FCMP_false": [13, 14, 15], + "FCMP_oeq": [16, 17, 18], + "FCMP_ogt": [19, 20, 21], + "FCMP_oge": [22, 23, 24], + "FCMP_olt": [25, 26, 27], + "FCMP_ole": [28, 29, 30], + "FCMP_one": [31, 32, 33], + "FCMP_ord": [34, 35, 36], + "FCMP_uno": [37, 38, 39], + "FCMP_ueq": [40, 41, 42], + "FCMP_ugt": [43, 44, 45], + "FCMP_uge": [46, 47, 48], + "FCMP_ult": [49, 50, 51], + "FCMP_ule": [52, 53, 54], + "FCMP_une": [55, 56, 57], + "FCMP_true": [58, 59, 60], + "ICMP_eq": [61, 62, 63], + "ICMP_ne": [64, 65, 66], + "ICMP_ugt": [67, 68, 69], + "ICMP_uge": [70, 71, 72], + "ICMP_ult": [73, 74, 75], + "ICMP_ule": [76, 77, 78], + "ICMP_sgt": [79, 80, 81], + "ICMP_sge": [82, 83, 84], + "ICMP_slt": [85, 86, 87], + "ICMP_sle": [88, 89, 90] } } diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json index 19f3efee9f6a1..7ef85490b27df 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json +++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json @@ -47,6 +47,7 @@ "FPTrunc": [133, 134, 135], "FPExt": [136, 137, 138], "PtrToInt": [139, 140, 141], + "PtrToAddr": [202, 203, 204], "IntToPtr": [142, 143, 144], "BitCast": [145, 146, 147], "AddrSpaceCast": [148, 149, 150], @@ -86,6 +87,32 @@ "Function": [0, 0, 0], "Pointer": [0, 0, 0], "Constant": [0, 0, 0], - "Variable": [0, 0, 0] + "Variable": [0, 0, 0], + "FCMP_false": [0, 0, 0], + "FCMP_oeq": [0, 0, 0], + "FCMP_ogt": [0, 0, 0], + "FCMP_oge": [0, 0, 0], + "FCMP_olt": [0, 0, 0], + "FCMP_ole": [0, 0, 0], + "FCMP_one": [0, 0, 0], + "FCMP_ord": [0, 0, 0], + "FCMP_uno": [0, 0, 0], + "FCMP_ueq": [0, 0, 0], + "FCMP_ugt": [0, 0, 0], + "FCMP_uge": [0, 0, 0], + "FCMP_ult": [0, 0, 0], + "FCMP_ule": [0, 0, 0], + "FCMP_une": [0, 0, 0], + "FCMP_true": [0, 0, 0], + "ICMP_eq": [0, 0, 0], + "ICMP_ne": [0, 0, 0], + "ICMP_ugt": [0, 0, 0], + "ICMP_uge": [0, 0, 0], + "ICMP_ult": [0, 0, 0], + "ICMP_ule": [0, 0, 0], + "ICMP_sgt": [1, 1, 1], + "ICMP_sge": [0, 0, 0], + "ICMP_slt": [0, 0, 0], + "ICMP_sle": [0, 0, 0] } } diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt index df7769c9c6a65..d62b0dd157b0b 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt +++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt @@ -82,3 +82,29 @@ Key: Function: [ 0.20 0.40 ] Key: Pointer: [ 0.60 0.80 ] Key: Constant: [ 1.00 1.20 ] Key: Variable: [ 1.40 1.60 ] +Key: FCMP_false: [ 1.80 2.00 ] +Key: FCMP_oeq: [ 2.20 2.40 ] +Key: FCMP_ogt: [ 2.60 2.80 ] +Key: FCMP_oge: [ 3.00 3.20 ] +Key: FCMP_olt: [ 3.40 3.60 ] +Key: FCMP_ole: [ 3.80 4.00 ] +Key: FCMP_one: [ 4.20 4.40 ] +Key: FCMP_ord: [ 4.60 4.80 ] +Key: FCMP_uno: [ 5.00 5.20 ] +Key: FCMP_ueq: [ 5.40 5.60 ] +Key: FCMP_ugt: [ 5.80 6.00 ] +Key: FCMP_uge: [ 6.20 6.40 ] +Key: FCMP_ult: [ 6.60 6.80 ] +Key: FCMP_ule: [ 7.00 7.20 ] +Key: FCMP_une: [ 7.40 7.60 ] +Key: FCMP_true: [ 7.80 8.00 ] +Key: ICMP_eq: [ 8.20 8.40 ] +Key: ICMP_ne: [ 8.60 8.80 ] +Key: ICMP_ugt: [ 9.00 9.20 ] +Key: ICMP_uge: [ 9.40 9.60 ] +Key: ICMP_ult: [ 9.80 10.00 ] +Key: ICMP_ule: [ 10.20 10.40 ] +Key: ICMP_sgt: [ 10.60 10.80 ] +Key: ICMP_sge: [ 11.00 11.20 ] +Key: ICMP_slt: [ 11.40 11.60 ] +Key: ICMP_sle: [ 11.80 12.00 ] diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt index f3ce809fd2fd2..e443adb17ac78 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt +++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt @@ -82,3 +82,29 @@ Key: Function: [ 0.50 1.00 ] Key: Pointer: [ 1.50 2.00 ] Key: Constant: [ 2.50 3.00 ] Key: Variable: [ 3.50 4.00 ] +Key: FCMP_false: [ 4.50 5.00 ] +Key: FCMP_oeq: [ 5.50 6.00 ] +Key: FCMP_ogt: [ 6.50 7.00 ] +Key: FCMP_oge: [ 7.50 8.00 ] +Key: FCMP_olt: [ 8.50 9.00 ] +Key: FCMP_ole: [ 9.50 10.00 ] +Key: FCMP_one: [ 10.50 11.00 ] +Key: FCMP_ord: [ 11.50 12.00 ] +Key: FCMP_uno: [ 12.50 13.00 ] +Key: FCMP_ueq: [ 13.50 14.00 ] +Key: FCMP_ugt: [ 14.50 15.00 ] +Key: FCMP_uge: [ 15.50 16.00 ] +Key: FCMP_ult: [ 16.50 17.00 ] +Key: FCMP_ule: [ 17.50 18.00 ] +Key: FCMP_une: [ 18.50 19.00 ] +Key: FCMP_true: [ 19.50 20.00 ] +Key: ICMP_eq: [ 20.50 21.00 ] +Key: ICMP_ne: [ 21.50 22.00 ] +Key: ICMP_ugt: [ 22.50 23.00 ] +Key: ICMP_uge: [ 23.50 24.00 ] +Key: ICMP_ult: [ 24.50 25.00 ] +Key: ICMP_ule: [ 25.50 26.00 ] +Key: ICMP_sgt: [ 26.50 27.00 ] +Key: ICMP_sge: [ 27.50 28.00 ] +Key: ICMP_slt: [ 28.50 29.00 ] +Key: ICMP_sle: [ 29.50 30.00 ] diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt index 72b25b9bd3d9c..7fb6043552f7b 100644 --- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt +++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt @@ -82,3 +82,29 @@ Key: Function: [ 0.00 0.00 ] Key: Pointer: [ 0.00 0.00 ] Key: Constant: [ 0.00 0.00 ] Key: Variable: [ 0.00 0.00 ] +Key: FCMP_false: [ 0.00 0.00 ] +Key: FCMP_oeq: [ 0.00 0.00 ] +Key: FCMP_ogt: [ 0.00 0.00 ] +Key: FCMP_oge: [ 0.00 0.00 ] +Key: FCMP_olt: [ 0.00 0.00 ] +Key: FCMP_ole: [ 0.00 0.00 ] +Key: FCMP_one: [ 0.00 0.00 ] +Key: FCMP_ord: [ 0.00 0.00 ] +Key: FCMP_uno: [ 0.00 0.00 ] +Key: FCMP_ueq: [ 0.00 0.00 ] +Key: FCMP_ugt: [ 0.00 0.00 ] +Key: FCMP_uge: [ 0.00 0.00 ] +Key: FCMP_ult: [ 0.00 0.00 ] +Key: FCMP_ule: [ 0.00 0.00 ] +Key: FCMP_une: [ 0.00 0.00 ] +Key: FCMP_true: [ 0.00 0.00 ] +Key: ICMP_eq: [ 0.00 0.00 ] +Key: ICMP_ne: [ 0.00 0.00 ] +Key: ICMP_ugt: [ 0.00 0.00 ] +Key: ICMP_uge: [ 0.00 0.00 ] +Key: ICMP_ult: [ 0.00 0.00 ] +Key: ICMP_ule: [ 0.00 0.00 ] +Key: ICMP_sgt: [ 0.00 0.00 ] +Key: ICMP_sge: [ 0.00 0.00 ] +Key: ICMP_slt: [ 0.00 0.00 ] +Key: ICMP_sle: [ 0.00 0.00 ] diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll index fe532479086d3..804c1ca5cb6f6 100644 --- a/llvm/test/Analysis/IR2Vec/if-else.ll +++ b/llvm/test/Analysis/IR2Vec/if-else.ll @@ -29,7 +29,7 @@ return: ; preds = %if.else, %if.then ; CHECK: Basic block vectors: ; CHECK-NEXT: Basic block: entry: -; CHECK-NEXT: [ 816.00 825.00 834.00 ] +; CHECK-NEXT: [ 816.20 825.20 834.20 ] ; CHECK-NEXT: Basic block: if.then: ; CHECK-NEXT: [ 195.00 198.00 201.00 ] ; CHECK-NEXT: Basic block: if.else: diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll index b0e3e49978018..9be0ee1c2de7a 100644 --- a/llvm/test/Analysis/IR2Vec/unreachable.ll +++ b/llvm/test/Analysis/IR2Vec/unreachable.ll @@ -33,7 +33,7 @@ return: ; preds = %if.else, %if.then ; CHECK: Basic block vectors: ; CHECK-NEXT: Basic block: entry: -; CHECK-NEXT: [ 816.00 825.00 834.00 ] +; CHECK-NEXT: [ 816.20 825.20 834.20 ] ; CHECK-NEXT: Basic block: if.then: ; CHECK-NEXT: [ 195.00 198.00 201.00 ] ; CHECK-NEXT: Basic block: if.else: diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll index 4b51adf30bf74..8dbce57302f6f 100644 --- a/llvm/test/tools/llvm-ir2vec/entities.ll +++ b/llvm/test/tools/llvm-ir2vec/entities.ll @@ -1,6 +1,6 @@ ; RUN: llvm-ir2vec entities | FileCheck %s -CHECK: 84 +CHECK: 110 CHECK-NEXT: Ret 0 CHECK-NEXT: Br 1 CHECK-NEXT: Switch 2 @@ -85,3 +85,29 @@ CHECK-NEXT: Function 80 CHECK-NEXT: Pointer 81 CHECK-NEXT: Constant 82 CHECK-NEXT: Variable 83 +CHECK-NEXT: FCMP_false 84 +CHECK-NEXT: FCMP_oeq 85 +CHECK-NEXT: FCMP_ogt 86 +CHECK-NEXT: FCMP_oge 87 +CHECK-NEXT: FCMP_olt 88 +CHECK-NEXT: FCMP_ole 89 +CHECK-NEXT: FCMP_one 90 +CHECK-NEXT: FCMP_ord 91 +CHECK-NEXT: FCMP_uno 92 +CHECK-NEXT: FCMP_ueq 93 +CHECK-NEXT: FCMP_ugt 94 +CHECK-NEXT: FCMP_uge 95 +CHECK-NEXT: FCMP_ult 96 +CHECK-NEXT: FCMP_ule 97 +CHECK-NEXT: FCMP_une 98 +CHECK-NEXT: FCMP_true 99 +CHECK-NEXT: ICMP_eq 100 +CHECK-NEXT: ICMP_ne 101 +CHECK-NEXT: ICMP_ugt 102 +CHECK-NEXT: ICMP_uge 103 +CHECK-NEXT: ICMP_ult 104 +CHECK-NEXT: ICMP_ule 105 +CHECK-NEXT: ICMP_sgt 106 +CHECK-NEXT: ICMP_sge 107 +CHECK-NEXT: ICMP_slt 108 +CHECK-NEXT: ICMP_sle 109 diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index aabebf0cc90a9..1c656b8fcf4e7 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -184,7 +184,7 @@ class IR2VecTool { // Add "Arg" relationships unsigned ArgIndex = 0; for (const Use &U : I.operands()) { - unsigned OperandID = Vocabulary::getSlotIndex(*U); + unsigned OperandID = Vocabulary::getSlotIndex(*U.get()); unsigned RelationID = ArgRelation + ArgIndex; OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n'; diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 9f2f6a3496ce0..9bc48e45eab5e 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -435,6 +435,7 @@ static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes; static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs; static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs; static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds; +static constexpr unsigned MaxPredicateKinds = Vocabulary::MaxPredicateKinds; // Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID // names and their canonical string keys. @@ -460,7 +461,8 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) { EXPECT_EQ(Emb.size(), Dim); // Should have the correct total number of embeddings - EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands); + EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + + MaxPredicateKinds); auto ExpectedVocab = VocabVec; @@ -527,6 +529,26 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) { EXPECT_EQ(Vocabulary::getSlotIndex(*Arg), EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID)); #undef EXPECTED_VOCAB_OPERAND_SLOT + + // Test getSlotIndex for predicates +#define EXPECTED_VOCAB_PREDICATE_SLOT(X) \ + MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast(X) + for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE; + P <= CmpInst::LAST_FCMP_PREDICATE; ++P) { + CmpInst::Predicate Pred = static_cast(P); + unsigned ExpectedIdx = + EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE)); + EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx); + } + auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1; + for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE; + P <= CmpInst::LAST_ICMP_PREDICATE; ++P) { + CmpInst::Predicate Pred = static_cast(P); + unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT( + ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE); + EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx); + } +#undef EXPECTED_VOCAB_PREDICATE_SLOT } #if GTEST_HAS_DEATH_TEST @@ -569,6 +591,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) { #undef EXPECT_CANONICAL_TYPE_NAME + // Verify OperandKind -> string mapping #define HANDLE_OPERAND_KINDS(X) \ X(FunctionID, "Function") \ X(PointerID, "Pointer") \ @@ -592,6 +615,28 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) { Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1); EXPECT_EQ(FuncArgKey, "Function"); EXPECT_EQ(PtrArgKey, "Pointer"); + +// Verify PredicateKind -> string mapping +#define EXPECT_PREDICATE_KIND(PredNum, PredPos, PredKind) \ + do { \ + std::string PredStr = \ + std::string(PredKind) + "_" + \ + CmpInst::getPredicateName(static_cast(PredNum)) \ + .str(); \ + unsigned Pos = MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + PredPos; \ + EXPECT_EQ(Vocabulary::getStringKey(Pos), PredStr); \ + } while (0) + + for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE; + P <= CmpInst::LAST_FCMP_PREDICATE; ++P) + EXPECT_PREDICATE_KIND(P, P - CmpInst::FIRST_FCMP_PREDICATE, "FCMP"); + + auto ICMP_Pos = CmpInst::LAST_FCMP_PREDICATE + 1; + for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE; + P <= CmpInst::LAST_ICMP_PREDICATE; ++P) + EXPECT_PREDICATE_KIND(P, ICMP_Pos++, "ICMP"); + +#undef EXPECT_PREDICATE_KIND } TEST(IR2VecVocabularyTest, VocabularyDimensions) {