Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Aug 25, 2025

Refactor IR2Vec vocabulary to use canonical type IDs, improving the embedding representation for LLVM IR types.

The previous implementation used raw Type::TypeID values directly in the vocabulary, which led to redundant entries (e.g., all float variants mapped to "FloatTy" but had separate slots). This change improves the vocabulary by:

  1. Making the type representation more consistent by properly canonicalizing types
  2. Reducing vocabulary size by eliminating redundant entries
  3. Improving the embedding quality by ensuring similar types share the same representation

(Tracking issue - #141817)

@svkeerthy svkeerthy changed the title Canonicalized type [IR2Vec] Refactor vocabulary to use canonical type IDs Aug 25, 2025
@svkeerthy svkeerthy marked this pull request as ready for review August 25, 2025 23:03
@llvmbot llvmbot added mlgo llvm:analysis Includes value tracking, cost tables and constant folding labels Aug 25, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 25, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

Refactor IR2Vec vocabulary to use canonical type IDs, improving the embedding representation for LLVM IR types.

The previous implementation used raw Type::TypeID values directly in the vocabulary, which led to redundant entries (e.g., all float variants mapped to "FloatTy" but had separate slots). This change improves the vocabulary by:

  1. Making the type representation more consistent by properly canonicalizing types
  2. Reducing vocabulary size by eliminating redundant entries
  3. Improving the embedding quality by ensuring similar types share the same representation

Patch is 41.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155323.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+101-17)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+65-91)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt (+1-10)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt (+1-10)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt (+1-10)
  • (modified) llvm/test/tools/llvm-ir2vec/entities.ll (+16-25)
  • (modified) llvm/test/tools/llvm-ir2vec/triplets.ll (+29-29)
  • (modified) llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp (+4-10)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+117-95)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 7ace83ba1d053..eb3860ea1e488 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/ErrorOr.h"
 #include "llvm/Support/JSON.h"
