-
Notifications
You must be signed in to change notification settings - Fork 15k
[IR2Vec] Refactor vocabulary to use canonical type IDs #155323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesRefactor IR2Vec vocabulary to use canonical type IDs, improving the embedding representation for LLVM IR types. The previous implementation used raw Type::TypeID values directly in the vocabulary, which led to redundant entries (e.g., all float variants mapped to "FloatTy" but had separate slots). This change improves the vocabulary by:
Patch is 41.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155323.diff 9 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 7ace83ba1d053..eb3860ea1e488 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/JSON.h"
+#include <array>
#include <map>
namespace llvm {
@@ -137,13 +138,48 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// Class for storing and accessing the IR2Vec vocabulary.
-/// Encapsulates all vocabulary-related constants, logic, and access methods.
+///
+/// The Vocabulary class manages seed embeddings for LLVM IR entities. It
+/// contains the seed embeddings for three types of entities: instruction
+/// opcodes, types, and operands. Types are grouped/canonicalized for better
+/// learning (e.g., all float variants map to FloatTy). The vocabulary abstracts
+/// away the canonicalization effectively, the exposed APIs handle all the known
+/// LLVM IR opcodes, types and operands.
+///
+/// This class helps populate the seed embeddings in an internal vector-based
+/// ADT. It provides logic to map every IR entity to a specific slot index or
+/// position in this vector, enabling O(1) embedding lookup while avoiding
+/// unnecessary computations involving string based lookups while generating the
+/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
bool Valid = false;
+public:
+ // Slot layout:
+ // [0 .. MaxOpcodes-1] => Instruction opcodes
+ // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
+ // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
+
+ /// Canonical type IDs supported by IR2Vec Vocabulary
+ enum class CanonicalTypeID : unsigned {
+ FloatTy,
+ VoidTy,
+ LabelTy,
+ MetadataTy,
+ VectorTy,
+ TokenTy,
+ IntegerTy,
+ FunctionTy,
+ PointerTy,
+ StructTy,
+ ArrayTy,
+ UnknownTy,
+ MaxCanonicalType
+ };
+
/// Operand kinds supported by IR2Vec Vocabulary
enum class OperandKind : unsigned {
FunctionID,
@@ -152,20 +188,15 @@ class Vocabulary {
VariableID,
MaxOperandKind
};
- /// String mappings for OperandKind values
- static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
- "Constant", "Variable"};
- static_assert(std::size(OperandKindNames) ==
- static_cast<unsigned>(OperandKind::MaxOperandKind),
- "OperandKindNames array size must match MaxOperandKind");
-public:
/// Vocabulary layout constants
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
#include "llvm/IR/Instruction.def"
#undef LAST_OTHER_INST
static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
+ static constexpr unsigned MaxCanonicalTypeIDs =
+ static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
@@ -174,11 +205,8 @@ class Vocabulary {
LLVM_ABI bool isValid() const;
LLVM_ABI unsigned getDimension() const;
- LLVM_ABI size_t size() const;
-
- static size_t expectedSize() {
- return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
- }
+ /// Total number of entries (opcodes + canonicalized types + operand kinds)
+ static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
/// Helper function to get vocabulary key for a given Opcode
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
@@ -192,10 +220,11 @@ class Vocabulary {
/// Helper function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
- /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
- LLVM_ABI static unsigned getNumericID(unsigned Opcode);
- LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
- LLVM_ABI static unsigned getNumericID(const Value *Op);
+ /// Helpers to return the slot index or position of a given Opcode, TypeID, or
+ /// OperandKind in the vocabulary.
+ LLVM_ABI static unsigned getSlotIdx(unsigned Opcode);
+ LLVM_ABI static unsigned getSlotIdx(Type::TypeID TypeID);
+ LLVM_ABI static unsigned getSlotIdx(const Value *Op);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
@@ -234,6 +263,61 @@ class Vocabulary {
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const;
+
+private:
+ constexpr static unsigned NumCanonicalEntries =
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+
+ /// String mappings for CanonicalTypeID values
+ static constexpr StringLiteral CanonicalTypeNames[] = {
+ "FloatTy", "VoidTy", "LabelTy", "MetadataTy",
+ "VectorTy", "TokenTy", "IntegerTy", "FunctionTy",
+ "PointerTy", "StructTy", "ArrayTy", "UnknownTy"};
+ static_assert(std::size(CanonicalTypeNames) ==
+ static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
+ "CanonicalTypeNames array size must match MaxCanonicalType");
+
+ /// String mappings for OperandKind values
+ static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
+ "Constant", "Variable"};
+ static_assert(std::size(OperandKindNames) ==
+ static_cast<unsigned>(OperandKind::MaxOperandKind),
+ "OperandKindNames array size must match MaxOperandKind");
+
+ /// Every known TypeID defined in llvm/IR/Type.h is expected to have a
+ /// corresponding mapping here in the same order as enum Type::TypeID.
+ static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
+ CanonicalTypeID::FloatTy, // HalfTyID = 0
+ CanonicalTypeID::FloatTy, // BFloatTyID
+ CanonicalTypeID::FloatTy, // FloatTyID
+ CanonicalTypeID::FloatTy, // DoubleTyID
+ CanonicalTypeID::FloatTy, // X86_FP80TyID
+ CanonicalTypeID::FloatTy, // FP128TyID
+ CanonicalTypeID::FloatTy, // PPC_FP128TyID
+ CanonicalTypeID::VoidTy, // VoidTyID
+ CanonicalTypeID::LabelTy, // LabelTyID
+ CanonicalTypeID::MetadataTy, // MetadataTyID
+ CanonicalTypeID::VectorTy, // X86_AMXTyID
+ CanonicalTypeID::TokenTy, // TokenTyID
+ CanonicalTypeID::IntegerTy, // IntegerTyID
+ CanonicalTypeID::FunctionTy, // FunctionTyID
+ CanonicalTypeID::PointerTy, // PointerTyID
+ CanonicalTypeID::StructTy, // StructTyID
+ CanonicalTypeID::ArrayTy, // ArrayTyID
+ CanonicalTypeID::VectorTy, // FixedVectorTyID
+ CanonicalTypeID::VectorTy, // ScalableVectorTyID
+ CanonicalTypeID::PointerTy, // TypedPointerTyID
+ CanonicalTypeID::UnknownTy // TargetExtTyID
+ }};
+ static_assert(TypeIDMapping.size() == MaxTypeIDs,
+ "TypeIDMapping must cover all Type::TypeID values");
+
+ /// Helper function to get vocabulary key for canonical type by enum
+ LLVM_ABI static StringRef
+ getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
+
+ /// Helper function to convert TypeID to CanonicalTypeID
+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
};
/// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index e28938b64bfdb..e92dfca0c4ac6 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -32,7 +32,7 @@ using namespace ir2vec;
#define DEBUG_TYPE "ir2vec"
STATISTIC(VocabMissCounter,
- "Number of lookups to entites not present in the vocabulary");
+ "Number of lookups to entities not present in the vocabulary");
namespace llvm {
namespace ir2vec {
@@ -264,12 +264,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
: Vocab(std::move(Vocab)), Valid(true) {}
bool Vocabulary::isValid() const {
- return Vocab.size() == Vocabulary::expectedSize() && Valid;
-}
-
-size_t Vocabulary::size() const {
- assert(Valid && "IR2Vec Vocabulary is invalid");
- return Vocab.size();
+ return Vocab.size() == NumCanonicalEntries && Valid;
}
unsigned Vocabulary::getDimension() const {
@@ -277,19 +272,32 @@ unsigned Vocabulary::getDimension() const {
return Vocab[0].size();
}
-const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+unsigned Vocabulary::getSlotIdx(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Vocab[Opcode - 1];
+ return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getSlotIdx(Type::TypeID TypeID) {
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
+}
+
+unsigned Vocabulary::getSlotIdx(const Value *Op) {
+ unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+ assert(Index < MaxOperandKinds && "Invalid OperandKind");
+ return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+}
+
+const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+ return Vocab[getSlotIdx(Opcode)];
}
-const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
- assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
- return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
+ return Vocab[getSlotIdx(TypeID)];
}
const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
- OperandKind ArgKind = getOperandKind(Arg);
- return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
+ return Vocab[getSlotIdx(Arg)];
}
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
@@ -303,43 +311,21 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
return "UnknownOpcode";
}
+StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
+ unsigned Index = static_cast<unsigned>(CType);
+ assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
+ return CanonicalTypeNames[Index];
+}
+
+Vocabulary::CanonicalTypeID
+Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) {
+ unsigned Index = static_cast<unsigned>(TypeID);
+ assert(Index < MaxTypeIDs && "Invalid TypeID");
+ return TypeIDMapping[Index];
+}
+
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
- switch (TypeID) {
- case Type::VoidTyID:
- return "VoidTy";
- case Type::HalfTyID:
- case Type::BFloatTyID:
- case Type::FloatTyID:
- case Type::DoubleTyID:
- case Type::X86_FP80TyID:
- case Type::FP128TyID:
- case Type::PPC_FP128TyID:
- return "FloatTy";
- case Type::IntegerTyID:
- return "IntegerTy";
- case Type::FunctionTyID:
- return "FunctionTy";
- case Type::StructTyID:
- return "StructTy";
- case Type::ArrayTyID:
- return "ArrayTy";
- case Type::PointerTyID:
- case Type::TypedPointerTyID:
- return "PointerTy";
- case Type::FixedVectorTyID:
- case Type::ScalableVectorTyID:
- return "VectorTy";
- case Type::LabelTyID:
- return "LabelTy";
- case Type::TokenTyID:
- return "TokenTy";
- case Type::MetadataTyID:
- return "MetadataTy";
- case Type::X86_AMXTyID:
- case Type::TargetExtTyID:
- return "UnknownTy";
- }
- return "UnknownTy";
+ return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
}
StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -348,20 +334,6 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
return OperandKindNames[Index];
}
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
- VocabVector DummyVocab;
- float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operand
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
- Vocabulary::MaxOperandKinds)) {
- DummyVocab.push_back(Embedding(Dim, DummyVal));
- DummyVal += 0.1f;
- }
- return DummyVocab;
-}
-
// Helper function to classify an operand into OperandKind
Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
if (isa<Function>(Op))
@@ -373,34 +345,18 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
-unsigned Vocabulary::getNumericID(unsigned Opcode) {
- assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
- return Opcode - 1; // Convert to zero-based index
-}
-
-unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
- assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
- return MaxOpcodes + static_cast<unsigned>(TypeID);
-}
-
-unsigned Vocabulary::getNumericID(const Value *Op) {
- unsigned Index = static_cast<unsigned>(getOperandKind(Op));
- assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxTypeIDs + Index;
-}
-
StringRef Vocabulary::getStringKey(unsigned Pos) {
- assert(Pos < Vocabulary::expectedSize() &&
- "Position out of bounds in vocabulary");
+ assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxTypeIDs)
- return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
+ if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ return getVocabKeyForCanonicalTypeID(
+ static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
+ static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -410,6 +366,21 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
return !(PAC.preservedWhenStateless());
}
+Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
+ VocabVector DummyVocab;
+ DummyVocab.reserve(NumCanonicalEntries);
+ float DummyVal = 0.1f;
+ // Create a dummy vocabulary with entries for all opcodes, types, and
+ // operands
+ for ([[maybe_unused]] unsigned _ :
+ seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
+ Vocabulary::MaxOperandKinds)) {
+ DummyVocab.push_back(Embedding(Dim, DummyVal));
+ DummyVal += 0.1f;
+ }
+ return DummyVocab;
+}
+
// ==----------------------------------------------------------------------===//
// IR2VecVocabAnalysis
//===----------------------------------------------------------------------===//
@@ -502,6 +473,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
Embedding(Dim, 0));
+ NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
auto It = OpcVocab.find(VocabKey.str());
@@ -513,14 +485,15 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
NumericOpcodeEmbeddings.end());
- // Handle Types
- std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
+ // Handle Types - only canonical types are present in vocabulary
+ std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
Embedding(Dim, 0));
- for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
- StringRef VocabKey =
- Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
+ NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
+ for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
+ StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
+ static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
- NumericTypeEmbeddings[TypeID] = It->second;
+ NumericTypeEmbeddings[CTypeID] = It->second;
continue;
}
handleMissingEntity(VocabKey.str());
@@ -531,6 +504,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
Embedding(Dim, 0));
+ NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index 1b9b3c2acd8a5..df7769c9c6a65 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 129.00 130.00 ]
Key: LandingPad: [ 131.00 132.00 ]
Key: Freeze: [ 133.00 134.00 ]
Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
Key: VoidTy: [ 1.50 2.00 ]
Key: LabelTy: [ 2.50 3.00 ]
Key: MetadataTy: [ 3.50 4.00 ]
-Key: UnknownTy: [ 4.50 5.00 ]
+Key: VectorTy: [ 11.50 12.00 ]
Key: TokenTy: [ 5.50 6.00 ]
Key: IntegerTy: [ 6.50 7.00 ]
Key: FunctionTy: [ 7.50 8.00 ]
Key: PointerTy: [ 8.50 9.00 ]
Key: StructTy: [ 9.50 10.00 ]
Key: ArrayTy: [ 10.50 11.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: PointerTy: [ 8.50 9.00 ]
Key: UnknownTy: [ 4.50 5.00 ]
Key: Function: [ 0.20 0.40 ]
Key: Pointer: [ 0.60 0.80 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index 9673e7f23fa5c..f3ce809fd2fd2 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 64.50 65.00 ]
Key: LandingPad: [ 65.50 66.00 ]
Key: Freeze: [ 66.50 67.00 ]
Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
-Key: FloatTy: [ 0.50 1.00 ]
Key: VoidTy: [ 1.50 2.00 ]
Key: LabelTy: [ 2.50 3.00 ]
Key: MetadataTy: [ 3.50 4.00 ]
-Key: UnknownTy: [ 4.50 5.00 ]
+Key: VectorTy: [ 11.50 12.00 ]
Key: TokenTy: [ 5.50 6.00 ]
Key: IntegerTy: [ 6.50 7.00 ]
Key: FunctionTy: [ 7.50 8.00 ]
Key: PointerTy: [ 8.50 9.00 ]
Key: StructTy: [ 9.50 10.00 ]
Key: ArrayTy: [ 10.50 11.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: VectorTy: [ 11.50 12.00 ]
-Key: PointerTy: [ 8.50 9.00 ]
Key: UnknownTy: [ 4.50 5.00 ]
Key: Function: [ 0.50 1.00 ]
Key: Pointer: [ 1.50 2.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 1f575d29092dd..72b25b9bd3d9c 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue: [ 12.90 13.00 ]
Key: LandingPad: [ 13.10 13.20 ]
Key: Freeze: [ 13.30 13.40 ]
Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
-Key: FloatTy: [ 0.00 0.00 ]
Key: VoidTy: [ 0.00 0.00 ]
Key: LabelTy: [ 0.00 0.00 ]
Key: MetadataTy: [ 0.00 0.00 ]
-Key: UnknownTy: [ ...
[truncated]
|
class Vocabulary { | ||
friend class llvm::IR2VecVocabAnalysis; | ||
using VocabVector = std::vector<ir2vec::Embedding>; | ||
VocabVector Vocab; | ||
bool Valid = false; | ||
|
||
public: | ||
// Slot layout: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment is about an internal implementation detail, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move it to private?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this patch, but curious (and if it makes sense, it can be addressed later): why not 3 vectors rather than indexing within one? Also, IIUC the tools have some understanding that this is the case (there's one vector with slots) and thus need to use getSlotIndex. Is this an artifact of how serialization happens? Maybe that can be captured in the comment here (to be clear, I don't think 3 vectors vs one is better or worse, just a bit of a confusing design choice so a comment may help)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we can do this refactoring. But, the tool needs consecutive indexing of entities while dumping the triplets. So, this logic should either be in the ir2vec::Vocabulary or should be moved to the tool.
66ad8aa
to
267e2da
Compare
@@ -162,8 +162,8 @@ class IR2VecTool { | |||
|
|||
for (const BasicBlock &BB : F) { | |||
for (const auto &I : BB.instructionsWithoutDebug()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit for later: should the iteration over BBs be a utility in ir2vec somehow - basically the reflex is to do
for (const auto &BB : F)
for (const auto &I : BB)
// do stuff to I
meaning that BB.instructionsWithoutDebug()
is not the immediately discoverable / the first place one goes to when coding.
One idea (again, later patch): what if the ir2vec APIs like getSlotIndex would return a "null" value for debug info - this can be a configurable option of the vocab or something like that - which, when used through ir2vec, it'd have no effect (like its embedding would be the 0 tensor, for instance)
just a thought. noop here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it makes sense. Will track it separately.
8ce3c84
to
d6cc948
Compare
7037453
to
9b851b9
Compare
d6cc948
to
3680f38
Compare
9b851b9
to
fe1463e
Compare
3680f38
to
5dc4c48
Compare
5dc4c48
to
3e10c63
Compare
fe1463e
to
d70182f
Compare
3e10c63
to
bd8e431
Compare
d70182f
to
01b9019
Compare
bd8e431
to
0d1fb80
Compare
2fd070b
to
2dd49e7
Compare
3cdc0d5
to
a5fc0c3
Compare
2dd49e7
to
a73389b
Compare
a5fc0c3
to
6dea185
Compare
3c6e312
to
1f5c4ac
Compare
Merge activity
|
a53c559
to
e727d58
Compare
e727d58
to
cb881b6
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/18347 Here is the relevant piece of the build log for the reference
|
Refactor IR2Vec vocabulary to use canonical type IDs, improving the embedding representation for LLVM IR types.
The previous implementation used raw Type::TypeID values directly in the vocabulary, which led to redundant entries (e.g., all float variants mapped to "FloatTy" but had separate slots). This change improves the vocabulary by:
(Tracking issue - #141817)