Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 81a84b2

Browse files
committed
VocabStorage
1 parent 52875ac commit 81a84b2

File tree

6 files changed

+570
-122
lines changed

6 files changed

+570
-122
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 114 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "llvm/Support/JSON.h"
4646
#include <array>
4747
#include <map>
48+
#include <optional>
4849

4950
namespace llvm {
5051

@@ -144,6 +145,73 @@ struct Embedding {
144145
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
145146
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
146147

148+
/// Generic storage class for section-based vocabularies.
149+
/// VocabStorage provides a generic foundation for storing and accessing
150+
/// embeddings organized into sections.
151+
class VocabStorage {
152+
private:
153+
/// Section-based storage
154+
std::vector<std::vector<Embedding>> Sections;
155+
156+
size_t TotalSize = 0;
157+
unsigned Dimension = 0;
158+
159+
public:
160+
/// Default constructor creates empty storage (invalid state)
161+
VocabStorage() : Sections(), TotalSize(0), Dimension(0) {}
162+
163+
/// Create a VocabStorage with pre-organized section data
164+
VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
165+
166+
VocabStorage(VocabStorage &&) = default;
167+
VocabStorage &operator=(VocabStorage &&Other);
168+
169+
VocabStorage(const VocabStorage &) = delete;
170+
VocabStorage &operator=(const VocabStorage &) = delete;
171+
172+
/// Get total number of entries across all sections
173+
size_t size() const { return TotalSize; }
174+
175+
/// Get number of sections
176+
unsigned getNumSections() const {
177+
return static_cast<unsigned>(Sections.size());
178+
}
179+
180+
/// Section-based access: Storage[sectionId][localIndex]
181+
const std::vector<Embedding> &operator[](unsigned SectionId) const {
182+
assert(SectionId < Sections.size() && "Invalid section ID");
183+
return Sections[SectionId];
184+
}
185+
186+
/// Get vocabulary dimension
187+
unsigned getDimension() const { return Dimension; }
188+
189+
/// Check if vocabulary is valid (has data)
190+
bool isValid() const { return TotalSize > 0; }
191+
192+
/// Iterator support for section-based access
193+
class const_iterator {
194+
const VocabStorage *Storage;
195+
unsigned SectionId;
196+
size_t LocalIndex;
197+
198+
public:
199+
const_iterator(const VocabStorage *Storage, unsigned SectionId,
200+
size_t LocalIndex)
201+
: Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
202+
203+
LLVM_ABI const Embedding &operator*() const;
204+
LLVM_ABI const_iterator &operator++();
205+
LLVM_ABI bool operator==(const const_iterator &Other) const;
206+
LLVM_ABI bool operator!=(const const_iterator &Other) const;
207+
};
208+
209+
const_iterator begin() const { return const_iterator(this, 0, 0); }
210+
const_iterator end() const {
211+
return const_iterator(this, getNumSections(), 0);
212+
}
213+
};
214+
147215
/// Class for storing and accessing the IR2Vec vocabulary.
148216
/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
149217
/// seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
164232
class Vocabulary {
165233
friend class llvm::IR2VecVocabAnalysis;
166234

167-
// Vocabulary Slot Layout:
235+
// Vocabulary Layout:
168236
// +----------------+------------------------------------------------------+
169237
// | Entity Type | Index Range |
170238
// +----------------+------------------------------------------------------+
@@ -175,8 +243,16 @@ class Vocabulary {
175243
// Note: "Similar" LLVM Types are grouped/canonicalized together.
176244
// Operands include Comparison predicates (ICmp/FCmp).
177245
// This can be extended to include other specializations in future.
178-
using VocabVector = std::vector<ir2vec::Embedding>;
179-
VocabVector Vocab;
246+
enum class Section : unsigned {
247+
Opcodes = 0,
248+
CanonicalTypes = 1,
249+
Operands = 2,
250+
Predicates = 3,
251+
MaxSections
252+
};
253+
254+
// Use section-based storage for better organization and efficiency
255+
VocabStorage Storage;
180256

181257
static constexpr unsigned NumICmpPredicates =
182258
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
@@ -228,9 +304,18 @@ class Vocabulary {
228304
NumICmpPredicates + NumFCmpPredicates;
229305

230306
Vocabulary() = default;
231-
LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
307+
LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {}
308+
309+
Vocabulary(const Vocabulary &) = delete;
310+
Vocabulary &operator=(const Vocabulary &) = delete;
311+
312+
Vocabulary(Vocabulary &&) = default;
313+
Vocabulary &operator=(Vocabulary &&Other);
314+
315+
LLVM_ABI bool isValid() const {
316+
return Storage.size() == NumCanonicalEntries;
317+
}
232318

233-
LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
234319
LLVM_ABI unsigned getDimension() const;
235320
/// Total number of entries (opcodes + canonicalized types + operand kinds +
236321
/// predicates)
@@ -251,12 +336,11 @@ class Vocabulary {
251336
/// Function to get vocabulary key for a given predicate
252337
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
253338

254-
/// Functions to return the slot index or position of a given Opcode, TypeID,
255-
/// or OperandKind in the vocabulary.
256-
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
257-
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
258-
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
259-
LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
339+
/// Functions to return flat index
340+
LLVM_ABI static unsigned getIndex(unsigned Opcode);
341+
LLVM_ABI static unsigned getIndex(Type::TypeID TypeID);
342+
LLVM_ABI static unsigned getIndex(const Value &Op);
343+
LLVM_ABI static unsigned getIndex(CmpInst::Predicate P);
260344

261345
/// Accessors to get the embedding for a given entity.
262346
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
@@ -265,34 +349,29 @@ class Vocabulary {
265349
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
266350

267351
/// Const Iterator type aliases
268-
using const_iterator = VocabVector::const_iterator;
352+
using const_iterator = VocabStorage::const_iterator;
353+
269354
const_iterator begin() const {
270355
assert(isValid() && "IR2Vec Vocabulary is invalid");
271-
return Vocab.begin();
356+
return Storage.begin();
272357
}
273358

274-
const_iterator cbegin() const {
275-
assert(isValid() && "IR2Vec Vocabulary is invalid");
276-
return Vocab.cbegin();
277-
}
359+
const_iterator cbegin() const { return begin(); }
278360

279361
const_iterator end() const {
280362
assert(isValid() && "IR2Vec Vocabulary is invalid");
281-
return Vocab.end();
363+
return Storage.end();
282364
}
283365

284-
const_iterator cend() const {
285-
assert(isValid() && "IR2Vec Vocabulary is invalid");
286-
return Vocab.cend();
287-
}
366+
const_iterator cend() const { return end(); }
288367

289368
/// Returns the string key for a given index position in the vocabulary.
290369
/// This is useful for debugging or printing the vocabulary. Do not use this
291370
/// for embedding generation as string based lookups are inefficient.
292371
LLVM_ABI static StringRef getStringKey(unsigned Pos);
293372

294373
/// Create a dummy vocabulary for testing purposes.
295-
LLVM_ABI static VocabVector createDummyVocabForTest(unsigned Dim = 1);
374+
LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1);
296375

297376
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
298377
ModuleAnalysisManager::Invalidator &Inv) const;
@@ -301,12 +380,16 @@ class Vocabulary {
301380
constexpr static unsigned NumCanonicalEntries =
302381
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
303382

304-
// Base offsets for slot layout to simplify index computation
383+
// Base offsets for flat index computation
305384
constexpr static unsigned OperandBaseOffset =
306385
MaxOpcodes + MaxCanonicalTypeIDs;
307386
constexpr static unsigned PredicateBaseOffset =
308387
OperandBaseOffset + MaxOperandKinds;
309388

389+
/// Functions for predicate index calculations
390+
static unsigned getPredicateLocalIndex(CmpInst::Predicate P);
391+
static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex);
392+
310393
/// String mappings for CanonicalTypeID values
311394
static constexpr StringLiteral CanonicalTypeNames[] = {
312395
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
@@ -452,22 +535,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
452535
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
453536
/// its corresponding embedding.
454537
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
455-
using VocabVector = std::vector<ir2vec::Embedding>;
456538
using VocabMap = std::map<std::string, ir2vec::Embedding>;
457-
VocabMap OpcVocab, TypeVocab, ArgVocab;
458-
VocabVector Vocab;
539+
std::optional<ir2vec::VocabStorage> Vocab;
459540

460-
Error readVocabulary();
541+
Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
542+
VocabMap &ArgVocab);
461543
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
462544
VocabMap &TargetVocab, unsigned &Dim);
463-
void generateNumMappedVocab();
545+
void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
546+
VocabMap &ArgVocab);
464547
void emitError(Error Err, LLVMContext &Ctx);
465548

466549
public:
467550
LLVM_ABI static AnalysisKey Key;
468551
IR2VecVocabAnalysis() = default;
469-
LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
470-
LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
552+
LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
553+
: Vocab(std::move(Vocab)) {}
471554
using Result = ir2vec::Vocabulary;
472555
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM);
473556
};

0 commit comments

Comments
 (0)