+#include <array>
 #include <map>
 
 namespace llvm {
@@ -137,13 +138,48 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
 using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 
 /// Class for storing and accessing the IR2Vec vocabulary.
-/// Encapsulates all vocabulary-related constants, logic, and access methods.
+///
+/// The Vocabulary class manages seed embeddings for LLVM IR entities. It
+/// contains the seed embeddings for three types of entities: instruction
+/// opcodes, types, and operands. Types are grouped/canonicalized for better
+/// learning (e.g., all float variants map to FloatTy). The vocabulary abstracts
+/// away the canonicalization effectively, the exposed APIs handle all the known
+/// LLVM IR opcodes, types and operands.
+///
+/// This class helps populate the seed embeddings in an internal vector-based
+/// ADT. It provides logic to map every IR entity to a specific slot index or
+/// position in this vector, enabling O(1) embedding lookup while avoiding
+/// unnecessary computations involving string based lookups while generating the
+/// embeddings.
 class Vocabulary {
   friend class llvm::IR2VecVocabAnalysis;
   using VocabVector = std::vector<ir2vec::Embedding>;
   VocabVector Vocab;
   bool Valid = false;
 
+public:
+  // Slot layout:
+  // [0 .. MaxOpcodes-1]               => Instruction opcodes
+  // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
+  // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
+
+  /// Canonical type IDs supported by IR2Vec Vocabulary
+  enum class CanonicalTypeID : unsigned {
+    FloatTy,
+    VoidTy,
+    LabelTy,
+    MetadataTy,
+    VectorTy,
+    TokenTy,
+    IntegerTy,
+    FunctionTy,
+    PointerTy,
+    StructTy,
+    ArrayTy,
+    UnknownTy,
+    MaxCanonicalType
+  };
+
   /// Operand kinds supported by IR2Vec Vocabulary
   enum class OperandKind : unsigned {
     FunctionID,
@@ -152,20 +188,15 @@ class Vocabulary {
     VariableID,
     MaxOperandKind
   };
-  /// String mappings for OperandKind values
-  static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
-                                                       "Constant", "Variable"};
-  static_assert(std::size(OperandKindNames) ==
-                    static_cast<unsigned>(OperandKind::MaxOperandKind),
-                "OperandKindNames array size must match MaxOperandKind");
 
-public:
   /// Vocabulary layout constants
 #define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
 #include "llvm/IR/Instruction.def"
 #undef LAST_OTHER_INST
 
   static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
+  static constexpr unsigned MaxCanonicalTypeIDs =
+      static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
   static constexpr unsigned MaxOperandKinds =
       static_cast<unsigned>(OperandKind::MaxOperandKind);
 
@@ -174,11 +205,8 @@ class Vocabulary {
 
   LLVM_ABI bool isValid() const;
   LLVM_ABI unsigned getDimension() const;
-  LLVM_ABI size_t size() const;
-
-  static size_t expectedSize() {
-    return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
-  }
+  /// Total number of entries (opcodes + canonicalized types + operand kinds)
+  static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
 
   /// Helper function to get vocabulary key for a given Opcode
   LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
@@ -192,10 +220,11 @@ class Vocabulary {
   /// Helper function to classify an operand into OperandKind
   LLVM_ABI static OperandKind getOperandKind(const Value *Op);
 
-  /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
-  LLVM_ABI static unsigned getNumericID(unsigned Opcode);
-  LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
-  LLVM_ABI static unsigned getNumericID(const Value *Op);
+  /// Helpers to return the slot index or position of a given Opcode, TypeID, or
+  /// OperandKind in the vocabulary.
+  LLVM_ABI static unsigned getSlotIdx(unsigned Opcode);
+  LLVM_ABI static unsigned getSlotIdx(Type::TypeID TypeID);
+  LLVM_ABI static unsigned getSlotIdx(const Value *Op);
 
   /// Accessors to get the embedding for a given entity.
   LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
@@ -234,6 +263,61 @@ class Vocabulary {
 
   LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
                            ModuleAnalysisManager::Invalidator &Inv) const;
+
+private:
+  constexpr static unsigned NumCanonicalEntries =
+      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+
+  /// String mappings for CanonicalTypeID values
+  static constexpr StringLiteral CanonicalTypeNames[] = {
+      "FloatTy",   "VoidTy",   "LabelTy",   "MetadataTy",
+      "VectorTy",  "TokenTy",  "IntegerTy", "FunctionTy",
+      "PointerTy", "StructTy", "ArrayTy",   "UnknownTy"};
+  static_assert(std::size(CanonicalTypeNames) ==
+                    static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
+                "CanonicalTypeNames array size must match MaxCanonicalType");
+
+  /// String mappings for OperandKind values
+  static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
+                                                       "Constant", "Variable"};
+  static_assert(std::size(OperandKindNames) ==
+                    static_cast<unsigned>(OperandKind::MaxOperandKind),
+                "OperandKindNames array size must match MaxOperandKind");
+
+  /// Every known TypeID defined in llvm/IR/Type.h is expected to have a
+  /// corresponding mapping here in the same order as enum Type::TypeID.
+  static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
+      CanonicalTypeID::FloatTy,    // HalfTyID = 0
+      CanonicalTypeID::FloatTy,    // BFloatTyID
+      CanonicalTypeID::FloatTy,    // FloatTyID
+      CanonicalTypeID::FloatTy,    // DoubleTyID
+      CanonicalTypeID::FloatTy,    // X86_FP80TyID
+      CanonicalTypeID::FloatTy,    // FP128TyID
+      CanonicalTypeID::FloatTy,    // PPC_FP128TyID
+      CanonicalTypeID::VoidTy,     // VoidTyID
+      CanonicalTypeID::LabelTy,    // LabelTyID
+      CanonicalTypeID::MetadataTy, // MetadataTyID
+      CanonicalTypeID::VectorTy,   // X86_AMXTyID
+      CanonicalTypeID::TokenTy,    // TokenTyID
+      CanonicalTypeID::IntegerTy,  // IntegerTyID
+      CanonicalTypeID::FunctionTy, // FunctionTyID
+      CanonicalTypeID::PointerTy,  // PointerTyID
+      CanonicalTypeID::StructTy,   // StructTyID
+      CanonicalTypeID::ArrayTy,    // ArrayTyID
+      CanonicalTypeID::VectorTy,   // FixedVectorTyID
+      CanonicalTypeID::VectorTy,   // ScalableVectorTyID
+      CanonicalTypeID::PointerTy,  // TypedPointerTyID
+      CanonicalTypeID::UnknownTy   // TargetExtTyID
+  }};
+  static_assert(TypeIDMapping.size() == MaxTypeIDs,
+                "TypeIDMapping must cover all Type::TypeID values");
+
+  /// Helper function to get vocabulary key for canonical type by enum
+  LLVM_ABI static StringRef
+  getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
+
+  /// Helper function to convert TypeID to CanonicalTypeID
+  LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
 };
 
 /// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index e28938b64bfdb..e92dfca0c4ac6 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -32,7 +32,7 @@ using namespace ir2vec;
 #define DEBUG_TYPE "ir2vec"
 
 STATISTIC(VocabMissCounter,
-          "Number of lookups to entites not present in the vocabulary");
+          "Number of lookups to entities not present in the vocabulary");
 
 namespace llvm {
 namespace ir2vec {
@@ -264,12 +264,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
     : Vocab(std::move(Vocab)), Valid(true) {}
 
 bool Vocabulary::isValid() const {
-  return Vocab.size() == Vocabulary::expectedSize() && Valid;
-}
-
-size_t Vocabulary::size() const {
-  assert(Valid && "IR2Vec Vocabulary is invalid");
-  return Vocab.size();
+  return Vocab.size() == NumCanonicalEntries && Valid;
 }
 
 unsigned Vocabulary::getDimension() const {
@@ -277,19 +272,32 @@ unsigned Vocabulary::getDimension() const {
   return Vocab[0].size();
 }
 
-const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+unsigned Vocabulary::getSlotIdx(unsigned Opcode) {
   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
-  return Vocab[Opcode - 1];
+  return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getSlotIdx(Type::TypeID TypeID) {
+  assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+  return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
+}
+
+unsigned Vocabulary::getSlotIdx(const Value *Op) {
+  unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+  assert(Index < MaxOperandKinds && "Invalid OperandKind");
+  return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+}
+
+const Embedding &Vocabulary::operator[](unsigned Opcode) const {
+  return Vocab[getSlotIdx(Opcode)];
 }
 
-const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
-  assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
-  return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
+const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
+  return Vocab[getSlotIdx(TypeID)];
 }
 
 const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
-  OperandKind ArgKind = getOperandKind(Arg);
-  return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
+  return Vocab[getSlotIdx(Arg)];
 }
 
 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
@@ -303,43 +311,21 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
   return "UnknownOpcode";
 }
 
+StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
+  unsigned Index = static_cast<unsigned>(CType);
+  assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
+  return CanonicalTypeNames[Index];
+}
+
+Vocabulary::CanonicalTypeID
+Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) {
+  unsigned Index = static_cast<unsigned>(TypeID);
+  assert(Index < MaxTypeIDs && "Invalid TypeID");
+  return TypeIDMapping[Index];
+}
+
 StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
