Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 6 additions & 40 deletions pyg_lib/csrc/ops/cuda/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,13 @@
#include <cutlass/gemm/kernel/default_gemm_grouped.h>
#include <cutlass/util/host_tensor.h>
#include <torch/library.h>
#include <torch/nn/functional/padding.h>

#include "pyg_lib/csrc/utils/convert.h"

namespace pyg {
namespace ops {

namespace {
namespace F = torch::nn::functional;

at::Tensor pad_dim(const at::Tensor& input, int dim) {
int to_pad = (ceil(input.size(dim) / 4.0) * 4) - input.size(dim);
// int to_pad = ((input.size(dim) + 3 / 4) * 4) - input.size(dim);
if (dim == -1) {
return F::pad(input,
F::PadFuncOptions({0, to_pad, 0, 0}).mode(torch::kConstant));
} else {
return F::pad(input,
F::PadFuncOptions({0, 0, 0, to_pad}).mode(torch::kConstant));
}
}

at::Tensor pad_both(const at::Tensor& input) {
int dim_0_pad = (ceil(input.size(-2) / 4.0) * 4) - input.size(-2);
int dim_1_pad = (ceil(input.size(-1) / 4.0) * 4) - input.size(-1);
return F::pad(
input,
F::PadFuncOptions({0, dim_1_pad, 0, dim_0_pad}).mode(torch::kConstant));
}

void grouped_matmul_out_kernel(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other,
Expand All @@ -48,11 +26,11 @@ void grouped_matmul_out_kernel(const std::vector<at::Tensor>& input,
float, // Element A
cutlass::layout::RowMajor, // Layout A
cutlass::ComplexTransform::kNone, //
4, // Granularity A
1, // Granularity A
float, // Element B
cutlass::layout::RowMajor, // Layout B
cutlass::ComplexTransform::kNone, //
4, // Granularity B
1, // Granularity B
float, // Element C&D
cutlass::layout::RowMajor, // Layout C&D
float, // Element Accumulator
Expand All @@ -62,7 +40,7 @@ void grouped_matmul_out_kernel(const std::vector<at::Tensor>& input,
cutlass::gemm::GemmShape<64, 64, 32>, // Warp-level Tile
cutlass::gemm::GemmShape<16, 8, 8>, // Warp-level Tile
cutlass::epilogue::thread::LinearCombination< // Epilogue
float, 4, float, float>, //
float, 1, float, float>, //
cutlass::gemm::threadblock:: // Swizzling Operator
GemmIdentityThreadblockSwizzle<8>, //
3, // Stages
Expand All @@ -74,25 +52,13 @@ void grouped_matmul_out_kernel(const std::vector<at::Tensor>& input,
std::vector<float*> ptr_C_host(num_matrices);

for (size_t i = 0; i < num_matrices; ++i) {
if (input[i].size(-1) % 4 != 0) {
new_input.push_back(pad_dim(input[i], -1).contiguous());
} else {
new_input.push_back(input[i].contiguous());
}
new_input.push_back(input[i].contiguous());
ptr_A_host[i] = new_input[i].data_ptr<float>();

if (other[i].size(-1) % 4 != 0 || other[i].size(-2) % 4 != 0) {
new_other.push_back(pad_both(other[i]).contiguous());
} else {
new_other.push_back(other[i].contiguous());
}
new_other.push_back(other[i].contiguous());
ptr_B_host[i] = new_other[i].data_ptr<float>();

if (out[i].size(-1) % 4 != 0) {
new_out.push_back(pad_dim(out[i], -1).contiguous());
} else {
new_out.push_back(out[i].contiguous());
}
new_out.push_back(out[i].contiguous());
ptr_C_host[i] = new_out[i].data_ptr<float>();
}

Expand Down