Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58))
- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61), [#64](https://github.com/pyg-team/pyg-lib/pull/64), [#69](https://github.com/pyg-team/pyg-lib/pull/69))
- Added `pyg::sampler::neighbor_sample` implementation ([#54](https://github.com/pyg-team/pyg-lib/pull/54), [#76](https://github.com/pyg-team/pyg-lib/pull/76), [#77](https://github.com/pyg-team/pyg-lib/pull/77), [#78](https://github.com/pyg-team/pyg-lib/pull/78), [#80](https://github.com/pyg-team/pyg-lib/pull/80), [#81](https://github.com/pyg-team/pyg-lib/pull/81)), [#85](https://github.com/pyg-team/pyg-lib/pull/85), [#86](https://github.com/pyg-team/pyg-lib/pull/86))
- Added `pyg::sampler::neighbor_sample` implementation ([#54](https://github.com/pyg-team/pyg-lib/pull/54), [#76](https://github.com/pyg-team/pyg-lib/pull/76), [#77](https://github.com/pyg-team/pyg-lib/pull/77), [#78](https://github.com/pyg-team/pyg-lib/pull/78), [#80](https://github.com/pyg-team/pyg-lib/pull/80), [#81](https://github.com/pyg-team/pyg-lib/pull/81)), [#85](https://github.com/pyg-team/pyg-lib/pull/85), [#86](https://github.com/pyg-team/pyg-lib/pull/86), [#87](https://github.com/pyg-team/pyg-lib/pull/87))
- Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45), [#83](https://github.com/pyg-team/pyg-lib/pull/83))
- Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45), [#79](https://github.com/pyg-team/pyg-lib/pull/79), [#82](https://github.com/pyg-team/pyg-lib/pull/82))
- Added download script for benchmark data ([#44](https://github.com/pyg-team/pyg-lib/pull/44))
Expand Down
24 changes: 7 additions & 17 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "parallel_hashmap/phmap.h"
#include "pyg_lib/csrc/random/cpu/rand_engine.h"
#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "pyg_lib/csrc/sampler/subgraph.h"
Expand Down Expand Up @@ -32,9 +31,8 @@ class NeighborSampler {
if (count == 0)
return;

const auto offset = rowptr_offset(rowptr_, global_src_node);
const auto row_start = std::get<0>(offset);
const auto row_end = std::get<1>(offset);
const auto row_start = rowptr_[to_scalar_t(global_src_node)];
const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];
const auto population = row_end - row_start;

if (population == 0)
Expand Down Expand Up @@ -84,9 +82,8 @@ class NeighborSampler {
if (count == 0)
return;

const auto offset = rowptr_offset(rowptr_, global_src_node);
const auto row_start = std::get<0>(offset);
const auto row_end = std::get<1>(offset);
const auto row_start = rowptr_[to_scalar_t(global_src_node)];
const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];
const auto population = row_end - row_start;

if (population == 0)
Expand Down Expand Up @@ -152,21 +149,14 @@ class NeighborSampler {
}

private:
inline std::pair<scalar_t, scalar_t> rowptr_offset(const scalar_t* rowptr,
const scalar_t& node) {
return {rowptr[node], rowptr[node + 1]};
}

inline std::pair<scalar_t, scalar_t> rowptr_offset(
const scalar_t* rowptr,
const std::pair<scalar_t, scalar_t>& node) {
return {rowptr[std::get<1>(node)], rowptr[std::get<1>(node) + 1]};
inline scalar_t to_scalar_t(const scalar_t& node) { return node; }
inline scalar_t to_scalar_t(const std::pair<scalar_t, scalar_t>& node) {
return std::get<1>(node);
}

inline scalar_t to_node_t(const scalar_t& node, const scalar_t& ref) {
return node;
}

inline std::pair<scalar_t, scalar_t> to_node_t(
const scalar_t& node,
const std::pair<scalar_t, scalar_t>& ref) {
Expand Down