Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 00f3980

Browse files
lixy9474candyzone
authored andcommitted
[Embedding] Support LevelDB storage for EV.
1 parent 13ccd28 commit 00f3980

File tree

11 files changed

+589
-179
lines changed

11 files changed

+589
-179
lines changed

tensorflow/core/framework/embedding/embedding_config.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ struct EmbeddingConfig {
6666
kHashFunc = 0;
6767
num_counter = 0;
6868
}
69-
if (embedding::LEVELDB == storage_type) {
70-
layout_type = LayoutType::LEVELDB;
71-
}
7269
if (layout_type == LayoutType::NORMAL_FIX) {
7370
normal_fix_flag = 1;
7471
}

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class EmbeddingVar : public ResourceBase {
7474
new_value_ptr_fn = [] (size_t size) { return new LightValuePtr<V>(size); };
7575
} else if (LayoutType::NORMAL == emb_config_.get_layout_type()) {
7676
new_value_ptr_fn = [] (size_t size) { return new NormalValuePtr<V>(size); };
77+
} else if (LayoutType::NORMAL_FIX == emb_config_.get_layout_type()){
78+
new_value_ptr_fn = [] (size_t size) { return new NormalContiguousValuePtr<V>(size); };
7779
} else {
7880
return errors::InvalidArgument(name_, ", Unsupport EmbeddingVariable LayoutType.");
7981
}
@@ -94,6 +96,12 @@ class EmbeddingVar : public ResourceBase {
9496
if (!alloc_) {
9597
return errors::InvalidArgument(name_, ", No registered PMEM_LIBPMEM AllocatorFactory.");
9698
}
99+
} else if (embedding::StorageType::LEVELDB == emb_config_.get_storage_type()) {
100+
alloc_ = ev_allocator();
101+
if (!alloc_) {
102+
return errors::InvalidArgument(name_, ", No registered EV AllocatorFactory.");
103+
}
104+
kv_->SetNewValuePtrFunc(new_value_ptr_fn);
97105
} else {
98106
return errors::InvalidArgument(name_, ", Unsupport EmbeddingVariable StorageType.");
99107
}
@@ -117,28 +125,6 @@ class EmbeddingVar : public ResourceBase {
117125
new_value_ptr_fn = [] (size_t size) { return new LightValuePtr<V>(size); };
118126
} else if (LayoutType::NORMAL == emb_config_.get_layout_type()) {
119127
new_value_ptr_fn = [] (size_t size) { return new NormalValuePtr<V>(size); };
120-
} else if (LayoutType::LEVELDB == emb_config_.get_layout_type()) {
121-
if (emb_config_.is_primary()) {
122-
std::string path = emb_config_.get_storage_path();
123-
Status s = Env::Default()->IsDirectory(path);
124-
if (!s.ok()) {
125-
LOG(WARNING) << "StoragePath=\"" << path << "\" is not Directory, message: " << s.ToString() << ". Try to create dir.";
126-
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(path));
127-
}
128-
db_name_ = io::JoinPath(path, "level_db_" + std::to_string(Env::Default()->NowMicros()));
129-
leveldb::Status st;
130-
leveldb::Options options;
131-
options.create_if_missing = true;
132-
//options.write_buffer_size = 1024 * 1024 * 1024;
133-
//options.error_if_exists = true;
134-
st = leveldb::DB::Open(options, db_name_.c_str(), &level_db_);
135-
if (!st.ok()) {
136-
LOG(FATAL) << "Fail to open leveldb: " << st.ToString();
137-
} else {
138-
VLOG(1) << "Open DB Success, db_name: " << db_name_;
139-
}
140-
new_value_ptr_fn = [this] (size_t size) { return new DBValuePtr<V>(size, this->level_db_); };
141-
}
142128
} else if (LayoutType::NORMAL_FIX == emb_config_.get_layout_type()){
143129
new_value_ptr_fn = [] (size_t size) { return new NormalContiguousValuePtr<V>(size); };
144130
} else {
@@ -166,6 +152,7 @@ class EmbeddingVar : public ResourceBase {
166152
if (!alloc_) {
167153
return errors::InvalidArgument(name_, ", No registered EV AllocatorFactory.");
168154
}
155+
kv_->SetNewValuePtrFunc(new_value_ptr_fn);
169156
} else {
170157
return errors::InvalidArgument(name_, ", Unsupport EmbeddingVariable StorageType.");
171158
}
@@ -206,6 +193,10 @@ class EmbeddingVar : public ResourceBase {
206193
return s;
207194
}
208195

196+
void BatchCommit(std::vector<K> keys, std::vector<ValuePtr<V>*> value_ptrs) {
197+
Status s = kv_->BatchCommit(keys, value_ptrs);
198+
}
199+
209200
int64 GetVersion(K key) {
210201
ValuePtr<V>* value_ptr = nullptr;
211202
TF_CHECK_OK(LookupOrCreateKey(key, &value_ptr));
@@ -242,8 +233,8 @@ class EmbeddingVar : public ResourceBase {
242233
return typename TTypes<V>::Flat(val, dims);
243234
}
244235

245-
void Commit(ValuePtr<V>* value_ptr, const V* v) {
246-
value_ptr->Commit(value_len_, v, emb_config_.emb_index);
236+
void Commit(const K id, ValuePtr<V>* value_ptr) {
237+
kv_->Commit(id, value_ptr);
247238
}
248239

249240
int64 ValueLen() const {
@@ -304,7 +295,7 @@ class EmbeddingVar : public ResourceBase {
304295
}
305296
}
306297
V* v = LookupOrCreateEmb(value_ptr, value_buff + i * value_len_);
307-
value_ptr->Free(v);
298+
kv_->Commit(key_buff[i], value_ptr);
308299
}
309300
return Status::OK();
310301
}
@@ -329,11 +320,14 @@ class EmbeddingVar : public ResourceBase {
329320
version_list->push_back(dump_version);
330321
}
331322
}
323+
kv_->FreeValuePtr(value_ptr_list[i]);
332324
}
333325
return key_list->size();
334326
}
335327