-  switch (TypeID) {
-  case Type::VoidTyID:
-    return "VoidTy";
-  case Type::HalfTyID:
-  case Type::BFloatTyID:
-  case Type::FloatTyID:
-  case Type::DoubleTyID:
-  case Type::X86_FP80TyID:
-  case Type::FP128TyID:
-  case Type::PPC_FP128TyID:
-    return "FloatTy";
-  case Type::IntegerTyID:
-    return "IntegerTy";
-  case Type::FunctionTyID:
-    return "FunctionTy";
-  case Type::StructTyID:
-    return "StructTy";
-  case Type::ArrayTyID:
-    return "ArrayTy";
-  case Type::PointerTyID:
-  case Type::TypedPointerTyID:
-    return "PointerTy";
-  case Type::FixedVectorTyID:
-  case Type::ScalableVectorTyID:
-    return "VectorTy";
-  case Type::LabelTyID:
-    return "LabelTy";
-  case Type::TokenTyID:
-    return "TokenTy";
-  case Type::MetadataTyID:
-    return "MetadataTy";
-  case Type::X86_AMXTyID:
-  case Type::TargetExtTyID:
-    return "UnknownTy";
-  }
-  return "UnknownTy";
+  return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
 }
 
 StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -348,20 +334,6 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
   return OperandKindNames[Index];
 }
 
-Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
-  VocabVector DummyVocab;
-  float DummyVal = 0.1f;
-  // Create a dummy vocabulary with entries for all opcodes, types, and
-  // operand
-  for ([[maybe_unused]] unsigned _ :
-       seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
-                   Vocabulary::MaxOperandKinds)) {
-    DummyVocab.push_back(Embedding(Dim, DummyVal));
-    DummyVal += 0.1f;
-  }
-  return DummyVocab;
-}
-
 // Helper function to classify an operand into OperandKind
 Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   if (isa<Function>(Op))
@@ -373,34 +345,18 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   return OperandKind::VariableID;
 }
 
-unsigned Vocabulary::getNumericID(unsigned Opcode) {
-  assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
-  return Opcode - 1; // Convert to zero-based index
-}
-
-unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
-  assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
-  return MaxOpcodes + static_cast<unsigned>(TypeID);
-}
-
-unsigned Vocabulary::getNumericID(const Value *Op) {
-  unsigned Index = static_cast<unsigned>(getOperandKind(Op));
-  assert(Index < MaxOperandKinds && "Invalid OperandKind");
-  return MaxOpcodes + MaxTypeIDs + Index;
-}
-
 StringRef Vocabulary::getStringKey(unsigned Pos) {
-  assert(Pos < Vocabulary::expectedSize() &&
-         "Position out of bounds in vocabulary");
+  assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
   // Opcode
   if (Pos < MaxOpcodes)
     return getVocabKeyForOpcode(Pos + 1);
   // Type
-  if (Pos < MaxOpcodes + MaxTypeIDs)
-    return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
+  if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+    return getVocabKeyForCanonicalTypeID(
+        static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
   // Operand
   return getVocabKeyForOperandKind(
-      static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
+      static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
 }
 
 // For now, assume vocabulary is stable unless explicitly invalidated.
@@ -410,6 +366,21 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
   return !(PAC.preservedWhenStateless());
 }
 
+Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
+  VocabVector DummyVocab;
+  DummyVocab.reserve(NumCanonicalEntries);
+  float DummyVal = 0.1f;
+  // Create a dummy vocabulary with entries for all opcodes, types, and
+  // operands
+  for ([[maybe_unused]] unsigned _ :
+       seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
+                   Vocabulary::MaxOperandKinds)) {
+    DummyVocab.push_back(Embedding(Dim, DummyVal));
+    DummyVal += 0.1f;
+  }
+  return DummyVocab;
+}
+
 // ==----------------------------------------------------------------------===//
 // IR2VecVocabAnalysis
 //===----------------------------------------------------------------------===//
@@ -502,6 +473,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   // Handle Opcodes
   std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
                                                  Embedding(Dim, 0));
+  NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
   for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
     StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
     auto It = OpcVocab.find(VocabKey.str());
@@ -513,14 +485,15 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
                NumericOpcodeEmbeddings.end());
 
-  // Handle Types
-  std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
+  // Handle Types - only canonical types are present in vocabulary
+  std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
                                                Embedding(Dim, 0));
-  for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
-    StringRef VocabKey =
-        Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
+  NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
+  for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
+    StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
+        static_cast<Vocabulary::CanonicalTypeID>(CTypeID));
     if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
-      NumericTypeEmbeddings[TypeID] = It->second;
+      NumericTypeEmbeddings[CTypeID] = It->second;
       continue;
     }
     handleMissingEntity(VocabKey.str());
@@ -531,6 +504,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   // Handle Arguments/Operands
   std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
                                               Embedding(Dim, 0));
+  NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
   for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
     Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
     StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index 1b9b3c2acd8a5..df7769c9c6a65 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue:  [ 129.00  130.00 ]
 Key: LandingPad:  [ 131.00  132.00 ]
 Key: Freeze:  [ 133.00  134.00 ]
 Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
 Key: VoidTy:  [ 1.50  2.00 ]
 Key: LabelTy:  [ 2.50  3.00 ]
 Key: MetadataTy:  [ 3.50  4.00 ]
