Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c02c6c8

Browse files
authored
[contrib] Add group_norm_v2 (#1887)
* Add group_norm_v2 * Refine the coding style * fix LB_SM_COUNT * coding style * add tests * add comments * comply with c++17 * fix data race
1 parent 8c23de0 commit c02c6c8

26 files changed

Lines changed: 2291 additions & 18 deletions
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pathlib
2+
3+
4+
hw_c_list = [
5+
(8 * 8, 1280),
6+
(8 * 8, 2560),
7+
(16 * 16, 640),
8+
(16 * 16, 1280),
9+
(16 * 16, 1920),
10+
(16 * 16, 2560),
11+
(32 * 32, 320),
12+
(32 * 32, 640),
13+
(32 * 32, 960),
14+
(32 * 32, 1280),
15+
(32 * 32, 1920),
16+
(64 * 64, 320),
17+
(64 * 64, 640),
18+
(64 * 64, 960),
19+
]
20+
21+
22+
def run():
23+
src_path = pathlib.Path(__file__).parent.absolute()
24+
25+
for f in src_path.glob("gn_cuda_inst_*.cu"):
26+
f.unlink()
27+
28+
for hw, c in hw_c_list:
29+
print(f"GN_CUDA_INST_DEFINE({hw}, {c})")
30+
with open(src_path / f"gn_cuda_inst_{hw}_{c}.cu", "w") as f:
31+
f.write(f"#include \"gn_cuda_host_template.cuh\"\n")
32+
f.write(f"\n")
33+
f.write(f"\n")
34+
f.write(f"namespace group_norm_v2 {{\n")
35+
f.write(f"\n")
36+
f.write(f"GN_CUDA_INST_DEFINE({hw}, {c})\n")
37+
f.write(f"\n")
38+
f.write(f"}} // namespace group_norm_v2\n")
39+
40+
with open(src_path / "gn_dispatch_hw_c.hpp", "w") as f:
41+
f.write(f"#pragma once\n")
42+
f.write(f"\n")
43+
f.write(f"#define DISPATCH_HW_C(hw, c, HW, C, ...) [&] {{ \\\n")
44+
for hw, c in hw_c_list:
45+
f.write(f" if (hw == {hw} && c == {c}) {{ constexpr int HW = {hw}, C = {c}; return __VA_ARGS__(); }} \\\n")
46+
f.write(f" throw std::invalid_argument(\"DISPATCH_HW_C \" + std::to_string(hw) + \" \" + std::to_string(c)); \\\n")
47+
f.write(f" }}()\n")
48+
49+
50+
if __name__ == "__main__":
51+
run()
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include <torch/extension.h>
2+
#include <ATen/cuda/CUDAContext.h>
3+
4+
#include "gn.hpp"
5+
6+
7+
namespace group_norm_v2 {
8+
9+
torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, bool silu, int num_groups, std::optional<torch::Tensor> mean_var_out, int sm_margin) {
10+
if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != torch::kFloat32)) {
11+
throw std::invalid_argument("gn dtype mismatch");
12+
}
13+
torch::Tensor out = torch::empty_like(x);
14+
float *ptr_mean_var_out = mean_var_out.has_value() ? mean_var_out->data_ptr<float>() : nullptr;
15+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
16+
int device_id = at::cuda::getCurrentCUDAStream().device().index();
17+
group_norm_v2::Meta meta;
18+
if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) {
19+
group_norm_v2::gn_cuda(
20+
(half *)out.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(),
21+
eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out,
22+
nullptr, nullptr, sm_margin, stream, device_id, &meta, true);
23+
} else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) {
24+
group_norm_v2::gn_cuda(
25+
(__nv_bfloat16 *)out.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(),
26+
eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out,
27+
nullptr, nullptr, sm_margin, stream, device_id, &meta, true);
28+
} else {
29+
throw std::invalid_argument("gn only supports half or bfloat16 input and weight");
30+
}
31+
torch::Tensor red_buffer = torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
32+
thread_local torch::Tensor barrier;
33+
if (barrier.size(0) < meta.barrier_size) {
34+
barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA));
35+
}
36+
if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) {
37+
group_norm_v2::gn_cuda(
38+
(half *)out.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(),
39+
eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out,
40+
red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false);
41+
} else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) {
42+
group_norm_v2::gn_cuda(
43+
(__nv_bfloat16 *)out.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(),
44+
eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out,
45+
red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false);
46+
} else {
47+
throw std::invalid_argument("gn only supports half or bfloat16 input and weight");
48+
}
49+
return out;
50+
}
51+
52+
auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch::Tensor b, torch::Tensor mean_var, float eps, bool silu, int num_groups, int sm_margin) {
53+
if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != torch::kFloat32) {
54+
throw std::invalid_argument("gn_bwd dtype mismatch");
55+
}
56+
torch::Tensor grad_input = torch::empty_like(x);
57+
torch::Tensor grad_weight = torch::empty_like(w);
58+
torch::Tensor grad_bias = torch::empty_like(w);
59+
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
60+
int device_id = at::cuda::getCurrentCUDAStream().device().index();
61+
group_norm_v2::Meta meta;
62+
if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) {
63+
group_norm_v2::gn_bwd_cuda(
64+
(half *)grad_input.data_ptr(), (half *)grad_weight.data_ptr(), (half *)grad_bias.data_ptr(),
65+
(half *)grad_output.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), mean_var.data_ptr<float>(),
66+
eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups,
67+
nullptr, nullptr, sm_margin, stream, device_id, &meta, true);
68+
} else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) {
69+
group_norm_v2::gn_bwd_cuda(
70+
(__nv_bfloat16 *)grad_input.data_ptr(), (__nv_bfloat16 *)grad_weight.data_ptr(), (__nv_bfloat16 *)grad_bias.data_ptr(),
71+
(__nv_bfloat16 *)grad_output.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), mean_var.data_ptr<float>(),
72+
eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups,
73+
nullptr, nullptr, sm_margin, stream, device_id, &meta, true);
74+
} else {
75+
throw std::invalid_argument("gn only supports half or bfloat16 input and weight");
76+
}
77+
torch::Tensor red_buffer = torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
78+
thread_local torch::Tensor barrier;
79+
if (barrier.size(0) < meta.barrier_size) {
80+
barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA));
81+
}
82+
if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) {
83+
group_norm_v2::gn_bwd_cuda(
84+
(half *)grad_input.data_ptr(), (half *)grad_weight.data_ptr(), (half *)grad_bias.data_ptr(),
85+
(half *)grad_output.data_ptr(), (half *)x.data_ptr(), (half *)w.data_ptr(), (half *)b.data_ptr(), mean_var.data_ptr<float>(),
86+
eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups,
87+
red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false);
88+
} else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) {
89+
group_norm_v2::gn_bwd_cuda(
90+
(__nv_bfloat16 *)grad_input.data_ptr(), (__nv_bfloat16 *)grad_weight.data_ptr(), (__nv_bfloat16 *)grad_bias.data_ptr(),
91+
(__nv_bfloat16 *)grad_output.data_ptr(), (__nv_bfloat16 *)x.data_ptr(), (__nv_bfloat16 *)w.data_ptr(), (__nv_bfloat16 *)b.data_ptr(), mean_var.data_ptr<float>(),
92+
eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups,
93+
red_buffer.data_ptr<float>(), barrier.data_ptr<unsigned>(), sm_margin, stream, device_id, nullptr, false);
94+
} else {
95+
throw std::invalid_argument("gn only supports half or bfloat16 input and weight");
96+
}
97+
return std::make_tuple(grad_input, grad_weight, grad_bias);
98+
}
99+
100+
} // namespace group_norm_v2
101+
102+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
103+
m.def("gn", &group_norm_v2::gn, py::arg("x"), py::arg("w"), py::arg("b"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("mean_var_out") = py::none(), py::arg("sm_margin") = 0, "");
104+
m.def("gn_bwd", &group_norm_v2::gn_bwd, py::arg("grad_output"), py::arg("x"), py::arg("w"), py::arg("b"), py::arg("mean_var"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("sm_margin") = 0, "");
105+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
#include <cuda_runtime.h>
5+
6+
7+
namespace group_norm_v2 {
8+
9+
struct Meta {
10+
int64_t red_buffer_size;
11+
int64_t barrier_size;
12+
int BLOCK_DIM_X;
13+
int C_PER_BLOCK;
14+
int ROWS_PER_BLOCK;
15+
int VEC_ELEMS;
16+
bool LOAD_TWICE;
17+
int BLOCKS_PER_SM;
18+
bool HARDWARE_CLUSTER;
19+
int wgrad_sync_method;
20+
};
21+
22+
template<typename T>
23+
void gn_cuda(T *out, T *x, T *w, T *b, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *mean_var_out, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only);
24+
25+
template<typename T>
26+
void gn_bwd_cuda(T *grad_input, T *grad_weight, T *grad_bias, T *grad_output, T *x, T *w, T *b, float *mean_var, float eps, bool silu, int64_t n, int64_t hw, int num_groups, int channels_per_group, float *red_buffer, unsigned *barrier, int sm_margin, cudaStream_t stream, int device_id, Meta *meta_ptr, bool meta_only);
27+
28+
} // namespace group_norm_v2
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include "gn.hpp"
2+
3+
#include <cstdio>
4+
#include <mutex>
5+
#include <stdexcept>
6+
7+
#include <cuda_runtime.h>
8+
#include <cuda_fp16.h>
9+
#include <cuda_bf16.h>
10+
11+
#include "gn_utils.hpp"
12+
#include "gn_dispatch_hw_c.hpp"
13+
14+
15+
#define DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, NUM_GROUPS, SILU, ...) [&] { \
16+
if (num_groups == 16 && silu == true) { constexpr int NUM_GROUPS = 16; constexpr bool SILU = true; return __VA_ARGS__(); } \
17+
if (num_groups == 32 && silu == false) { constexpr int NUM_GROUPS = 32; constexpr bool SILU = false; return __VA_ARGS__(); } \
18+
throw std::invalid_argument("DISPATCH_NUM_GROUPS_AND_SILU " + std::to_string(num_groups) + " " + std::to_string(silu)); \
19+
}()
20+
21+
namespace group_norm_v2 {
22+
23+
template<typename T, int HW, int C, int G, bool SILU>
24+
void gn_cuda_single_shape(GN_CUDA_HOST_PARAMS(T));
25+
26+
template<typename T, int HW, int C, int G, bool SILU>
27+
void gn_bwd_cuda_single_shape(GN_BWD_CUDA_HOST_PARAMS(T));
28+
29+
template<typename T>
30+
void gn_cuda(GN_CUDA_HOST_PARAMS(T)) {
31+
DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] {
32+
DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, [&] {
33+
return gn_cuda_single_shape<T, HW, C, G, SILU>(GN_CUDA_HOST_ARGS);
34+
});
35+
});
36+
}
37+
38+
template<typename T>
39+
void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(T)) {
40+
DISPATCH_HW_C(hw, num_groups * channels_per_group, HW, C, [&] {
41+
DISPATCH_NUM_GROUPS_AND_SILU(num_groups, silu, G, SILU, [&] {
42+
return gn_bwd_cuda_single_shape<T, HW, C, G, SILU>(GN_BWD_CUDA_HOST_ARGS);
43+
});
44+
});
45+
}
46+
47+
template void gn_cuda(GN_CUDA_HOST_PARAMS(half));
48+
template void gn_cuda(GN_CUDA_HOST_PARAMS(__nv_bfloat16));
49+
50+
template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(half));
51+
template void gn_bwd_cuda(GN_BWD_CUDA_HOST_PARAMS(__nv_bfloat16));
52+
53+
} // namespace group_norm_v2

0 commit comments

Comments
 (0)