Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `bias` term to `grouped_matmul` and `segment_matmul` ([#161](https://github.com/pyg-team/pyg-lib/pull/161))
- Added `sampled_op` implementation ([#156](https://github.com/pyg-team/pyg-lib/pull/156), [#159](https://github.com/pyg-team/pyg-lib/pull/159), [#160](https://github.com/pyg-team/pyg-lib/pull/160))
### Changed
- Improved `[segment|grouped]_matmul` GPU implementation by reducing launch overheads ([#213](https://github.com/pyg-team/pyg-lib/pull/213))
- Sample the nodes with the same timestamp as seed nodes ([#187](https://github.com/pyg-team/pyg-lib/pull/187))
- Added `write-csv` (saves benchmark results as csv file) and `libraries` (determines which libraries will be used in benchmark) parameters ([#167](https://github.com/pyg-team/pyg-lib/pull/167))
- Enable benchmarking of neighbor sampler on temporal graphs ([#165](https://github.com/pyg-team/pyg-lib/pull/165))
Expand Down
121 changes: 57 additions & 64 deletions pyg_lib/csrc/ops/cuda/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,78 +20,78 @@ template <typename GemmKernel>
void run_grouped_gemm(const at::TensorList input,
const at::TensorList other,
const at::TensorList out) {
const auto num_matrices = input.size();
std::vector<at::Tensor> new_input, new_other, new_out;
std::vector<float*> ptr_A_host(num_matrices);
std::vector<float*> ptr_B_host(num_matrices);
std::vector<float*> ptr_C_host(num_matrices);

for (size_t i = 0; i < num_matrices; ++i) {
new_input.push_back(input[i].contiguous());
ptr_A_host[i] = new_input[i].data_ptr<float>();
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;

new_other.push_back(other[i].contiguous());
ptr_B_host[i] = new_other[i].data_ptr<float>();
const int64_t num_matrices = input.size();
const int64_t gemm_coord_size =
num_matrices * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
// Number of gemm args not including *problem_sizes
at::Tensor gemm_args =
at::empty({num_matrices * 6 + gemm_coord_size},
at::TensorOptions().dtype(at::kLong).pinned_memory(true));

new_out.push_back(out[i].contiguous());
ptr_C_host[i] = new_out[i].data_ptr<float>();
}
// Obtain pointers for each argument (on host)
int64_t* ld_A_data = gemm_args.data_ptr<int64_t>(); // Base pointer
int64_t* ld_B_data = ld_A_data + num_matrices;
int64_t* ld_C_data = ld_A_data + 2 * num_matrices;
int64_t* ptr_A_data = ld_A_data + 3 * num_matrices;
int64_t* ptr_B_data = ld_A_data + 4 * num_matrices;
int64_t* ptr_C_data = ld_A_data + 5 * num_matrices;
cutlass::gemm::GemmCoord* problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(ld_A_data + 6 * num_matrices);

cutlass::DeviceAllocation<float*> ptr_A;
ptr_A.reset(num_matrices);
ptr_A.copy_from_host(ptr_A_host.data());
// Set arguments into gemm_args from input args
for (size_t i = 0; i < num_matrices; ++i) {
auto new_in = input[i].contiguous();
auto new_other = other[i].contiguous();
auto new_out = out[i].contiguous();
auto m = new_in.size(0), k = new_other.size(1), n = new_out.size(1);

cutlass::DeviceAllocation<float*> ptr_B;
ptr_B.reset(num_matrices);
ptr_B.copy_from_host(ptr_B_host.data());
problem_sizes_data[i] = cutlass::gemm::GemmCoord(m, n, k);

cutlass::DeviceAllocation<float*> ptr_C;
ptr_C.reset(num_matrices);
ptr_C.copy_from_host(ptr_C_host.data());
ld_A_data[i] = GemmKernel::LayoutA::packed({m, k}).stride(0);
ld_B_data[i] = GemmKernel::LayoutB::packed({k, n}).stride(0);
ld_C_data[i] = GemmKernel::LayoutC::packed({m, n}).stride(0);

std::vector<cutlass::gemm::GemmCoord> all_problems(num_matrices);
std::vector<int64_t> ld_A_host(num_matrices);
std::vector<int64_t> ld_B_host(num_matrices);
std::vector<int64_t> ld_C_host(num_matrices);
for (size_t i = 0; i < num_matrices; ++i) {
auto m = new_input[i].size(0), k = new_input[i].size(1),
n = new_out[i].size(1);
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
ld_A_host[i] = GemmKernel::LayoutA::packed({m, k}).stride(0);
ld_B_host[i] = GemmKernel::LayoutB::packed({k, n}).stride(0);
ld_C_host[i] = GemmKernel::LayoutC::packed({m, n}).stride(0);
ptr_A_data[i] = reinterpret_cast<int64_t>(new_in.data_ptr<float>());
ptr_B_data[i] = reinterpret_cast<int64_t>(new_other.data_ptr<float>());
ptr_C_data[i] = reinterpret_cast<int64_t>(new_out.data_ptr<float>());
}

cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> all_problems_device;
all_problems_device.reset(num_matrices);
all_problems_device.copy_from_host(all_problems.data());

cutlass::DeviceAllocation<int64_t> ld_A;
ld_A.reset(num_matrices);
ld_A.copy_from_host(ld_A_host.data());
// Transfer arguments to GPU
gemm_args = gemm_args.to(out[0].device(), true);

cutlass::DeviceAllocation<int64_t> ld_B;
ld_B.reset(num_matrices);
ld_B.copy_from_host(ld_B_host.data());

cutlass::DeviceAllocation<int64_t> ld_C;
ld_C.reset(num_matrices);
ld_C.copy_from_host(ld_C_host.data());
// Obtain pointers for each of arguments (on GPU)
ld_A_data = gemm_args.data_ptr<int64_t>(); // Base pointer
ld_B_data = ld_A_data + num_matrices;
ld_C_data = ld_A_data + 2 * num_matrices;
ptr_A_data = ld_A_data + 3 * num_matrices;
ptr_B_data = ld_A_data + 4 * num_matrices;
ptr_C_data = ld_A_data + 5 * num_matrices;
problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(ld_A_data + 6 * num_matrices);

using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;
typename EpilogueOutputOp::Params epilogue_op(1.0, 0.0);

using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
typename GemmGrouped::Arguments args(
all_problems_device.get(), num_matrices, /*threadblock_count=*/1024,
epilogue_op, ptr_A.get(), ptr_B.get(), ptr_C.get(), ptr_C.get(),
ld_A.get(), ld_B.get(), ld_C.get(), ld_C.get());
// Create GemmGrouped::Arguments using the arguments prepared above
typename GemmGrouped::Arguments args(problem_sizes_data, num_matrices,
/*threadblock_count=*/1024, epilogue_op,
reinterpret_cast<float**>(ptr_A_data),
reinterpret_cast<float**>(ptr_B_data),
reinterpret_cast<float**>(ptr_C_data),
reinterpret_cast<float**>(ptr_C_data),
ld_A_data, ld_B_data, ld_C_data,
ld_C_data);

GemmGrouped gemm;
auto status = gemm.initialize(args);
auto status =
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == cutlass::Status::kSuccess, "GroupedGEMM init failed");
status = gemm.run();
status = gemm.run(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == cutlass::Status::kSuccess, "GroupedGEMM run failed");

C10_CUDA_KERNEL_LAUNCH_CHECK();
}

// Returns the amount of shared memory required per threadblock in
Expand All @@ -106,15 +106,8 @@ int shared_memory_for_kernel() {
cudaDeviceProp get_dev_prop() {
cudaDeviceProp properties;
int device_idx;
cudaError_t result = cudaGetDevice(&device_idx);
if (result != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(result));
}

result = cudaGetDeviceProperties(&properties, device_idx);
if (result != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(result));
}
C10_CUDA_CHECK(cudaGetDevice(&device_idx));
C10_CUDA_CHECK(cudaGetDeviceProperties(&properties, device_idx));
return properties;
}
cudaDeviceProp props;
Expand Down
2 changes: 2 additions & 0 deletions pyg_lib/csrc/ops/cuda/sampled_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ at::Tensor sampled_op_kernel(const at::Tensor& left,
left_data, right_data, out_data, left_index_data, right_index_data,
to_fn_type.at(fn), left_index.has_value(), right_index.has_value(),
num_feats, numel);

C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return out;
}
Expand Down
2 changes: 2 additions & 0 deletions pyg_lib/csrc/sampler/cuda/random_walk_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ at::Tensor random_walk_kernel(const at::Tensor& rowptr,
random_walk_kernel_impl<<<blocks(seed.size(0)), threads(), 0, stream>>>(
rowptr_data, col_data, seed_data, rand_data, out_data, seed.size(0),
walk_length);

C10_CUDA_KERNEL_LAUNCH_CHECK();
});

return out.t().contiguous();
Expand Down