Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ccb8450

Browse files
committed
[Embedding] Implement of Multi-Level Embedding.
1 parent 00f3980 commit ccb8450

18 files changed

+1013
-749
lines changed

modelzoo/features/pmem/benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def main(_):
106106
storage_type=config_pb2.StorageType.PMEM_LIBPMEM,
107107
storage_path=FLAGS.ev_storage_path,
108108
storage_size=FLAGS.ev_storage_size_gb * 1024 * 1024 * 1024))
109+
elif FLAGS.ev_storage == "dram_pmem":
110+
ev_option = variables.EmbeddingVariableOption(storage_option=variables.StorageOption(
111+
storage_type=config_pb2.StorageType.DRAM_PMEM,
112+
storage_path=FLAGS.ev_storage_path,
113+
storage_size=FLAGS.ev_storage_size_gb * 1024 * 1024 * 1024))
109114
fm_w = tf.get_embedding_variable(
110115
name='fm_w{}'.format(sidx),
111116
embedding_dim=1,
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_CACHE_H_
2+
#define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_CACHE_H_
3+
#include <iostream>
4+
#include <map>
5+
#include <unordered_map>
6+
#include <set>
7+
#include <list>
8+
#include "tensorflow/core/platform/types.h"
9+
#include "tensorflow/core/platform/mutex.h"
10+
#include "tensorflow/core/lib/core/status.h"
11+
12+
namespace tensorflow {
13+
namespace embedding {
14+
15+
template <class K>
16+
class BatchCache {
17+
public:
18+
BatchCache() {}
19+
virtual size_t get_evic_ids(K* evic_ids, size_t k_size) = 0;
20+
virtual void add_to_rank(const K* batch_ids, size_t batch_size) = 0;
21+
virtual size_t size() = 0;
22+
};
23+
24+
template <class K>
25+
class LRUCache : public BatchCache<K> {
26+
private:
27+
class LRUNode {
28+
public:
29+
K id;
30+
LRUNode *pre, *next;
31+
LRUNode(K id) : id(id), pre(nullptr), next(nullptr) {}
32+
};
33+
LRUNode *head, *tail;
34+
std::map<K, LRUNode *> mp;
35+
mutex mu_;
36+
37+
public:
38+
LRUCache() {
39+
mp.clear();
40+
head = new LRUNode(0);
41+
tail = new LRUNode(0);
42+
head->next = tail;
43+
tail->pre = head;
44+
}
45+
46+
size_t size() {
47+
mutex_lock l(mu_);
48+
return mp.size();
49+
}
50+
51+
size_t get_evic_ids(K* evic_ids, size_t k_size) {
52+
mutex_lock l(mu_);
53+
size_t true_size = 0;
54+
LRUNode *evic_node = tail->pre;
55+
LRUNode *rm_node = evic_node;
56+
for (size_t i = 0; i < k_size && evic_node != head; ++i) {
57+
evic_ids[i] = evic_node->id;
58+
rm_node = evic_node;
59+
evic_node = evic_node->pre;
60+
mp.erase(rm_node->id);
61+
delete rm_node;
62+
true_size++;
63+
}
64+
evic_node->next = tail;
65+
tail->pre = evic_node;
66+
return true_size;
67+
}
68+
69+
void add_to_rank(const K* batch_ids, size_t batch_size) {
70+
mutex_lock l(mu_);
71+
for (size_t i = 0; i < batch_size; ++i) {
72+
K id = batch_ids[i];
73+
typename std::map<K, LRUNode *>::iterator it = mp.find(id);
74+
if (it != mp.end()) {
75+
LRUNode *node = it->second;
76+
node->pre->next = node->next;
77+
node->next->pre = node->pre;
78+
head->next->pre = node;
79+
node->next = head->next;
80+
head->next = node;
81+
node->pre = head;
82+
} else {
83+
LRUNode *newNode = new LRUNode(id);
84+
head->next->pre = newNode;
85+
newNode->next = head->next;
86+
head->next = newNode;
87+
newNode->pre = head;
88+
mp[id] = newNode;
89+
}
90+
}
91+
}
92+
};
93+
94+
template <class K>
95+
class LFUCache : public BatchCache<K> {
96+
private:
97+
class LFUNode {
98+
public:
99+
K key;
100+
size_t freq;
101+
LFUNode(K key, size_t freq) : key(key), freq(freq) {}
102+
};
103+
size_t min_freq;
104+
size_t max_freq;
105+
std::unordered_map<K, typename std::list<LFUNode>::iterator> key_table;
106+
std::unordered_map<K, typename std::list<LFUNode>> freq_table;
107+
mutex mu_;
108+
109+
public:
110+
LFUCache() {
111+
min_freq = 0;
112+
max_freq = 0;
113+
key_table.clear();
114+
freq_table.clear();
115+
}
116+
117+
size_t size() {
118+
mutex_lock l(mu_);
119+
return key_table.size();
120+
}
121+
122+
size_t get_evic_ids(K *evic_ids, size_t k_size) {
123+
mutex_lock l(mu_);
124+
size_t true_size = 0;
125+
for (size_t i = 0; i < k_size; ++i) {
126+
auto rm_it = freq_table[min_freq].back();
127+
key_table.erase(rm_it.key);
128+
evic_ids[i] = rm_it.key;
129+
++true_size;
130+
freq_table[min_freq].pop_back();
131+
if (freq_table[min_freq].size() == 0) {
132+
freq_table.erase(min_freq);
133+
++min_freq;
134+
while (min_freq <= max_freq) {
135+
auto it = freq_table.find(min_freq);
136+
if (it == freq_table.end() || it->second.size() == 0) {
137+
++min_freq;
138+
} else {
139+
break;
140+
}
141+
}
142+
}
143+
}
144+
return true_size;
145+
}
146+
147+
void add_to_rank(const K *batch_ids, size_t batch_size) {
148+
mutex_lock l(mu_);
149+
for (size_t i = 0; i < batch_size; ++i) {
150+
K id = batch_ids[i];
151+
auto it = key_table.find(id);
152+
if (it == key_table.end()) {
153+
freq_table[1].push_front(LFUNode(id, 1));
154+
key_table[id] = freq_table[1].begin();
155+
min_freq = 1;
156+
} else {
157+
typename std::list<LFUNode>::iterator node = it->second;
158+
size_t freq = node->freq;
159+
freq_table[freq].erase(node);
160+
if (freq_table[freq].size() == 0) {
161+
freq_table.erase(freq);
162+
if (min_freq == freq)
163+
min_freq += 1;
164+
}
165+
max_freq = std::max(max_freq, freq + 1);
166+
freq_table[freq + 1].push_front(LFUNode(id, freq + 1));
167+
key_table[id] = freq_table[freq + 1].begin();
168+
}
169+
}
170+
}
171+
};
172+
173+
} // embedding
174+
} // tensorflow
175+
176+
#endif // TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_CACHE_H_

tensorflow/core/framework/embedding/config.proto

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ enum StorageType {
1111
PMEM_MEMKIND = 2;
1212
PMEM_LIBPMEM = 3;
1313
SSD = 4;
14+
LEVELDB = 5;
1415

15-
LEVELDB = 14;
16-
/*
1716
// two level
1817
DRAM_PMEM = 11;
1918
DRAM_SSD = 12;
2019
HBM_DRAM = 13;
20+
DRAM_LEVELDB = 14;
2121

2222
// three level
2323
DRAM_PMEM_SSD = 101;
2424
HBM_DRAM_SSD = 102;
25-
*/
25+
2626
}

tensorflow/core/framework/embedding/dense_hash_map.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ class DenseHashMap : public KVInterface<K, V> {
3838
hash_map_[i].hash_map.set_empty_key(-1);
3939
hash_map_[i].hash_map.set_deleted_key(-2);
4040
}
41-
KVInterface<K, V>::total_dims_ = 0;
4241
}
4342

