Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions include/saltatlas/dnnd/detail/dnnd_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class dnnd_kernel {
if (!priv_check_const_option()) {
return;
}
priv_reset_build_profile_counters();
priv_init_knn_heap_with_random_values();
priv_construct_kernel();
priv_convert(knn_index);
Expand All @@ -132,6 +133,7 @@ class dnnd_kernel {
if (!priv_check_const_option()) {
return;
}
priv_reset_build_profile_counters();
priv_init_knn_heap_with_index(init_knn_index, recheck);
priv_construct_kernel();
priv_convert(knn_index);
Expand All @@ -153,6 +155,7 @@ class dnnd_kernel {
if (!priv_check_const_option()) {
return;
}
priv_reset_build_profile_counters();
priv_init_knn_heap_with_index(init_knn_index, recheck);
priv_construct_kernel();
priv_convert(knn_index);
Expand All @@ -163,6 +166,7 @@ class dnnd_kernel {
if (m_option.verbose) {
m_comm.cout0() << "Rerunning NN-Descent kernel" << std::endl;
}
priv_reset_build_profile_counters();
priv_init_knn_heap_with_index(knn_index, true);
priv_construct_kernel();
priv_convert(knn_index);
Expand Down Expand Up @@ -211,6 +215,26 @@ class dnnd_kernel {
return true;
}

void priv_reset_build_profile_counters() {
m_cnt_dist_cals = 0;
#if SALTATLAS_DNND_SHOW_BASIC_MSG_STATISTICS
m_num_neighbor_suggestion_msgs = 0;
m_num_feature_msgs = 0;
m_num_distance_msgs = 0;
m_num_pruned_distance_msgs = 0;
#endif

#if SALTATLAS_DNND_PROFILE_FEATURE_MSG
m_feature_msg_src_count.clear();
#endif
}

inline distance_type priv_compute_distance(const point_type& point1,
const point_type& point2) {
++m_cnt_dist_cals;
return m_distance_function(point1, point2);
}

void priv_init_knn_heap_with_random_values() {
if (m_option.verbose) {
m_comm.cout0() << "\nInitializing the k-NN index with random neighbors."
Expand Down Expand Up @@ -254,12 +278,6 @@ class dnnd_kernel {

void priv_construct_kernel() {
priv_check_construct_parameters();
#if SALTATLAS_DNND_SHOW_BASIC_MSG_STATISTICS
m_num_neighbor_suggestion_msgs = 0;
m_num_feature_msgs = 0;
m_num_distance_msgs = 0;
m_num_pruned_distance_msgs = 0;
#endif
std::size_t epoch_no = 0;
double elapsed_time_sec = 0.0;
while (true) {
Expand Down Expand Up @@ -310,6 +328,8 @@ class dnnd_kernel {
}
m_comm.cf_barrier();
if (m_option.verbose) {
m_comm.cout0() << "\nTotal distance calculations\t"
<< ygm::sum(m_cnt_dist_cals, m_comm) << std::endl;
#if SALTATLAS_DNND_SHOW_BASIC_MSG_STATISTICS
m_comm.cout0() << "\nMessage Statistics" << std::endl;
m_comm.cout0() << "#of sent neighbor suggestions\t"
Expand Down Expand Up @@ -464,7 +484,7 @@ class dnnd_kernel {
const id_type sid, const id_type nid,
const point_type& src_point) {
const auto& nbr_point = local_this->m_point_store[nid];
const auto d = local_this->m_distance_function(src_point, nbr_point);
const auto d = local_this->priv_compute_distance(src_point, nbr_point);
local_this->comm().async(local_this->m_point_partitioner(sid),
distance_calculator{}, local_this, sid, nid, d);
}
Expand Down Expand Up @@ -495,7 +515,7 @@ class dnnd_kernel {
std::advance(pitr, offset);
const auto& nid = pitr->first;
const auto& nbr_point = pitr->second;
const auto d = local_this->m_distance_function(src_point, nbr_point);
const auto d = local_this->priv_compute_distance(src_point, nbr_point);
local_this->comm().async(local_this->m_point_partitioner(sid),
random_neighbor_explorer{}, local_this, sid, nid,
d);
Expand Down Expand Up @@ -884,7 +904,7 @@ class dnnd_kernel {
// Update u2's heap (nearest neighbors list) if 'u1' is closer than the
// current neighbors.
const auto& u2_point = local_this->m_point_store[u2];
const auto d = local_this->m_distance_function(u1_point, u2_point);
const auto d = local_this->priv_compute_distance(u1_point, u2_point);
local_this->m_cnt_new_neighbors += nn_heap.try_add(u1, d, true);

if (d < u1_max_distance) {
Expand Down Expand Up @@ -1062,6 +1082,9 @@ class dnnd_kernel {
std::size_t m_num_points{0}; // Global number of points
std::size_t m_mini_batch_no{0};
std::size_t m_cnt_new_neighbors{0};

// For profiling
std::size_t m_cnt_dist_cals{0};
#if SALTATLAS_DNND_SHOW_BASIC_MSG_STATISTICS
std::size_t m_num_neighbor_suggestion_msgs{0};
std::size_t m_num_feature_msgs{0};
Expand Down
95 changes: 65 additions & 30 deletions include/saltatlas/neo_dnnd/neo_dnnd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class neo_dnnd {
m_fvs_disp.resize(m_fv_send_batch_size);
m_fvs_block_lengths.resize(m_fv_send_batch_size, 1);
priv_commit_mpi_types();
priv_reset_profile_counters();
}
if (m_verbose) {
priv_show_dram_usage();
Expand Down Expand Up @@ -345,6 +346,7 @@ class neo_dnnd {
void optimize(const double m, knng_type& knng) {
m_comm.cout0() << "Start optimization" << std::endl;
priv_cout0(m_verbose) << "m: " << m << std::endl;
priv_reset_profile_counters();

matrix2d<id_type> r_nids(m_comm.size());
std::vector<std::vector<distance_type>> r_dits(m_comm.size());
Expand Down Expand Up @@ -451,7 +453,20 @@ class neo_dnnd {
m_comm.barrier();
}

/// Reset time recorder.
void reset_time_recorder() {
if (m_time_recorder) {
m_time_recorder->get().reset();
}
}

void print_profile([[maybe_unused]] const bool final) const {
if (final) {
m_comm.cout0() << "#of distance calculations:\t"
<< m_comm.all_reduce_sum(m_num_distance_calculations)
<< std::endl;
}

if (final || m_verbose) {
m_comm.cout0() << "FV processing breakdown:" << std::endl;
m_comm.cout0() << " #of normally sent:\t"
Expand Down Expand Up @@ -643,6 +658,24 @@ class neo_dnnd {
}
}

void priv_reset_profile_counters() {
m_num_distance_calculations = 0;
m_counter_db.clear();
#ifdef PROFILE_FV
m_fv_count.clear();
#endif
// No reset m_time_recorder here since it's could be owned externally
}

inline distance_type priv_get_distance(const fe_type* const p1,
const std::size_t dim1,
const fe_type* const p2,
const std::size_t dim2) {
++m_num_distance_calculations;
return m_distance_func(std::span(const_cast<fe_type*>(p1), dim1),
std::span(const_cast<fe_type*>(p2), dim2));
}

static std::string priv_get_pstore_name(const int rank) {
std::string name = k_shm_dir;
name += "/";
Expand Down Expand Up @@ -894,9 +927,9 @@ class neo_dnnd {
const auto nid = elem.second;

assert(priv_owner(sid) == m_comm.rank());
const auto dist = m_distance_func(
std::span(const_cast<fe_type*>(point_store[sid]), num_dims()),
std::span(&feature_recv_buf[buf_i * num_dims()], num_dims()));
const auto dist = priv_get_distance(
point_store[sid], num_dims(), &feature_recv_buf[buf_i * num_dims()],
num_dims());
assert(m_graph.count(sid) > 0);
m_graph.at(sid).try_add(nid, dist, true); // push as a new neighbor
++buf_i;
Expand Down Expand Up @@ -1240,9 +1273,8 @@ class neo_dnnd {
assert(priv_owner(nb) == pair_rank);
const auto* src_fv = my_pstore[src];
const auto* nb_fv = pair_pstore[nb];
const auto dist = m_distance_func(
std::span(const_cast<fe_type*>(src_fv), num_dims()),
std::span(const_cast<fe_type*>(nb_fv), num_dims()));
const auto dist =
priv_get_distance(src_fv, num_dims(), nb_fv, num_dims());
distances[c] = dist;
}
m_time_recorder->get().stop();
Expand Down Expand Up @@ -1445,10 +1477,8 @@ class neo_dnnd {
assert(priv_owner(pid) == m_comm.rank());
const auto fv_idx = (indices.size() > 0) ? indices[i] : i;
const auto sent_fv_pos = fv_idx * num_dims();
const auto dist = m_distance_func(
std::span(const_cast<fe_type*>(point_store[pid]), num_dims()),
std::span(const_cast<fe_type*>(&features.at(sent_fv_pos)),
num_dims()));
const auto dist = priv_get_distance(
point_store[pid], num_dims(), &features.at(sent_fv_pos), num_dims());
out_distances[i] = dist;
}
}
Expand Down Expand Up @@ -1559,10 +1589,8 @@ class neo_dnnd {
const auto nfv_bank_no = priv_node_local_rank(priv_owner(nid));
const auto* const nfv = m_pop_fv_store->get(nfv_bank_no, nid);
assert(nfv);
const auto dist =
m_distance_func(std::span(const_cast<fe_type*>(sfv), num_dims()),
std::span(const_cast<fe_type*>(nfv), num_dims()));
distances[i] = dist;
const auto dist = priv_get_distance(sfv, num_dims(), nfv, num_dims());
distances[i] = dist;
}
m_time_recorder->get().stop();

Expand Down Expand Up @@ -1666,21 +1694,29 @@ class neo_dnnd {

// Variables initialized in the constructor
distance_function m_distance_func;
time_recorder m_default_time_recorder;
std::optional<std::reference_wrapper<time_recorder>> m_time_recorder;
mpi::communicator& m_comm;
std::mt19937_64 m_rng; // Must be initialized after m_comm as it uses rank
bool m_verbose{false};

bool m_read_nlocal_pstores_directly{false};
bool m_remove_duplicate_fvs{false};
std::size_t m_k{0};
double m_rho{1.0};
double m_delta{0.001};
std::size_t m_num_total_points{0};
dndetail::counter_db m_counter_db;
// Options
bool m_read_nlocal_pstores_directly{false};
bool m_remove_duplicate_fvs{false};
std::size_t m_k{0};
double m_rho{1.0};
double m_delta{0.001};
std::size_t m_num_total_points{0};

// Core data structures
knn_heap_adj_list_t m_graph{};
std::vector<const point_store*> m_point_stores;
std::vector<metall::manager*> m_point_store_managers;
std::unique_ptr<pop_fv_cache_t> m_pop_fv_store;

std::size_t m_num_dims{0};
std::size_t m_super_step_no{0};
std::size_t m_min_cache_id{0};

knn_heap_adj_list_t m_graph{};
// For sending feature vectors
std::size_t m_fv_send_batch_size{0};
::MPI_Datatype m_nb_dist_type{MPI_DATATYPE_NULL};
Expand All @@ -1691,15 +1727,14 @@ class neo_dnnd {
std::vector<::MPI_Datatype> m_fvs_types{};
std::vector<int> m_all_to_all_pairs{};
std::vector<int> m_all_to_all_node_pairs{};

// For profiling
std::size_t m_num_distance_calculations{0};
dndetail::counter_db m_counter_db{};
time_recorder m_default_time_recorder{};
#ifdef PROFILE_FV
// Counts how many times each feature vector was received.
bstuo::unordered_flat_map<id_type, std::size_t> m_fv_count;
bstuo::unordered_flat_map<id_type, std::size_t> m_fv_count{};
#endif
std::size_t m_super_step_no{0};
std::unique_ptr<pop_fv_cache_t> m_pop_fv_store;
std::size_t m_min_cache_id{0};
std::vector<const point_store*> m_point_stores;
std::vector<metall::manager*> m_point_store_managers;
std::size_t m_num_dims{0};
};
} // namespace saltatlas