|
| 1 | +#include <torch/extension.h> |
| 2 | +#include <ATen/cuda/CUDAContext.h> |
| 3 | + |
| 4 | +#include "gn.hpp" |
| 5 | + |
| 6 | + |
| 7 | +namespace group_norm_v2 { |
| 8 | + |
| 9 | +torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, bool silu, int num_groups, std::optional<torch::Tensor> mean_var_out, int sm_margin) { |
| 10 | + if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != torch::kFloat32)) { |
| 11 | + throw std::invalid_argument("gn dtype mismatch"); |
| 12 | + } |
| 13 | + torch::Tensor out = torch::empty_like(x); |
| 14 | + float *ptr_mean_var_out = mean_var_out.has_value() ? mean_var_out->data_ptr<float>() : nullptr; |
| 15 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
| 16 | + int device_id = at::cuda::getCurrentCUDAStream().device().index(); |
| 17 | + group_norm_v2::Meta meta; |
| 18 | + if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { |
| 19 | + group_norm_v2::gn_cuda( |
| 20 | + (half *)out.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), |
| 21 | + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, |
| 22 | + nullptr, nullptr, sm_margin, stream, device_id, &meta, true); |
| 23 | + } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { |
| 24 | + group_norm_v2::gn_cuda( |
| 25 | + (__nv_bfloat16 *)out.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), |
| 26 | + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, |
| 27 | + nullptr, nullptr, sm_margin, stream, device_id, &meta, true); |
| 28 | + } else { |
| 29 | + throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); |
| 30 | + } |
| 31 | + torch::Tensor red_buffer = torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); |
| 32 | + thread_local torch::Tensor barrier; |
| 33 | + if (barrier.size(0) < meta.barrier_size) { |
| 34 | + barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); |
| 35 | + } |
| 36 | + if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { |
| 37 | + group_norm_v2::gn_cuda( |
| 38 | + (half *)out.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), |
| 39 | + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, |
| 40 | + red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false); |
| 41 | + } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { |
| 42 | + group_norm_v2::gn_cuda( |
| 43 | + (__nv_bfloat16 *)out.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), |
| 44 | + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, |
| 45 | + red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false); |
| 46 | + } else { |
| 47 | + throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); |
| 48 | + } |
| 49 | + return out; |
| 50 | +} |
| 51 | + |
| 52 | +auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch::Tensor b, torch::Tensor mean_var, float eps, bool silu, int num_groups, int sm_margin) { |
| 53 | + if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != torch::kFloat32) { |
| 54 | + throw std::invalid_argument("gn_bwd dtype mismatch"); |
| 55 | + } |
| 56 | + torch::Tensor grad_input = torch::empty_like(x); |
| 57 | + torch::Tensor grad_weight = torch::empty_like(w); |
| 58 | + torch::Tensor grad_bias = torch::empty_like(w); |
| 59 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); |
| 60 | + int device_id = at::cuda::getCurrentCUDAStream().device().index(); |
| 61 | + group_norm_v2::Meta meta; |
| 62 | + if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { |
| 63 | + group_norm_v2::gn_bwd_cuda( |
| 64 | + (half *)grad_input.data_ptr(), (half *)grad_weight.data_ptr(), (half *)grad_bias.data_ptr(), |
| 65 | + (half *)grad_output.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), mean_var.data_ptr<float>(), |
| 66 | + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, |
| 67 | + nullptr, nullptr, sm_margin, stream, device_id, &meta, true); |
| 68 | + } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { |
| 69 | + group_norm_v2::gn_bwd_cuda( |
| 70 | + (__nv_bfloat16 *)grad_input.data_ptr(), (__nv_bfloat16 *)grad_weight.data_ptr(), (__nv_bfloat16 *)grad_bias.data_ptr(), |
| 71 | + (__nv_bfloat16 *)grad_output.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), mean_var.data_ptr<float>(), |
| 72 | + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, |
| 73 | + nullptr, nullptr, sm_margin, stream, device_id, &meta, true); |
| 74 | + } else { |
| 75 | + throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); |
| 76 | + } |
| 77 | + torch::Tensor red_buffer = torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); |
| 78 | + thread_local torch::Tensor barrier; |
| 79 | + if (barrier.size(0) < meta.barrier_size) { |
| 80 | + barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); |
| 81 | + } |
| 82 | + if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { |
| 83 | + group_norm_v2::gn_bwd_cuda( |
| 84 | + (half *)grad_input.data_ptr(), (half *)grad_weight.data_ptr(), (half *)grad_bias.data_ptr(), |
| 85 | + (half *)grad_output.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), mean_var.data_ptr<float>(), |
| 86 | + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, |
| 87 | + red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false); |
| 88 | + } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { |
| 89 | + group_norm_v2::gn_bwd_cuda( |
| 90 | + (__nv_bfloat16 *)grad_input.data_ptr(), (__nv_bfloat16 *)grad_weight.data_ptr(), (__nv_bfloat16 *)grad_bias.data_ptr(), |
| 91 | + (__nv_bfloat16 *)grad_output.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), mean_var.data_ptr<float>(), |
| 92 | + eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, |
| 93 | + red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false); |
| 94 | + } else { |
| 95 | + throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); |
| 96 | + } |
| 97 | + return std::make_tuple(grad_input, grad_weight, grad_bias); |
| 98 | +} |
| 99 | + |
| 100 | +} // namespace group_norm_v2 |
| 101 | + |
| 102 | +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| 103 | + m.def("gn", &group_norm_v2::gn, py::arg("x"), py::arg("w"), py::arg("b"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("mean_var_out") = py::none(), py::arg("sm_margin") = 0, ""); |
| 104 | + m.def("gn_bwd", &group_norm_v2::gn_bwd, py::arg("grad_output"), py::arg("x"), py::arg("w"), py::arg("b"), py::arg("mean_var"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("sm_margin") = 0, ""); |
| 105 | +} |
0 commit comments