Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61), [#64](https://github.com/pyg-team/pyg-lib/pull/64), [#69](https://github.com/pyg-team/pyg-lib/pull/69), [#73](https://github.com/pyg-team/pyg-lib/pull/73))
- Added `pyg::sampler::neighbor_sample` implementation ([#54](https://github.com/pyg-team/pyg-lib/pull/54), [#76](https://github.com/pyg-team/pyg-lib/pull/76), [#77](https://github.com/pyg-team/pyg-lib/pull/77), [#78](https://github.com/pyg-team/pyg-lib/pull/78), [#80](https://github.com/pyg-team/pyg-lib/pull/80), [#81](https://github.com/pyg-team/pyg-lib/pull/81)), [#85](https://github.com/pyg-team/pyg-lib/pull/85), [#86](https://github.com/pyg-team/pyg-lib/pull/86), [#87](https://github.com/pyg-team/pyg-lib/pull/87), [#89](https://github.com/pyg-team/pyg-lib/pull/89))
- Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45), [#83](https://github.com/pyg-team/pyg-lib/pull/83))
- Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45), [#79](https://github.com/pyg-team/pyg-lib/pull/79), [#82](https://github.com/pyg-team/pyg-lib/pull/82), [#91](https://github.com/pyg-team/pyg-lib/pull/91), [#93](https://github.com/pyg-team/pyg-lib/pull/93))
- Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45), [#79](https://github.com/pyg-team/pyg-lib/pull/79), [#82](https://github.com/pyg-team/pyg-lib/pull/82), [#91](https://github.com/pyg-team/pyg-lib/pull/91), [#93](https://github.com/pyg-team/pyg-lib/pull/93), [#106](https://github.com/pyg-team/pyg-lib/pull/106))
- Added download script for benchmark data ([#44](https://github.com/pyg-team/pyg-lib/pull/44))
- Added `biased sampling` utils ([#38](https://github.com/pyg-team/pyg-lib/pull/38))
- Added `CHANGELOG.md` ([#39](https://github.com/pyg-team/pyg-lib/pull/39))
Expand Down
94 changes: 94 additions & 0 deletions benchmark/sampler/hetero_neighbor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import argparse
import ast
import time

import torch
import torch_sparse # noqa
from tqdm import tqdm

import pyg_lib
from pyg_lib.testing import remap_keys, withDataset, withSeed

argparser = argparse.ArgumentParser('Hetero neighbor sample benchmark')
argparser.add_argument('--batch-sizes', nargs='+', type=int, default=[
512,
1024,
2048,
4096,
8192,
])
argparser.add_argument('--num_neighbors', type=ast.literal_eval, default=[
[-1],
[15, 10, 5],
[20, 15, 10],
])
# TODO(kgajdamo): Enable sampling with replacement
# argparser.add_argument('--replace', action='store_true')
# TODO (kgajdamo): Support undirected hetero graphs
# argparser.add_argument('--directed', action='store_true')
argparser.add_argument('--shuffle', action='store_true')
args = argparser.parse_args()


@withSeed
@withDataset('ogb', 'mag')
def test_hetero_neighbor(dataset, **kwargs):
colptr_dict, row_dict = dataset
num_nodes_dict = {k[-1]: v.size(0) - 1 for k, v in colptr_dict.items()}

if args.shuffle:
node_perm = torch.randperm(num_nodes_dict['paper'])
else:
node_perm = torch.arange(0, num_nodes_dict['paper'])

for num_neighbors in args.num_neighbors:
num_neighbors_dict = {key: num_neighbors for key in colptr_dict.keys()}

for batch_size in args.batch_sizes:
print(f'batch_size={batch_size}, num_neighbors={num_neighbors}):')
t = time.perf_counter()
for seed in tqdm(node_perm.split(batch_size)[:20]):
seed_dict = {'paper': seed}
pyg_lib.sampler.hetero_neighbor_sample(
colptr_dict,
row_dict,
seed_dict,
num_neighbors_dict,
None, # time_dict
True, # csc
False, # replace
True, # directed
)
pyg_lib_duration = time.perf_counter() - t

t = time.perf_counter()
for seed in tqdm(node_perm.split(batch_size)[:20]):
node_types = list(num_nodes_dict.keys())
edge_types = list(colptr_dict.keys())
colptr_dict_sparse = remap_keys(colptr_dict)
row_dict_sparse = remap_keys(row_dict)
seed_dict = {'paper': seed}
num_neighbors_dict_sparse = remap_keys(num_neighbors_dict)
num_hops = max([len(v) for v in [num_neighbors]])
torch.ops.torch_sparse.hetero_neighbor_sample(
node_types,
edge_types,
colptr_dict_sparse,
row_dict_sparse,
seed_dict,
num_neighbors_dict_sparse,
num_hops,
False, # replace
True, # directed
)
torch_sparse_duration = time.perf_counter() - t

# TODO (kgajdamo): Add dgl hetero sampler.

print(f' pyg-lib={pyg_lib_duration:.3f} seconds')
print(f'torch-sparse={torch_sparse_duration:.3f} seconds')
print()


if __name__ == '__main__':
test_hetero_neighbor()
2 changes: 1 addition & 1 deletion pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def hetero_neighbor_sample(
for k, v in num_neighbors_dict.items()
}

out = torch.ops.pyg.hetero_neighbor_sample(
out = torch.ops.pyg.hetero_neighbor_sample_cpu(
node_types,
edge_types,
rowptr_dict,
Expand Down
60 changes: 53 additions & 7 deletions pyg_lib/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import os.path as osp
from typing import Callable, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -30,12 +30,18 @@ def wrapper(*args, **kwargs):
def withDataset(group: str, name: str) -> Callable:
def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
dataset = get_sparse_matrix(
group,
name,
dtype=kwargs.get('dtype', torch.long),
device=kwargs.get('device', None),
)
if group == 'ogb' and name == 'mag':
dataset = get_ogb_mag_hetero_sparse_matrix(
dtype=kwargs.get('dtype', torch.long),
device=kwargs.get('device', None),
)
else:
dataset = get_sparse_matrix(
group,
name,
dtype=kwargs.get('dtype', torch.long),
device=kwargs.get('device', None),
)

func(*args, dataset=dataset, **kwargs)

Expand Down Expand Up @@ -89,7 +95,47 @@ def get_sparse_matrix(
return rowptr, col


def get_ogb_mag_hetero_sparse_matrix(
dtype: torch.dtype = torch.long,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
r"""Returns a heterogeneous graph :obj:`(colptr_dict, row_dict)`
from the `OGB <https://ogb.stanford.edu/>`_ benchmark suite.

Args:
dtype (torch.dtype, optional): The desired data type of returned
tensors. (default: :obj:`torch.long`)
device (torch.device, optional): the desired device of returned
tensors. (default: :obj:`None`)

Returns:
(Dict[Tuple[str, str, str], torch.Tensor],
Dict[Tuple[str, str, str], torch.Tensor], int, List,
List[Tuple[str, str, str]]): Compressed source node indices and target
node indices of the hetero sparse matrix, number of paper nodes,
all node types and all edge types.
"""
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG

path = osp.join(get_home_dir(), 'ogb-mag')
transform = T.Compose([T.ToUndirected(), T.ToSparseTensor()])
data = OGB_MAG(path, pre_transform=transform)[0]

colptr_dict, row_dict = {}, {}
for edge_type in data.edge_types:
colptr, row, _ = data[edge_type].adj_t.csr()
colptr_dict[edge_type] = colptr.to(device, dtype)
row_dict[edge_type] = row.to(device, dtype)

return colptr_dict, row_dict


def to_edge_index(rowptr: Tensor, col: Tensor) -> Tensor:
row = torch.arange(rowptr.size(0) - 1, dtype=col.dtype, device=col.device)
row = row.repeat_interleave(rowptr[1:] - rowptr[:-1])
return torch.stack([row, col], dim=0)


def remap_keys(mapping: Dict[Tuple[str, str, str], Any]) -> Dict[str, Any]:
return {'__'.join(k): v for k, v in mapping.items()}