45
45
#include " llvm/Support/JSON.h"
46
46
#include < array>
47
47
#include < map>
48
+ #include < optional>
48
49
49
50
namespace llvm {
50
51
@@ -144,6 +145,73 @@ struct Embedding {
144
145
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
145
146
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
146
147
148
+ // / Generic storage class for section-based vocabularies.
149
+ // / VocabStorage provides a generic foundation for storing and accessing
150
+ // / embeddings organized into sections.
151
+ class VocabStorage {
152
+ private:
153
+ // / Section-based storage
154
+ std::vector<std::vector<Embedding>> Sections;
155
+
156
+ size_t TotalSize = 0 ;
157
+ unsigned Dimension = 0 ;
158
+
159
+ public:
160
+ // / Default constructor creates empty storage (invalid state)
161
+ VocabStorage () : Sections(), TotalSize(0 ), Dimension(0 ) {}
162
+
163
+ // / Create a VocabStorage with pre-organized section data
164
+ VocabStorage (std::vector<std::vector<Embedding>> &&SectionData);
165
+
166
+ VocabStorage (VocabStorage &&) = default ;
167
+ VocabStorage &operator =(VocabStorage &&Other);
168
+
169
+ VocabStorage (const VocabStorage &) = delete ;
170
+ VocabStorage &operator =(const VocabStorage &) = delete ;
171
+
172
+ // / Get total number of entries across all sections
173
+ size_t size () const { return TotalSize; }
174
+
175
+ // / Get number of sections
176
+ unsigned getNumSections () const {
177
+ return static_cast <unsigned >(Sections.size ());
178
+ }
179
+
180
+ // / Section-based access: Storage[sectionId][localIndex]
181
+ const std::vector<Embedding> &operator [](unsigned SectionId) const {
182
+ assert (SectionId < Sections.size () && " Invalid section ID" );
183
+ return Sections[SectionId];
184
+ }
185
+
186
+ // / Get vocabulary dimension
187
+ unsigned getDimension () const { return Dimension; }
188
+
189
+ // / Check if vocabulary is valid (has data)
190
+ bool isValid () const { return TotalSize > 0 ; }
191
+
192
+ // / Iterator support for section-based access
193
+ class const_iterator {
194
+ const VocabStorage *Storage;
195
+ unsigned SectionId;
196
+ size_t LocalIndex;
197
+
198
+ public:
199
+ const_iterator (const VocabStorage *Storage, unsigned SectionId,
200
+ size_t LocalIndex)
201
+ : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
202
+
203
+ LLVM_ABI const Embedding &operator *() const ;
204
+ LLVM_ABI const_iterator &operator ++();
205
+ LLVM_ABI bool operator ==(const const_iterator &Other) const ;
206
+ LLVM_ABI bool operator !=(const const_iterator &Other) const ;
207
+ };
208
+
209
+ const_iterator begin () const { return const_iterator (this , 0 , 0 ); }
210
+ const_iterator end () const {
211
+ return const_iterator (this , getNumSections (), 0 );
212
+ }
213
+ };
214
+
147
215
// / Class for storing and accessing the IR2Vec vocabulary.
148
216
// / The Vocabulary class manages seed embeddings for LLVM IR entities. The
149
217
// / seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
164
232
class Vocabulary {
165
233
friend class llvm ::IR2VecVocabAnalysis;
166
234
167
- // Vocabulary Slot Layout:
235
+ // Vocabulary Layout:
168
236
// +----------------+------------------------------------------------------+
169
237
// | Entity Type | Index Range |
170
238
// +----------------+------------------------------------------------------+
@@ -175,8 +243,16 @@ class Vocabulary {
175
243
// Note: "Similar" LLVM Types are grouped/canonicalized together.
176
244
// Operands include Comparison predicates (ICmp/FCmp).
177
245
// This can be extended to include other specializations in future.
178
- using VocabVector = std::vector<ir2vec::Embedding>;
179
- VocabVector Vocab;
246
+ enum class Section : unsigned {
247
+ Opcodes = 0 ,
248
+ CanonicalTypes = 1 ,
249
+ Operands = 2 ,
250
+ Predicates = 3 ,
251
+ MaxSections
252
+ };
253
+
254
+ // Use section-based storage for better organization and efficiency
255
+ VocabStorage Storage;
180
256
181
257
static constexpr unsigned NumICmpPredicates =
182
258
static_cast <unsigned >(CmpInst::LAST_ICMP_PREDICATE) -
@@ -228,9 +304,18 @@ class Vocabulary {
228
304
NumICmpPredicates + NumFCmpPredicates;
229
305
230
306
Vocabulary () = default ;
231
- LLVM_ABI Vocabulary (VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
307
+ LLVM_ABI Vocabulary (VocabStorage &&Storage) : Storage(std::move(Storage)) {}
308
+
309
+ Vocabulary (const Vocabulary &) = delete ;
310
+ Vocabulary &operator =(const Vocabulary &) = delete ;
311
+
312
+ Vocabulary (Vocabulary &&) = default ;
313
+ Vocabulary &operator =(Vocabulary &&Other);
314
+
315
+ LLVM_ABI bool isValid () const {
316
+ return Storage.size () == NumCanonicalEntries;
317
+ }
232
318
233
- LLVM_ABI bool isValid () const { return Vocab.size () == NumCanonicalEntries; };
234
319
LLVM_ABI unsigned getDimension () const ;
235
320
// / Total number of entries (opcodes + canonicalized types + operand kinds +
236
321
// / predicates)
@@ -251,12 +336,11 @@ class Vocabulary {
251
336
// / Function to get vocabulary key for a given predicate
252
337
LLVM_ABI static StringRef getVocabKeyForPredicate (CmpInst::Predicate P);
253
338
254
- // / Functions to return the slot index or position of a given Opcode, TypeID,
255
- // / or OperandKind in the vocabulary.
256
- LLVM_ABI static unsigned getSlotIndex (unsigned Opcode);
257
- LLVM_ABI static unsigned getSlotIndex (Type::TypeID TypeID);
258
- LLVM_ABI static unsigned getSlotIndex (const Value &Op);
259
- LLVM_ABI static unsigned getSlotIndex (CmpInst::Predicate P);
339
+ // / Functions to return flat index
340
+ LLVM_ABI static unsigned getIndex (unsigned Opcode);
341
+ LLVM_ABI static unsigned getIndex (Type::TypeID TypeID);
342
+ LLVM_ABI static unsigned getIndex (const Value &Op);
343
+ LLVM_ABI static unsigned getIndex (CmpInst::Predicate P);
260
344
261
345
// / Accessors to get the embedding for a given entity.
262
346
LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const ;
@@ -265,34 +349,29 @@ class Vocabulary {
265
349
LLVM_ABI const ir2vec::Embedding &operator [](CmpInst::Predicate P) const ;
266
350
267
351
// / Const Iterator type aliases
268
- using const_iterator = VocabVector::const_iterator;
352
+ using const_iterator = VocabStorage::const_iterator;
353
+
269
354
const_iterator begin () const {
270
355
assert (isValid () && " IR2Vec Vocabulary is invalid" );
271
- return Vocab .begin ();
356
+ return Storage .begin ();
272
357
}
273
358
274
- const_iterator cbegin () const {
275
- assert (isValid () && " IR2Vec Vocabulary is invalid" );
276
- return Vocab.cbegin ();
277
- }
359
+ const_iterator cbegin () const { return begin (); }
278
360
279
361
const_iterator end () const {
280
362
assert (isValid () && " IR2Vec Vocabulary is invalid" );
281
- return Vocab .end ();
363
+ return Storage .end ();
282
364
}
283
365
284
- const_iterator cend () const {
285
- assert (isValid () && " IR2Vec Vocabulary is invalid" );
286
- return Vocab.cend ();
287
- }
366
+ const_iterator cend () const { return end (); }
288
367
289
368
// / Returns the string key for a given index position in the vocabulary.
290
369
// / This is useful for debugging or printing the vocabulary. Do not use this
291
370
// / for embedding generation as string based lookups are inefficient.
292
371
LLVM_ABI static StringRef getStringKey (unsigned Pos);
293
372
294
373
// / Create a dummy vocabulary for testing purposes.
295
- LLVM_ABI static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
374
+ LLVM_ABI static VocabStorage createDummyVocabForTest (unsigned Dim = 1 );
296
375
297
376
LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
298
377
ModuleAnalysisManager::Invalidator &Inv) const ;
@@ -301,12 +380,16 @@ class Vocabulary {
301
380
constexpr static unsigned NumCanonicalEntries =
302
381
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
303
382
304
- // Base offsets for slot layout to simplify index computation
383
+ // Base offsets for flat index computation
305
384
constexpr static unsigned OperandBaseOffset =
306
385
MaxOpcodes + MaxCanonicalTypeIDs;
307
386
constexpr static unsigned PredicateBaseOffset =
308
387
OperandBaseOffset + MaxOperandKinds;
309
388
389
+ // / Functions for predicate index calculations
390
+ static unsigned getPredicateLocalIndex (CmpInst::Predicate P);
391
+ static CmpInst::Predicate getPredicateFromLocalIndex (unsigned LocalIndex);
392
+
310
393
// / String mappings for CanonicalTypeID values
311
394
static constexpr StringLiteral CanonicalTypeNames[] = {
312
395
" FloatTy" , " VoidTy" , " LabelTy" , " MetadataTy" ,
@@ -452,22 +535,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
452
535
// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
453
536
// / its corresponding embedding.
454
537
class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
455
- using VocabVector = std::vector<ir2vec::Embedding>;
456
538
using VocabMap = std::map<std::string, ir2vec::Embedding>;
457
- VocabMap OpcVocab, TypeVocab, ArgVocab;
458
- VocabVector Vocab;
539
+ std::optional<ir2vec::VocabStorage> Vocab;
459
540
460
- Error readVocabulary ();
541
+ Error readVocabulary (VocabMap &OpcVocab, VocabMap &TypeVocab,
542
+ VocabMap &ArgVocab);
461
543
Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
462
544
VocabMap &TargetVocab, unsigned &Dim);
463
- void generateNumMappedVocab ();
545
+ void generateVocabStorage (VocabMap &OpcVocab, VocabMap &TypeVocab,
546
+ VocabMap &ArgVocab);
464
547
void emitError (Error Err, LLVMContext &Ctx);
465
548
466
549
public:
467
550
LLVM_ABI static AnalysisKey Key;
468
551
IR2VecVocabAnalysis () = default ;
469
- LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector & Vocab);
470
- LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector && Vocab);
552
+ LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::VocabStorage && Vocab)
553
+ : Vocab(std::move(Vocab)) {}
471
554
using Result = ir2vec::Vocabulary;
472
555
LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
473
556
};
0 commit comments