4443
~DenseHashMap() {

tensorflow/core/framework/embedding/embedding_config.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ struct EmbeddingConfig {
2424
int64 storage_size;
2525
int64 default_value_dim;
2626
int normal_fix_flag;
27+
bool is_multi_level;
2728

2829
EmbeddingConfig(int64 emb_index = 0, int64 primary_emb_index = 0,
2930
int64 block_num = 1, int slot_num = 0,
3031
const std::string& name = "", int64 steps_to_live = 0,
3132
int64 filter_freq = 0, int64 max_freq = 999999,
32-
float l2_weight_threshold = -1.0, const std::string& layout = "normal",
33+
float l2_weight_threshold = -1.0, const std::string& layout = "normal_fix",
3334
int64 max_element_size = 0, float false_positive_probability = -1.0,
3435
DataType counter_type = DT_UINT64, embedding::StorageType storage_type = embedding::DRAM,
3536
const std::string& storage_path = "", int64 storage_size = 0,
@@ -48,7 +49,8 @@ struct EmbeddingConfig {
4849
storage_path(storage_path),
4950
storage_size(storage_size),
5051
default_value_dim(default_value_dim),
51-
normal_fix_flag(0) {
52+
normal_fix_flag(0),
53+
is_multi_level(false) {
5254
if ("normal" == layout) {
5355
layout_type = LayoutType::NORMAL;
5456
} else if ("light" == layout) {
@@ -61,16 +63,21 @@ struct EmbeddingConfig {
6163
}
6264
if (max_element_size != 0 && false_positive_probability != -1.0){
6365
kHashFunc = calc_num_hash_func(false_positive_probability);
64-
num_counter = calc_num_counter(max_element_size, false_positive_probability);
66+
num_counter = calc_num_counter(max_element_size, false_positive_probability);
6567
} else {
6668
kHashFunc = 0;
6769
num_counter = 0;
6870
}
6971
if (layout_type == LayoutType::NORMAL_FIX) {
7072
normal_fix_flag = 1;
7173
}
74+
if (storage_type == embedding::PMEM_MEMKIND || storage_type == embedding::PMEM_LIBPMEM ||
75+
storage_type == embedding::DRAM_PMEM || storage_type == embedding::DRAM_SSD ||
76+
storage_type == embedding::HBM_DRAM || storage_type == embedding::DRAM_LEVELDB) {
77+
is_multi_level = true;
78+
}
7279
}
73-
80+
7481
int64 calc_num_counter(int64 max_element_size, float false_positive_probability) {
7582
float loghpp = fabs(log(false_positive_probability));
7683
float factor = log(2) * log(2);

0 commit comments

Comments
 (0)