-Key: UnknownTy:  [ 4.50  5.00 ]
+Key: VectorTy:  [ 11.50  12.00 ]
 Key: TokenTy:  [ 5.50  6.00 ]
 Key: IntegerTy:  [ 6.50  7.00 ]
 Key: FunctionTy:  [ 7.50  8.00 ]
 Key: PointerTy:  [ 8.50  9.00 ]
 Key: StructTy:  [ 9.50  10.00 ]
 Key: ArrayTy:  [ 10.50  11.00 ]
-Key: VectorTy:  [ 11.50  12.00 ]
-Key: VectorTy:  [ 11.50  12.00 ]
-Key: PointerTy:  [ 8.50  9.00 ]
 Key: UnknownTy:  [ 4.50  5.00 ]
 Key: Function:  [ 0.20  0.40 ]
 Key: Pointer:  [ 0.60  0.80 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index 9673e7f23fa5c..f3ce809fd2fd2 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue:  [ 64.50  65.00 ]
 Key: LandingPad:  [ 65.50  66.00 ]
 Key: Freeze:  [ 66.50  67.00 ]
 Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
-Key: FloatTy:  [ 0.50  1.00 ]
 Key: VoidTy:  [ 1.50  2.00 ]
 Key: LabelTy:  [ 2.50  3.00 ]
 Key: MetadataTy:  [ 3.50  4.00 ]
-Key: UnknownTy:  [ 4.50  5.00 ]
+Key: VectorTy:  [ 11.50  12.00 ]
 Key: TokenTy:  [ 5.50  6.00 ]
 Key: IntegerTy:  [ 6.50  7.00 ]
 Key: FunctionTy:  [ 7.50  8.00 ]
 Key: PointerTy:  [ 8.50  9.00 ]
 Key: StructTy:  [ 9.50  10.00 ]
 Key: ArrayTy:  [ 10.50  11.00 ]
-Key: VectorTy:  [ 11.50  12.00 ]
-Key: VectorTy:  [ 11.50  12.00 ]
-Key: PointerTy:  [ 8.50  9.00 ]
 Key: UnknownTy:  [ 4.50  5.00 ]
 Key: Function:  [ 0.50  1.00 ]
 Key: Pointer:  [ 1.50  2.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 1f575d29092dd..72b25b9bd3d9c 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -67,25 +67,16 @@ Key: InsertValue:  [ 12.90  13.00 ]
 Key: LandingPad:  [ 13.10  13.20 ]
 Key: Freeze:  [ 13.30  13.40 ]
 Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
-Key: FloatTy:  [ 0.00  0.00 ]
 Key: VoidTy:  [ 0.00  0.00 ]
 Key: LabelTy:  [ 0.00  0.00 ]
 Key: MetadataTy:  [ 0.00  0.00 ]
-Key: UnknownTy:  [ ...
[truncated]

class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
bool Valid = false;

public:
// Slot layout:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is about an internal implementation detail, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move it to private?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this patch, but curious (and if it makes sense, it can be addressed later): why not 3 vectors rather than indexing within one? Also, IIUC the tools have some understanding that this is the case (there's one vector with slots) and thus need to use getSlotIndex. Is this an artifact of how serialization happens? Maybe that can be captured in the comment here (to be clear, I don't think 3 vectors vs one is better or worse, just a bit of a confusing design choice so a comment may help)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can do this refactoring. But, the tool needs consecutive indexing of entities while dumping the triplets. So, this logic should either be in the ir2vec::Vocabulary or should be moved to the tool.

@@ -162,8 +162,8 @@ class IR2VecTool {

for (const BasicBlock &BB : F) {
for (const auto &I : BB.instructionsWithoutDebug()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit for later: should the iteration over BBs be a utility in ir2vec somehow - basically the reflex is to do

for (const auto &BB : F)
  for (const auto &I : BB)
     // do stuff to I

meaning that BB.instructionsWithoutDebug() is not the immediately discoverable / the first place one goes to when coding.

One idea (again, later patch): what if the ir2vec APIs like getSlotIndex would return a "null" value for debug info - this can be a configurable option of the vocab or something like that - which, when used through ir2vec, it'd have no effect (like its embedding would be the 0 tensor, for instance)

just a thought. noop here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it makes sense. Will track it separately.

@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests branch from 8ce3c84 to d6cc948 Compare August 27, 2025 19:40
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch 2 times, most recently from 7037453 to 9b851b9 Compare August 27, 2025 20:06
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests branch from d6cc948 to 3680f38 Compare August 27, 2025 20:06
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch from 9b851b9 to fe1463e Compare August 27, 2025 21:04
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests branch from 3680f38 to 5dc4c48 Compare August 27, 2025 21:04
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests branch from 5dc4c48 to 3e10c63 Compare August 28, 2025 19:58
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch from fe1463e to d70182f Compare August 28, 2025 19:59
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests branch from 3e10c63 to bd8e431 Compare August 28, 2025 23:04
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch from d70182f to 01b9019 Compare August 28, 2025 23:04
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests branch from bd8e431 to 0d1fb80 Compare August 28, 2025 23:52
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch 2 times, most recently from 2fd070b to 2dd49e7 Compare August 29, 2025 00:37
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests branch 2 times, most recently from 3cdc0d5 to a5fc0c3 Compare August 29, 2025 18:51
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch from 2dd49e7 to a73389b Compare August 29, 2025 18:51
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests branch from a5fc0c3 to 6dea185 Compare August 29, 2025 19:31
Base automatically changed from users/svkeerthy/08-11-_nfc_ir2vec_add_missed_ptrtoaddr_in_vocab_for_tests to main August 29, 2025 20:05
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch 2 times, most recently from 3c6e312 to 1f5c4ac Compare August 29, 2025 20:38
Copy link
Contributor Author

svkeerthy commented Aug 29, 2025

Merge activity

  • Aug 29, 8:38 PM UTC: Graphite rebased this pull request as part of a merge.
  • Aug 29, 9:03 PM UTC: Graphite rebased this pull request as part of a merge.
  • Aug 29, 9:06 PM UTC: Graphite rebased this pull request as part of a merge.
  • Aug 29, 9:30 PM UTC: Graphite rebased this pull request as part of a merge.
  • Aug 29, 9:56 PM UTC: @svkeerthy merged this pull request with Graphite.

@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch 2 times, most recently from a53c559 to e727d58 Compare August 29, 2025 21:06
@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-25-canonicalized_type branch from e727d58 to cb881b6 Compare August 29, 2025 21:30
@svkeerthy svkeerthy merged commit 45c5498 into main Aug 29, 2025
9 checks passed
@svkeerthy svkeerthy deleted the users/svkeerthy/08-25-canonicalized_type branch August 29, 2025 21:56
@llvm-ci
Copy link
Collaborator

llvm-ci commented Aug 29, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia running on mlir-nvidia while building llvm at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/18347

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/GPU/CUDA/async.mlir' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-kernel-outlining  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary="format=fatbin"  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-runner    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_cuda_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_async_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_runner_utils.so    --entry-point-result=void -O0  | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-kernel-outlining
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt '-pass-pipeline=builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary=format=fatbin
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/mlir-runner --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_cuda_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_async_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/lib/libmlir_runner_utils.so --entry-point-result=void -O0
# .---command stderr------------
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventSynchronize(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# `-----------------------------
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# .---command stderr------------
# | /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir:68:12: error: CHECK: expected string not found in input
# |  // CHECK: [84, 84]
# |            ^
# | <stdin>:1:1: note: scanning from here
# | Unranked Memref base@ = 0x589184959af0 rank = 1 offset = 0 sizes = [2] strides = [1] data = 
# | ^
# | <stdin>:2:1: note: possible intended match here
# | [42, 42]
# | ^
# | 
# | Input file: <stdin>
# | Check file: /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |             1: Unranked Memref base@ = 0x589184959af0 rank = 1 offset = 0 sizes = [2] strides = [1] data =  
# | check:68'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
# |             2: [42, 42] 
# | check:68'0     ~~~~~~~~~
# | check:68'1     ?         possible intended match
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding mlgo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants