-
Notifications
You must be signed in to change notification settings - Fork 15k
[NFC][IR2Vec] Initialize Embedding vectors with zeros by default #155690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesInitialize Full diff: https://github.com/llvm/llvm-project/pull/155690.diff 2 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 44932a3385e16..6fb8f736da092 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -92,7 +92,7 @@ struct Embedding {
Embedding(std::vector<double> &&V) : Data(std::move(V)) {}
Embedding(std::initializer_list<double> IL) : Data(IL) {}
- explicit Embedding(size_t Size) : Data(Size) {}
+ explicit Embedding(size_t Size) : Data(Size, 0.0) {}
Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {}
size_t size() const { return Data.size(); }
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 565ec2a6287b7..6b90f1aabacfa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -155,7 +155,7 @@ void Embedding::print(raw_ostream &OS) const {
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
- FuncVector(Embedding(Dimension, 0)) {}
+ FuncVector(Embedding(Dimension)) {}
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
@@ -472,7 +472,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
@@ -487,7 +487,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Types - only canonical types are present in vocabulary
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
@@ -503,7 +503,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
|
@llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesInitialize Full diff: https://github.com/llvm/llvm-project/pull/155690.diff 2 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 44932a3385e16..6fb8f736da092 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -92,7 +92,7 @@ struct Embedding {
Embedding(std::vector<double> &&V) : Data(std::move(V)) {}
Embedding(std::initializer_list<double> IL) : Data(IL) {}
- explicit Embedding(size_t Size) : Data(Size) {}
+ explicit Embedding(size_t Size) : Data(Size, 0.0) {}
Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {}
size_t size() const { return Data.size(); }
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 565ec2a6287b7..6b90f1aabacfa 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -155,7 +155,7 @@ void Embedding::print(raw_ostream &OS) const {
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
- FuncVector(Embedding(Dimension, 0)) {}
+ FuncVector(Embedding(Dimension)) {}
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
@@ -472,7 +472,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes);
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
@@ -487,7 +487,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Types - only canonical types are present in vocabulary
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs);
for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) {
StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID(
@@ -503,7 +503,7 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Arguments/Operands
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
- Embedding(Dim, 0));
+ Embedding(Dim));
NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds);
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
|
9a18f1c
to
7ddfeaa
Compare
51b1cd4
to
7ec3927
Compare
7ddfeaa
to
c809d9d
Compare
f01119a
to
f82d77f
Compare
c809d9d
to
18675c6
Compare
18675c6
to
374bfa9
Compare
f0b3c0c
to
a9edd27
Compare
97560b9
to
a20fb0e
Compare
a9edd27
to
493f471
Compare
a20fb0e
to
b21b641
Compare
ec2e1e1
to
da83ad8
Compare
b21b641
to
5c658e1
Compare
5c658e1
to
0d74ab7
Compare
da83ad8
to
8c8500c
Compare
8c8500c
to
fd4e1df
Compare
Merge activity
|
Initialize
Embedding
vectors with zeros by default when only size is provided.