336328
Status Destroy(int64 value_len) {
329+
if (embedding::StorageType::LEVELDB == emb_config_.get_storage_type())
330+
return Status::OK();
337331
std::vector<K> key_list;
338332
std::vector<ValuePtr<V>* > value_ptr_list;
339333
kv_->GetSnapshot(&key_list, &value_ptr_list);

tensorflow/core/framework/embedding/kv_factory.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include "tensorflow/core/lib/core/status.h"
2020
#include "tensorflow/core/framework/embedding/dense_hash_map.h"
2121
#include "tensorflow/core/framework/embedding/lockless_hash_map.h"
22+
#include "tensorflow/core/framework/embedding/leveldb_kv.h"
2223

2324
namespace tensorflow {
2425

@@ -30,14 +31,19 @@ class KVFactory {
3031
public:
3132
~KVFactory() {}
3233
static KVInterface<K, V>* CreateKV(const std::string& kv_type,
33-
int partition_num) {
34+
int partition_num,
35+
std::string path) {
3436
if ("dense_hash_map" == kv_type) {
3537
VLOG(2) << "Use dense_hash_map as EV data struct";
3638
return new DenseHashMap<K, V>();
3739
} else if ("lockless_hash_map" == kv_type) {
3840
VLOG(2) << "Use lockless_hash_map as EV data struct";
3941
return new LocklessHashMap<K, V>();
40-
} else {
42+
} else if ("leveldb_kv" == kv_type) {
43+
VLOG(2) << "Use leveldb_kv as EV data struct";
44+
return new LevelDBKV<K, V>(path);
45+
}
46+
else {
4147
VLOG(2) << "Not match any hashtable_type, use default 'lockless_hash_map'";
4248
return new LocklessHashMap<K, V>();
4349
}

tensorflow/core/framework/embedding/kv_interface.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,17 @@ class KVInterface {
5050
"Unimplemented for BatchRemove in KVInterface.");
5151
}
5252

53+
virtual Status BatchCommit(std::vector<K> keys, std::vector<ValuePtr<V>*> value_ptrs) {return Status::OK();}
54+
5355
// KV Size
5456
virtual int64 Size() const = 0;
5557

58+
virtual void SetNewValuePtrFunc(std::function<ValuePtr<V>*(size_t)> new_value_ptr_fn) {}
59+
60+
virtual void FreeValuePtr(ValuePtr<V>* value_ptr) {}
61+
62+
virtual Status Commit(K key, const ValuePtr<V>* value_ptr) {return Status::OK();}
63+
5664
virtual Status GetSnapshot(std::vector<K>* key_list,
5765
std::vector<ValuePtr<V>* >* value_ptr_list) = 0;
5866

@@ -97,7 +105,6 @@ class KVInterface {
97105
std::vector<int> slot_dims_;
98106
std::vector<int> slot_offset_;
99107
int total_dims_;
100-
101108
private:
102109
std::atomic_flag flag_ = ATOMIC_FLAG_INIT;
103110
};
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#ifndef TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_LEVELDB_KV_H_
2+
#define TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_LEVELDB_KV_H_
3+
4+
#include "tensorflow/core/lib/io/path.h"
5+
6+
#include "tensorflow/core/framework/embedding/kv_interface.h"
7+
#include "tensorflow/core/lib/core/status.h"
8+
9+
#include "leveldb/db.h"
10+
#include "leveldb/comparator.h"
11+
#include "leveldb/write_batch.h"
12+
13+
#include <sstream>
14+
15+
using leveldb::DB;
16+
using leveldb::Options;
17+
using leveldb::ReadOptions;
18+
using leveldb::WriteBatch;
19+
using leveldb::WriteOptions;
20+
using leveldb::Iterator;
21+
22+
namespace tensorflow {
23+
template <class V>
24+
class ValuePtr;
25+
26+
template <class K>
27+
class SizeCounter {
28+
public:
29+
SizeCounter(int num_parts) {
30+
num_parts_ = num_parts;
31+
for (int i = 0; i < num_parts_; i++) {
32+
counter_.emplace_back(0);
33+
}
34+
}
35+
36+
void add(K key, int64 count) {
37+
int part = key % num_parts_;
38+
__sync_fetch_and_add(&counter_[part], count);
39+
}
40+
41+
void sub(K key, int64 count) {
42+
int part = key % num_parts_;
43+
__sync_fetch_and_sub(&counter_[part], count);
44+
}
45+
46+
int64 size() {
47+
int64 total = 0;
48+
for (int i = 0; i < num_parts_; i++) {
49+
total += counter_[i];
50+
}
51+
return total;
52+
}
53+
54+
private:
55+
std::vector<int64> counter_;
56+
int num_parts_;
57+
};
58+
59+
template <class K, class V>
60+
class LevelDBKV : public KVInterface<K, V> {
61+
public:
62+
LevelDBKV(std::string path) {
63+
path_ = io::JoinPath(path, "level_db_" + std::to_string(Env::Default()->NowMicros()));;
64+
options_.create_if_missing = true;
65+
leveldb::Status s = leveldb::DB::Open(options_, path_, &db_);
66+
KVInterface<K, V>::total_dims_ = 0;
67+
counter_ = new SizeCounter<K>(8);
68+
CHECK(s.ok());
69+
}
70+
71+
void SetNewValuePtrFunc(std::function<ValuePtr<V>*(size_t)> new_value_ptr_fn) {
72+
new_value_ptr_fn_ = new_value_ptr_fn;
73+
}
74+
75+
~LevelDBKV() {
76+
delete db_;
77+
}
78+
79+
Status Lookup(K key, ValuePtr<V>** value_ptr) {
80+
std::string val_str;
81+
leveldb::Slice db_key((char*)(&key), sizeof(void*));
82+
ValuePtr<V>* val = new_value_ptr_fn_(KVInterface<K, V>::total_dims_);
83+
leveldb::ReadOptions options;
84+
leveldb::Status s = db_->Get(options, db_key, &val_str);
85+
if (s.IsNotFound()) {
86+
delete val;
87+
return errors::NotFound(
88+
"Unable to find Key: ", key, " in LevelDB.");
89+
} else {
90+
memcpy((int64 *)(val->GetPtr()), &val_str[0], val_str.length());
91+
*value_ptr = val;
92+
return Status::OK();
93+
}
94+
}
95+
96+
Status Insert(K key, const ValuePtr<V>* value_ptr) {
97+
counter_->add(key, 1);
98+
return Status::OK();
99+
}
100+
101+
Status BatchInsert(std::vector<K> keys, std::vector<ValuePtr<V>*> value_ptrs) {
102+
return BatchCommit(keys, value_ptrs);
103+
}
104+
105+
Status BatchCommit(std::vector<K> keys, std::vector<ValuePtr<V>*> value_ptrs) {
106+
WriteBatch batch;
107+
for (int i = 0; i < keys.size(); i++) {
108+
std::string value_res((char*)value_ptrs[i]->GetPtr(), sizeof(FixedLengthHeader) + KVInterface<K, V>::total_dims_ * sizeof(V));
109+
leveldb::Slice db_key((char*)(&keys[i]), sizeof(void*));
110+
batch.Put(db_key, value_res);
111+
delete value_ptrs[i];
112+
}
113+
db_->Write(WriteOptions(),&batch);
114+
return Status::OK();
115+
}
116+
117+
Status Commit(K key, const ValuePtr<V>* value_ptr) {
118+
std::string value_res((char*)value_ptr->GetPtr(), sizeof(FixedLengthHeader) + KVInterface<K, V>::total_dims_ * sizeof(V));
119+
leveldb::Slice db_key((char*)(&key), sizeof(void*));
120+
leveldb::Status s = db_->Put(WriteOptions(), db_key, value_res);
121+
delete value_ptr;
122+
if (!s.ok()){
123+
return errors::AlreadyExists(
124+
"already exists Key: ", key, " in RocksDB.");
125+
} else {
126+
return Status::OK();
127+
}
128+
}
129+
130+
Status Remove(K key) {
131+
counter_->sub(key, 1);
132+
leveldb::Slice db_key((char*)(&key), sizeof(void*));
133+
leveldb::Status s = db_->Delete(WriteOptions(), db_key);
134+
if (s.ok()) {
135+
return Status::OK();
136+
} else {
137+
return errors::NotFound(
138+
"Unable to find Key: ", key, " in RocksDB.");
139+
}
140+
}
141+
142+
Status GetSnapshot(std::vector<K>* key_list, std::vector<ValuePtr<V>* >* value_ptr_list) {
143+
ReadOptions options;
144+
options.snapshot = db_->GetSnapshot();
145+
Iterator* it = db_->NewIterator(options);
146+
for (it->SeekToFirst(); it->Valid(); it->Next()) {
147+
std::string key_str, value_str;
148+
ValuePtr<V>* value_ptr = new_value_ptr_fn_(KVInterface<K, V>::total_dims_);
149+
key_str = it->key().ToString();
150+
value_str = it->value().ToString();
151+
key_list->emplace_back(*((long*)&key_str[0]));
152+
void* ptr = value_ptr->GetPtr();
153+
memcpy(ptr, &value_str[0], KVInterface<K, V>::total_dims_ * sizeof(V) + sizeof(FixedLengthHeader));
154+
value_ptr_list->emplace_back(value_ptr);
155+
}
156+
assert(it->status().ok());
157+
delete it;
158+
db_->ReleaseSnapshot(options.snapshot);
159+
return Status::OK();
160+
}
161+
162+
int64 Size() const {
163+
return counter_->size();
164+
}
165+
166+
void FreeValuePtr(ValuePtr<V>* value_ptr) {
167+
delete value_ptr;
168+
}
169+
170+
std::string DebugString() const {
171+
return "";
172+
}
173+
private:
174+
DB* db_;
175+
SizeCounter<K>* counter_;
176+
Options options_;
177+
std::string path_;
178+
std::function<ValuePtr<V>*(size_t)> new_value_ptr_fn_;
179+
};
180+
181+
} //namespace tensorflow
182+
183+
#endif TENSORFLOW_CORE_FRAMEWORK_EMBEDDING_LEVELDB_KV_H_

0 commit comments

Comments
 (0)