From d762563a5979b2fdfa1cfd0c88c77e4a2c9c23a5 Mon Sep 17 00:00:00 2001 From: svkeerthy Date: Fri, 12 Sep 2025 22:06:44 +0000 Subject: [PATCH] VocabStorage --- llvm/include/llvm/Analysis/IR2Vec.h | 145 +++++++-- llvm/lib/Analysis/IR2Vec.cpp | 230 +++++++++---- llvm/lib/Analysis/InlineAdvisor.cpp | 2 +- llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 6 +- .../FunctionPropertiesAnalysisTest.cpp | 9 +- llvm/unittests/Analysis/IR2VecTest.cpp | 301 ++++++++++++++++-- 6 files changed, 571 insertions(+), 122 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 4a6db5d895a62..7d51a7320d194 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -45,6 +45,7 @@ #include "llvm/Support/JSON.h" #include #include +#include namespace llvm { @@ -144,6 +145,73 @@ struct Embedding { using InstEmbeddingsMap = DenseMap; using BBEmbeddingsMap = DenseMap; +/// Generic storage class for section-based vocabularies. +/// VocabStorage provides a generic foundation for storing and accessing +/// embeddings organized into sections. +class VocabStorage { +private: + /// Section-based storage + std::vector> Sections; + + size_t TotalSize = 0; + unsigned Dimension = 0; + +public: + /// Default constructor creates empty storage (invalid state) + VocabStorage() : Sections(), TotalSize(0), Dimension(0) {} + + /// Create a VocabStorage with pre-organized section data + VocabStorage(std::vector> &&SectionData); + + VocabStorage(VocabStorage &&) = default; + VocabStorage &operator=(VocabStorage &&Other); + + VocabStorage(const VocabStorage &) = delete; + VocabStorage &operator=(const VocabStorage &) = delete; + + /// Get total number of entries across all sections + size_t size() const { return TotalSize; } + + /// Get number of sections + unsigned getNumSections() const { + return static_cast(Sections.size()); + } + + /// Section-based access: Storage[sectionId][localIndex] + const std::vector &operator[](unsigned SectionId) const { + assert(SectionId < Sections.size() && "Invalid section ID"); + return Sections[SectionId]; + } + + /// Get vocabulary dimension + unsigned getDimension() const { return Dimension; } + + /// Check if vocabulary is valid (has data) + bool isValid() const { return TotalSize > 0; } + + /// Iterator support for section-based access + class const_iterator { + const VocabStorage *Storage; + unsigned SectionId; + size_t LocalIndex; + + public: + const_iterator(const VocabStorage *Storage, unsigned SectionId, + size_t LocalIndex) + : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {} + + LLVM_ABI const Embedding &operator*() const; + LLVM_ABI const_iterator &operator++(); + LLVM_ABI bool operator==(const const_iterator &Other) const; + LLVM_ABI bool operator!=(const const_iterator &Other) const; + }; + + const_iterator begin() const { return const_iterator(this, 0, 0); } + const_iterator end() const { + return const_iterator(this, getNumSections(), 0); + } +}; + /// Class for storing and accessing the IR2Vec vocabulary. /// The Vocabulary class manages seed embeddings for LLVM IR entities. The /// seed embeddings are the initial learned representations of the entities @@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap; class Vocabulary { friend class llvm::IR2VecVocabAnalysis; - // Vocabulary Slot Layout: + // Vocabulary Layout: // +----------------+------------------------------------------------------+ // | Entity Type | Index Range | // +----------------+------------------------------------------------------+ @@ -175,8 +243,16 @@ class Vocabulary { // Note: "Similar" LLVM Types are grouped/canonicalized together. // Operands include Comparison predicates (ICmp/FCmp). // This can be extended to include other specializations in future. - using VocabVector = std::vector; - VocabVector Vocab; + enum class Section : unsigned { + Opcodes = 0, + CanonicalTypes = 1, + Operands = 2, + Predicates = 3, + MaxSections + }; + + // Use section-based storage for better organization and efficiency + VocabStorage Storage; static constexpr unsigned NumICmpPredicates = static_cast(CmpInst::LAST_ICMP_PREDICATE) - @@ -228,9 +304,18 @@ class Vocabulary { NumICmpPredicates + NumFCmpPredicates; Vocabulary() = default; - LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {} + LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {} + + Vocabulary(const Vocabulary &) = delete; + Vocabulary &operator=(const Vocabulary &) = delete; + + Vocabulary(Vocabulary &&) = default; + Vocabulary &operator=(Vocabulary &&Other); + + LLVM_ABI bool isValid() const { + return Storage.size() == NumCanonicalEntries; + } - LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; }; LLVM_ABI unsigned getDimension() const; /// Total number of entries (opcodes + canonicalized types + operand kinds + /// predicates) @@ -251,12 +336,11 @@ class Vocabulary { /// Function to get vocabulary key for a given predicate LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P); - /// Functions to return the slot index or position of a given Opcode, TypeID, - /// or OperandKind in the vocabulary. - LLVM_ABI static unsigned getSlotIndex(unsigned Opcode); - LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID); - LLVM_ABI static unsigned getSlotIndex(const Value &Op); - LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P); + /// Functions to return flat index + LLVM_ABI static unsigned getIndex(unsigned Opcode); + LLVM_ABI static unsigned getIndex(Type::TypeID TypeID); + LLVM_ABI static unsigned getIndex(const Value &Op); + LLVM_ABI static unsigned getIndex(CmpInst::Predicate P); /// Accessors to get the embedding for a given entity. LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const; @@ -265,26 +349,21 @@ class Vocabulary { LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const; /// Const Iterator type aliases - using const_iterator = VocabVector::const_iterator; + using const_iterator = VocabStorage::const_iterator; + const_iterator begin() const { assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab.begin(); + return Storage.begin(); } - const_iterator cbegin() const { - assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab.cbegin(); - } + const_iterator cbegin() const { return begin(); } const_iterator end() const { assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab.end(); + return Storage.end(); } - const_iterator cend() const { - assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab.cend(); - } + const_iterator cend() const { return end(); } /// Returns the string key for a given index position in the vocabulary. /// This is useful for debugging or printing the vocabulary. Do not use this @@ -292,7 +371,7 @@ class Vocabulary { LLVM_ABI static StringRef getStringKey(unsigned Pos); /// Create a dummy vocabulary for testing purposes. - LLVM_ABI static VocabVector createDummyVocabForTest(unsigned Dim = 1); + LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1); LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const; @@ -301,12 +380,16 @@ class Vocabulary { constexpr static unsigned NumCanonicalEntries = MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds; - // Base offsets for slot layout to simplify index computation + // Base offsets for flat index computation constexpr static unsigned OperandBaseOffset = MaxOpcodes + MaxCanonicalTypeIDs; constexpr static unsigned PredicateBaseOffset = OperandBaseOffset + MaxOperandKinds; + /// Functions for predicate index calculations + static unsigned getPredicateLocalIndex(CmpInst::Predicate P); + static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex); + /// String mappings for CanonicalTypeID values static constexpr StringLiteral CanonicalTypeNames[] = { "FloatTy", "VoidTy", "LabelTy", "MetadataTy", @@ -452,22 +535,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder { /// mapping between an entity of the IR (like opcode, type, argument, etc.) and /// its corresponding embedding. class IR2VecVocabAnalysis : public AnalysisInfoMixin { - using VocabVector = std::vector; using VocabMap = std::map; - VocabMap OpcVocab, TypeVocab, ArgVocab; - VocabVector Vocab; + std::optional Vocab; - Error readVocabulary(); + Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab, + VocabMap &ArgVocab); Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim); - void generateNumMappedVocab(); + void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab, + VocabMap &ArgVocab); void emitError(Error Err, LLVMContext &Ctx); public: LLVM_ABI static AnalysisKey Key; IR2VecVocabAnalysis() = default; - LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab); - LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab); + LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab) + : Vocab(std::move(Vocab)) {} using Result = ir2vec::Vocabulary; LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM); }; diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index f51f0898cb37e..eeba109eb7dbd 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Module.h" @@ -261,55 +262,121 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { BBVecMap[&BB] = BBVector; } +// ==----------------------------------------------------------------------===// +// VocabStorage +//===----------------------------------------------------------------------===// + +VocabStorage::VocabStorage(std::vector> &&SectionData) + : Sections(std::move(SectionData)) { + TotalSize = 0; + Dimension = 0; + assert(!Sections.empty() && "Vocabulary has no sections"); + assert(!Sections[0].empty() && "First section of vocabulary is empty"); + + // Compute total size across all sections + for (const auto &Section : Sections) + TotalSize += Section.size(); + + // Get dimension from the first embedding in the first section - all + // embeddings must have the same dimension + Dimension = static_cast(Sections[0][0].size()); +} + +VocabStorage &VocabStorage::operator=(VocabStorage &&Other) { + if (this != &Other) { + Sections = std::move(Other.Sections); + TotalSize = Other.TotalSize; + Dimension = Other.Dimension; + Other.TotalSize = 0; + Other.Dimension = 0; + } + return *this; +} + +const Embedding &VocabStorage::const_iterator::operator*() const { + assert(SectionId < Storage->Sections.size() && "Invalid section ID"); + assert(LocalIndex < Storage->Sections[SectionId].size() && + "Local index out of range"); + return Storage->Sections[SectionId][LocalIndex]; +} + +VocabStorage::const_iterator &VocabStorage::const_iterator::operator++() { + ++LocalIndex; + // Check if we need to move to the next section + while (SectionId < Storage->getNumSections() && + LocalIndex >= Storage->Sections[SectionId].size()) { + LocalIndex = 0; + ++SectionId; + } + return *this; +} + +bool VocabStorage::const_iterator::operator==( + const const_iterator &Other) const { + return Storage == Other.Storage && SectionId == Other.SectionId && + LocalIndex == Other.LocalIndex; +} + +bool VocabStorage::const_iterator::operator!=( + const const_iterator &Other) const { + return !(*this == Other); +} + // ==----------------------------------------------------------------------===// // Vocabulary //===----------------------------------------------------------------------===// +Vocabulary &Vocabulary::operator=(Vocabulary &&Other) { + if (this != &Other) + Storage = std::move(Other.Storage); + return *this; +} + unsigned Vocabulary::getDimension() const { assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab[0].size(); + return Storage.getDimension(); } -unsigned Vocabulary::getSlotIndex(unsigned Opcode) { +unsigned Vocabulary::getIndex(unsigned Opcode) { assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); return Opcode - 1; // Convert to zero-based index } -unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) { +unsigned Vocabulary::getIndex(Type::TypeID TypeID) { assert(static_cast(TypeID) < MaxTypeIDs && "Invalid type ID"); return MaxOpcodes + static_cast(getCanonicalTypeID(TypeID)); } -unsigned Vocabulary::getSlotIndex(const Value &Op) { +unsigned Vocabulary::getIndex(const Value &Op) { unsigned Index = static_cast(getOperandKind(&Op)); assert(Index < MaxOperandKinds && "Invalid OperandKind"); return OperandBaseOffset + Index; } -unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) { - unsigned PU = static_cast(P); - unsigned FirstFC = static_cast(CmpInst::FIRST_FCMP_PREDICATE); - unsigned FirstIC = static_cast(CmpInst::FIRST_ICMP_PREDICATE); - - unsigned PredIdx = - (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC); - return PredicateBaseOffset + PredIdx; +unsigned Vocabulary::getIndex(CmpInst::Predicate P) { + return PredicateBaseOffset + getPredicateLocalIndex(P); } const Embedding &Vocabulary::operator[](unsigned Opcode) const { - return Vocab[getSlotIndex(Opcode)]; + assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); + return Storage[static_cast(Section::Opcodes)][Opcode - 1]; } const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const { - return Vocab[getSlotIndex(TypeID)]; + assert(static_cast(TypeID) < MaxTypeIDs && "Invalid type ID"); + unsigned LocalIndex = static_cast(getCanonicalTypeID(TypeID)); + return Storage[static_cast(Section::CanonicalTypes)][LocalIndex]; } const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const { - return Vocab[getSlotIndex(Arg)]; + unsigned LocalIndex = static_cast(getOperandKind(&Arg)); + assert(LocalIndex < MaxOperandKinds && "Invalid OperandKind"); + return Storage[static_cast(Section::Operands)][LocalIndex]; } const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const { - return Vocab[getSlotIndex(P)]; + unsigned LocalIndex = getPredicateLocalIndex(P); + return Storage[static_cast(Section::Predicates)][LocalIndex]; } StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { @@ -359,12 +426,26 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) { assert(Index < MaxPredicateKinds && "Invalid predicate index"); - unsigned PredEnumVal = - (Index < NumFCmpPredicates) - ? (static_cast(CmpInst::FIRST_FCMP_PREDICATE) + Index) - : (static_cast(CmpInst::FIRST_ICMP_PREDICATE) + - (Index - NumFCmpPredicates)); - return static_cast(PredEnumVal); + return getPredicateFromLocalIndex(Index); +} + +unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) { + if (P >= CmpInst::FIRST_FCMP_PREDICATE && P <= CmpInst::LAST_FCMP_PREDICATE) + return P - CmpInst::FIRST_FCMP_PREDICATE; + else + return P - CmpInst::FIRST_ICMP_PREDICATE + + (CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1); +} + +CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) { + unsigned fcmpRange = + CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1; + if (LocalIndex < fcmpRange) + return static_cast(CmpInst::FIRST_FCMP_PREDICATE + + LocalIndex); + else + return static_cast(CmpInst::FIRST_ICMP_PREDICATE + + LocalIndex - fcmpRange); } StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) { @@ -401,17 +482,51 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA, return !(PAC.preservedWhenStateless()); } -Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { - VocabVector DummyVocab; - DummyVocab.reserve(NumCanonicalEntries); +VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) { float DummyVal = 0.1f; - // Create a dummy vocabulary with entries for all opcodes, types, operands - // and predicates - for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) { - DummyVocab.push_back(Embedding(Dim, DummyVal)); + + // Create sections for opcodes, types, operands, and predicates + // Order must match Vocabulary::Section enum + std::vector> Sections; + Sections.reserve(4); + + // Opcodes section + std::vector OpcodeSec; + OpcodeSec.reserve(MaxOpcodes); + for (unsigned I = 0; I < MaxOpcodes; ++I) { + OpcodeSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(OpcodeSec)); + + // Types section + std::vector TypeSec; + TypeSec.reserve(MaxCanonicalTypeIDs); + for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) { + TypeSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(TypeSec)); + + // Operands section + std::vector OperandSec; + OperandSec.reserve(MaxOperandKinds); + for (unsigned I = 0; I < MaxOperandKinds; ++I) { + OperandSec.emplace_back(Dim, DummyVal); DummyVal += 0.1f; } - return DummyVocab; + Sections.push_back(std::move(OperandSec)); + + // Predicates section + std::vector PredicateSec; + PredicateSec.reserve(MaxPredicateKinds); + for (unsigned I = 0; I < MaxPredicateKinds; ++I) { + PredicateSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(PredicateSec)); + + return VocabStorage(std::move(Sections)); } // ==----------------------------------------------------------------------===// @@ -457,7 +572,9 @@ Error IR2VecVocabAnalysis::parseVocabSection( // FIXME: Make this optional. We can avoid file reads // by auto-generating a default vocabulary during the build time. -Error IR2VecVocabAnalysis::readVocabulary() { +Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab, + VocabMap &TypeVocab, + VocabMap &ArgVocab) { auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true); if (!BufOrError) return createFileError(VocabFile, BufOrError.getError()); @@ -488,7 +605,9 @@ Error IR2VecVocabAnalysis::readVocabulary() { return Error::success(); } -void IR2VecVocabAnalysis::generateNumMappedVocab() { +void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab, + VocabMap &TypeVocab, + VocabMap &ArgVocab) { // Helper for handling missing entities in the vocabulary. // Currently, we use a zero vector. In the future, we will throw an error to @@ -506,7 +625,6 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Opcodes std::vector NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, Embedding(Dim)); - NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes); for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); auto It = OpcVocab.find(VocabKey.str()); @@ -515,13 +633,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { else handleMissingEntity(VocabKey.str()); } - Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), - NumericOpcodeEmbeddings.end()); // Handle Types - only canonical types are present in vocabulary std::vector NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs, Embedding(Dim)); - NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs); for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) { StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID( static_cast(CTypeID)); @@ -531,13 +646,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { } handleMissingEntity(VocabKey.str()); } - Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(), - NumericTypeEmbeddings.end()); // Handle Arguments/Operands std::vector NumericArgEmbeddings(Vocabulary::MaxOperandKinds, Embedding(Dim)); - NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds); for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) { Vocabulary::OperandKind Kind = static_cast(OpKind); StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind); @@ -548,14 +660,11 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { } handleMissingEntity(VocabKey.str()); } - Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(), - NumericArgEmbeddings.end()); // Handle Predicates: part of Operands section. We look up predicate keys // in ArgVocab. std::vector NumericPredEmbeddings(Vocabulary::MaxPredicateKinds, Embedding(Dim, 0)); - NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds); for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) { StringRef VocabKey = Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK)); @@ -566,15 +675,22 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { } handleMissingEntity(VocabKey.str()); } - Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(), - NumericPredEmbeddings.end()); -} -IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab) - : Vocab(Vocab) {} + // Create section-based storage instead of flat vocabulary + // Order must match Vocabulary::Section enum + std::vector> Sections(4); + Sections[static_cast(Vocabulary::Section::Opcodes)] = + std::move(NumericOpcodeEmbeddings); // Section::Opcodes + Sections[static_cast(Vocabulary::Section::CanonicalTypes)] = + std::move(NumericTypeEmbeddings); // Section::CanonicalTypes + Sections[static_cast(Vocabulary::Section::Operands)] = + std::move(NumericArgEmbeddings); // Section::Operands + Sections[static_cast(Vocabulary::Section::Predicates)] = + std::move(NumericPredEmbeddings); // Section::Predicates -IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab) - : Vocab(std::move(Vocab)) {} + // Create VocabStorage from organized sections + Vocab.emplace(std::move(Sections)); +} void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { @@ -586,8 +702,8 @@ IR2VecVocabAnalysis::Result IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { auto Ctx = &M.getContext(); // If vocabulary is already populated by the constructor, use it. - if (!Vocab.empty()) - return Vocabulary(std::move(Vocab)); + if (Vocab.has_value()) + return Vocabulary(std::move(Vocab.value())); // Otherwise, try to read from the vocabulary file. if (VocabFile.empty()) { @@ -596,7 +712,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { "set it using --ir2vec-vocab-path"); return Vocabulary(); // Return invalid result } - if (auto Err = readVocabulary()) { + + VocabMap OpcVocab, TypeVocab, ArgVocab; + if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) { emitError(std::move(Err), *Ctx); return Vocabulary(); } @@ -611,9 +729,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { scaleVocabSection(ArgVocab, ArgWeight); // Generate the numeric lookup vocabulary - generateNumMappedVocab(); + generateVocabStorage(OpcVocab, TypeVocab, ArgVocab); - return Vocabulary(std::move(Vocab)); + return Vocabulary(std::move(Vocab.value())); } // ==----------------------------------------------------------------------===// @@ -622,7 +740,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { PreservedAnalyses IR2VecPrinterPass::run(Module &M, ModuleAnalysisManager &MAM) { - auto Vocabulary = MAM.getResult(M); + auto &Vocabulary = MAM.getResult(M); assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid"); for (Function &F : M) { @@ -664,7 +782,7 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M, ModuleAnalysisManager &MAM) { - auto IR2VecVocabulary = MAM.getResult(M); + auto &IR2VecVocabulary = MAM.getResult(M); assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid"); // Print each entry diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp index 28b14c2562df1..0fa804f2959e8 100644 --- a/llvm/lib/Analysis/InlineAdvisor.cpp +++ b/llvm/lib/Analysis/InlineAdvisor.cpp @@ -217,7 +217,7 @@ AnalysisKey PluginInlineAdvisorAnalysis::Key; bool InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested( Module &M, ModuleAnalysisManager &MAM) { if (!IR2VecVocabFile.empty()) { - auto IR2VecVocabResult = MAM.getResult(M); + auto &IR2VecVocabResult = MAM.getResult(M); if (!IR2VecVocabResult.isValid()) { M.getContext().emitError("Failed to load IR2Vec vocabulary"); return false; diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index 1c656b8fcf4e7..434449c7c5117 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -162,8 +162,8 @@ class IR2VecTool { for (const BasicBlock &BB : F) { for (const auto &I : BB.instructionsWithoutDebug()) { - unsigned Opcode = Vocabulary::getSlotIndex(I.getOpcode()); - unsigned TypeID = Vocabulary::getSlotIndex(I.getType()->getTypeID()); + unsigned Opcode = Vocabulary::getIndex(I.getOpcode()); + unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID()); // Add "Next" relationship with previous instruction if (HasPrevOpcode) { @@ -184,7 +184,7 @@ class IR2VecTool { // Add "Arg" relationships unsigned ArgIndex = 0; for (const Use &U : I.operands()) { - unsigned OperandID = Vocabulary::getSlotIndex(*U.get()); + unsigned OperandID = Vocabulary::getIndex(*U.get()); unsigned RelationID = ArgRelation + ArgIndex; OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n'; diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp index dc6059dcf6827..1db34ba941292 100644 --- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp +++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp @@ -43,8 +43,11 @@ class FunctionPropertiesAnalysisTest : public testing::Test { public: FunctionPropertiesAnalysisTest() { auto VocabVector = ir2vec::Vocabulary::createDummyVocabForTest(1); - MAM.registerPass([&] { return IR2VecVocabAnalysis(VocabVector); }); - IR2VecVocab = ir2vec::Vocabulary(std::move(VocabVector)); + MAM.registerPass([VocabVector = std::move(VocabVector)]() mutable { + return IR2VecVocabAnalysis(std::move(VocabVector)); + }); + IR2VecVocab = + ir2vec::Vocabulary(ir2vec::Vocabulary::createDummyVocabForTest(1)); MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); FAM.registerPass([&] { return DominatorTreeAnalysis(); }); @@ -78,7 +81,7 @@ class FunctionPropertiesAnalysisTest : public testing::Test { FunctionPropertiesInfo buildFPI(Function &F) { // FunctionPropertiesInfo assumes IR2VecVocabAnalysis has been run to // use IR2Vec. - auto VocabResult = MAM.getResult(*F.getParent()); + auto &VocabResult = MAM.getResult(*F.getParent()); (void)VocabResult; return FunctionPropertiesInfo::getFunctionPropertiesInfo(F, FAM); } diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 9bc48e45eab5e..d915920eccda0 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -464,7 +464,10 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) { EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + MaxPredicateKinds); - auto ExpectedVocab = VocabVec; + // Collect embeddings for later comparison before moving VocabVec + std::vector ExpectedVocab; + for (const auto &Emb : VocabVec) + ExpectedVocab.push_back(Emb); IR2VecVocabAnalysis VocabAnalysis(std::move(VocabVec)); LLVMContext TestCtx; @@ -482,17 +485,17 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) { } TEST(IR2VecVocabularyTest, SlotIdxMapping) { - // Test getSlotIndex for Opcodes + // Test getIndex for Opcodes #define EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) \ - EXPECT_EQ(Vocabulary::getSlotIndex(NUM), static_cast(NUM - 1)); + EXPECT_EQ(Vocabulary::getIndex(NUM), static_cast(NUM - 1)); #define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE_SLOT(NUM, OPCODE, CLASS) #include "llvm/IR/Instruction.def" #undef HANDLE_INST #undef EXPECT_OPCODE_SLOT - // Test getSlotIndex for Types + // Test getIndex for Types #define EXPECT_TYPE_SLOT(TypeIDTok, CanonEnum, CanonStr) \ - EXPECT_EQ(Vocabulary::getSlotIndex(Type::TypeIDTok), \ + EXPECT_EQ(Vocabulary::getIndex(Type::TypeIDTok), \ MaxOpcodes + static_cast( \ Vocabulary::CanonicalTypeID::CanonEnum)); @@ -500,7 +503,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) { #undef EXPECT_TYPE_SLOT - // Test getSlotIndex for Value operands + // Test getIndex for Value operands LLVMContext Ctx; Module M("TestM", Ctx); FunctionType *FTy = @@ -510,27 +513,27 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) { #define EXPECTED_VOCAB_OPERAND_SLOT(X) \ MaxOpcodes + MaxCanonicalTypeIDs + static_cast(X) // Test Function operand - EXPECT_EQ(Vocabulary::getSlotIndex(*F), + EXPECT_EQ(Vocabulary::getIndex(*F), EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID)); // Test Constant operand Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42); - EXPECT_EQ(Vocabulary::getSlotIndex(*C), + EXPECT_EQ(Vocabulary::getIndex(*C), EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID)); // Test Pointer operand BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F); AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB); - EXPECT_EQ(Vocabulary::getSlotIndex(*PtrVal), + EXPECT_EQ(Vocabulary::getIndex(*PtrVal), EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID)); // Test Variable operand (function argument) Argument *Arg = F->getArg(0); - EXPECT_EQ(Vocabulary::getSlotIndex(*Arg), + EXPECT_EQ(Vocabulary::getIndex(*Arg), EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID)); #undef EXPECTED_VOCAB_OPERAND_SLOT - // Test getSlotIndex for predicates + // Test getIndex for predicates #define EXPECTED_VOCAB_PREDICATE_SLOT(X) \ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast(X) for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE; @@ -538,7 +541,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) { CmpInst::Predicate Pred = static_cast(P); unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE)); - EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx); + EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx); } auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1; for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE; @@ -546,7 +549,7 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) { CmpInst::Predicate Pred = static_cast(P); unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT( ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE); - EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx); + EXPECT_EQ(Vocabulary::getIndex(Pred), ExpectedIdx); } #undef EXPECTED_VOCAB_PREDICATE_SLOT } @@ -555,15 +558,14 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) { #ifndef NDEBUG TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) { // Test invalid opcode IDs - EXPECT_DEATH(Vocabulary::getSlotIndex(0u), "Invalid opcode"); - EXPECT_DEATH(Vocabulary::getSlotIndex(MaxOpcodes + 1), "Invalid opcode"); + EXPECT_DEATH(Vocabulary::getIndex(0u), "Invalid opcode"); + EXPECT_DEATH(Vocabulary::getIndex(MaxOpcodes + 1), "Invalid opcode"); // Test invalid type IDs - EXPECT_DEATH(Vocabulary::getSlotIndex(static_cast(MaxTypeIDs)), + EXPECT_DEATH(Vocabulary::getIndex(static_cast(MaxTypeIDs)), + "Invalid type ID"); + EXPECT_DEATH(Vocabulary::getIndex(static_cast(MaxTypeIDs + 10)), "Invalid type ID"); - EXPECT_DEATH( - Vocabulary::getSlotIndex(static_cast(MaxTypeIDs + 10)), - "Invalid type ID"); } #endif // NDEBUG #endif // GTEST_HAS_DEATH_TEST @@ -573,7 +575,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) { EXPECT_EQ(Vocabulary::getStringKey(12), "Add"); #define EXPECT_OPCODE(NUM, OPCODE, CLASS) \ - EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getSlotIndex(NUM)), \ + EXPECT_EQ(Vocabulary::getStringKey(Vocabulary::getIndex(NUM)), \ Vocabulary::getVocabKeyForOpcode(NUM)); #define HANDLE_INST(NUM, OPCODE, CLASS) EXPECT_OPCODE(NUM, OPCODE, CLASS) #include "llvm/IR/Instruction.def" @@ -672,10 +674,12 @@ TEST(IR2VecVocabularyTest, InvalidAccess) { #endif // GTEST_HAS_DEATH_TEST TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) { + Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest()); #define EXPECT_TYPE_TO_CANONICAL(TypeIDTok, CanonEnum, CanonStr) \ - EXPECT_EQ( \ - Vocabulary::getStringKey(Vocabulary::getSlotIndex(Type::TypeIDTok)), \ - CanonStr); + do { \ + unsigned FlatIdx = V.getIndex(Type::TypeIDTok); \ + EXPECT_EQ(Vocabulary::getStringKey(FlatIdx), CanonStr); \ + } while (0); IR2VEC_HANDLE_TYPE_BIMAP(EXPECT_TYPE_TO_CANONICAL) @@ -683,14 +687,20 @@ TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) { } TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) { - std::vector InvalidVocab; - InvalidVocab.push_back(Embedding(2, 1.0)); - InvalidVocab.push_back(Embedding(2, 2.0)); - - Vocabulary V(std::move(InvalidVocab)); + // Test 1: Create invalid VocabStorage with insufficient sections + std::vector> InvalidSectionData; + // Only add one section with 2 embeddings, but the vocabulary needs 4 sections + std::vector Section1; + Section1.push_back(Embedding(2, 1.0)); + Section1.push_back(Embedding(2, 2.0)); + InvalidSectionData.push_back(std::move(Section1)); + + VocabStorage InvalidStorage(std::move(InvalidSectionData)); + Vocabulary V(std::move(InvalidStorage)); EXPECT_FALSE(V.isValid()); { + // Test 2: Default-constructed vocabulary should be invalid Vocabulary InvalidResult; EXPECT_FALSE(InvalidResult.isValid()); #if GTEST_HAS_DEATH_TEST @@ -701,4 +711,239 @@ TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) { } } +TEST(VocabStorageTest, DefaultConstructor) { + VocabStorage storage; + + EXPECT_EQ(storage.size(), 0u); + EXPECT_EQ(storage.getNumSections(), 0u); + EXPECT_EQ(storage.getDimension(), 0u); + EXPECT_FALSE(storage.isValid()); + + // Test iterators on empty storage + EXPECT_EQ(storage.begin(), storage.end()); +} + +TEST(VocabStorageTest, BasicConstruction) { + // Create test data with 3 sections + std::vector> sectionData; + + // Section 0: 2 embeddings of dimension 3 + std::vector section0; + section0.emplace_back(std::vector{1.0, 2.0, 3.0}); + section0.emplace_back(std::vector{4.0, 5.0, 6.0}); + sectionData.push_back(std::move(section0)); + + // Section 1: 1 embedding of dimension 3 + std::vector section1; + section1.emplace_back(std::vector{7.0, 8.0, 9.0}); + sectionData.push_back(std::move(section1)); + + // Section 2: 3 embeddings of dimension 3 + std::vector section2; + section2.emplace_back(std::vector{10.0, 11.0, 12.0}); + section2.emplace_back(std::vector{13.0, 14.0, 15.0}); + section2.emplace_back(std::vector{16.0, 17.0, 18.0}); + sectionData.push_back(std::move(section2)); + + VocabStorage storage(std::move(sectionData)); + + EXPECT_EQ(storage.size(), 6u); // Total: 2 + 1 + 3 = 6 + EXPECT_EQ(storage.getNumSections(), 3u); + EXPECT_EQ(storage.getDimension(), 3u); + EXPECT_TRUE(storage.isValid()); +} + +TEST(VocabStorageTest, SectionAccess) { + // Create test data + std::vector> sectionData; + + std::vector section0; + section0.emplace_back(std::vector{1.0, 2.0}); + section0.emplace_back(std::vector{3.0, 4.0}); + sectionData.push_back(std::move(section0)); + + std::vector section1; + section1.emplace_back(std::vector{5.0, 6.0}); + sectionData.push_back(std::move(section1)); + + VocabStorage storage(std::move(sectionData)); + + // Test section access + EXPECT_EQ(storage[0].size(), 2u); + EXPECT_EQ(storage[1].size(), 1u); + + // Test embedding values + EXPECT_THAT(storage[0][0].getData(), ElementsAre(1.0, 2.0)); + EXPECT_THAT(storage[0][1].getData(), ElementsAre(3.0, 4.0)); + EXPECT_THAT(storage[1][0].getData(), ElementsAre(5.0, 6.0)); +} + +#if GTEST_HAS_DEATH_TEST +#ifndef NDEBUG +TEST(VocabStorageTest, InvalidSectionAccess) { + std::vector> sectionData; + std::vector section0; + section0.emplace_back(std::vector{1.0, 2.0}); + sectionData.push_back(std::move(section0)); + + VocabStorage storage(std::move(sectionData)); + + EXPECT_DEATH(storage[1], "Invalid section ID"); + EXPECT_DEATH(storage[10], "Invalid section ID"); +} + +TEST(VocabStorageTest, EmptySection) { + std::vector> sectionData; + std::vector emptySection; // Empty section + sectionData.push_back(std::move(emptySection)); + + std::vector validSection; + validSection.emplace_back(std::vector{1.0}); + sectionData.push_back(std::move(validSection)); + + EXPECT_DEATH(VocabStorage(std::move(sectionData)), + "First section of vocabulary is empty"); +} + +TEST(VocabStorageTest, NoSections) { + std::vector> sectionData; // No sections + + EXPECT_DEATH(VocabStorage(std::move(sectionData)), + "Vocabulary has no sections"); +} +#endif // NDEBUG +#endif // GTEST_HAS_DEATH_TEST + +TEST(VocabStorageTest, MoveAssignment) { + // Create source storage + std::vector> sectionData1; + std::vector section0; + section0.emplace_back(std::vector{1.0, 2.0}); + sectionData1.push_back(std::move(section0)); + VocabStorage source(std::move(sectionData1)); + + // Create destination storage + std::vector> sectionData2; + std::vector section1; + section1.emplace_back(std::vector{5.0, 6.0, 7.0}); + sectionData2.push_back(std::move(section1)); + VocabStorage dest(std::move(sectionData2)); + + EXPECT_EQ(dest.getDimension(), 3u); // Initially 3D + + // Move assign + dest = std::move(source); + + // Check destination has source's data + EXPECT_EQ(dest.size(), 1u); + EXPECT_EQ(dest.getDimension(), 2u); // Now 2D from source + EXPECT_TRUE(dest.isValid()); + EXPECT_THAT(dest[0][0].getData(), ElementsAre(1.0, 2.0)); +} + +TEST(VocabStorageTest, IteratorBasics) { + std::vector> sectionData; + + std::vector section0; + section0.emplace_back(std::vector{1.0, 2.0}); + section0.emplace_back(std::vector{3.0, 4.0}); + sectionData.push_back(std::move(section0)); + + std::vector section1; + section1.emplace_back(std::vector{5.0, 6.0}); + sectionData.push_back(std::move(section1)); + + VocabStorage storage(std::move(sectionData)); + + // Test iterator basics + auto it = storage.begin(); + auto end = storage.end(); + + EXPECT_NE(it, end); + + // Check first embedding + EXPECT_THAT((*it).getData(), ElementsAre(1.0, 2.0)); + + // Advance to second embedding + ++it; + EXPECT_NE(it, end); + EXPECT_THAT((*it).getData(), ElementsAre(3.0, 4.0)); + + // Advance to third embedding (in section 1) + ++it; + EXPECT_NE(it, end); + EXPECT_THAT((*it).getData(), ElementsAre(5.0, 6.0)); + + // Advance past the end + ++it; + EXPECT_EQ(it, end); +} + +TEST(VocabStorageTest, IteratorTraversal) { + std::vector> sectionData; + + // Section 0: 2 embeddings + std::vector section0; + section0.emplace_back(std::vector{10.0}); + section0.emplace_back(std::vector{20.0}); + sectionData.push_back(std::move(section0)); + + // Section 1: empty section (to test section skipping) + std::vector section1; // Empty + sectionData.push_back(std::move(section1)); + + // Section 2: 3 embeddings + std::vector section2; + section2.emplace_back(std::vector{30.0}); + section2.emplace_back(std::vector{40.0}); + section2.emplace_back(std::vector{50.0}); + sectionData.push_back(std::move(section2)); + + VocabStorage storage(std::move(sectionData)); + + // Collect all values using iterator + std::vector values; + for (const auto &emb : storage) { + EXPECT_EQ(emb.size(), 1u); + values.push_back(emb[0]); + } + + // Should get all embeddings from non-empty sections + EXPECT_THAT(values, ElementsAre(10.0, 20.0, 30.0, 40.0, 50.0)); +} + +TEST(VocabStorageTest, IteratorComparison) { + std::vector> sectionData; + std::vector section0; + section0.emplace_back(std::vector{1.0}); + section0.emplace_back(std::vector{2.0}); + sectionData.push_back(std::move(section0)); + + VocabStorage storage(std::move(sectionData)); + + auto it1 = storage.begin(); + auto it2 = storage.begin(); + auto end = storage.end(); + + // Test equality + EXPECT_EQ(it1, it2); + EXPECT_NE(it1, end); + + // Advance one iterator + ++it1; + EXPECT_NE(it1, it2); + EXPECT_NE(it1, end); + + // Advance second iterator to match + ++it2; + EXPECT_EQ(it1, it2); + + // Advance both to end + ++it1; + ++it2; + EXPECT_EQ(it1, end); + EXPECT_EQ(it2, end); + EXPECT_EQ(it1, it2); +} + } // end anonymous namespace