Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 741bdf5

Browse files
authored
Add multi_tensor_unscale_l2norm_cuda (#1727)
* Add multi_tensor_scale_l2norm * Rename * Add unit test * Fix unit test --------- Co-authored-by: Jaemin Choi <[email protected]>
1 parent 52e18c8 commit 741bdf5

3 files changed

Lines changed: 250 additions & 0 deletions

File tree

csrc/amp_C_frontend.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
4646
float scale,
4747
at::optional<bool> per_tensor_python);
4848

49+
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
50+
int chunk_size,
51+
at::Tensor noop_flag,
52+
std::vector<std::vector<at::Tensor>> tensor_lists,
53+
at::Tensor inv_scale,
54+
at::optional<bool> per_tensor_python);
55+
4956
void multi_tensor_lamb_stage1_cuda(
5057
int chunk_size,
5158
at::Tensor noop_flag,
@@ -184,6 +191,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
184191
"Computes L2 norm for a list of contiguous tensors");
185192
m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda,
186193
"Computes L2 norm for a list of contiguous tensors and does scaling");
194+
m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda,
195+
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only performed for L2 norm computation, and tensors are not updated)");
187196
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
188197
"Computes update part of LAMB optimizer");
189198
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,

csrc/multi_tensor_l2norm_kernel.cu

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,91 @@ struct L2NormFunctor
109109
}
110110
};
111111

112+
template<typename x_t>
113+
struct UnscaleL2NormFunctor
114+
{
115+
__device__ __forceinline__ void operator()(
116+
int chunk_size,
117+
volatile int* noop_gmem,
118+
TensorListMetadata<1>& tl,
119+
const float* inv_scale,
120+
float* output,
121+
float* output_per_tensor,
122+
bool per_tensor,
123+
int max_chunks_per_tensor)
124+
{
125+
// I'd like this kernel to propagate infs/nans.
126+
// if(*noop_gmem == 1)
127+
// return;
128+
129+
int tensor_loc = tl.block_to_tensor[blockIdx.x];
130+
int chunk_idx = tl.block_to_chunk[blockIdx.x];
131+
int n = tl.sizes[tensor_loc];
132+
133+
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
134+
x += chunk_idx*chunk_size;
135+
136+
n -= chunk_idx*chunk_size;
137+
138+
__shared__ float s_vals[512];
139+
140+
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
141+
x_t r_x[ILP];
142+
for(int i = 0; i < ILP; i++)
143+
{
144+
vals[i] = 0.f;
145+
r_x[i] = 0;
146+
}
147+
148+
// to make things simple, we put aligned case in a different code path
149+
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
150+
{
151+
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
152+
{
153+
// load
154+
load_store(r_x, x, 0 , i_start);
155+
#pragma unroll
156+
for(int ii = 0; ii < ILP; ii++)
157+
{
158+
float next = static_cast<float>(r_x[ii]) * (*inv_scale);
159+
vals[ii] += next*next;
160+
}
161+
}
162+
}
163+
else
164+
{
165+
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
166+
{
167+
#pragma unroll
168+
for(int ii = 0; ii < ILP; ii++)
169+
{
170+
int i = i_start + threadIdx.x + ii*blockDim.x;
171+
if(i < n && i < chunk_size)
172+
{
173+
float next = static_cast<float>(x[i]) * (*inv_scale);
174+
vals[ii] += next*next;
175+
}
176+
}
177+
}
178+
}
179+
180+
float val = 0.f;
181+
for(int i = 0; i < ILP; i++)
182+
val += vals[i];
183+
184+
float final = reduce_block_into_lanes(s_vals, val);
185+
186+
if(threadIdx.x == 0)
187+
{
188+
if(!isfinite(final))
189+
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
190+
output[blockIdx.x] += final;
191+
if(per_tensor)
192+
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
193+
}
194+
}
195+
};
196+
112197
// Probably better to template, but since we are not likely to support other norm
113198
template<typename x_t>
114199
struct MaxNormFunctor
@@ -355,6 +440,73 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
355440
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
356441
}
357442

443+
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
444+
int chunk_size,
445+
at::Tensor noop_flag,
446+
std::vector<std::vector<at::Tensor>> tensor_lists,
447+
at::Tensor inv_scale,
448+
at::optional<bool> per_tensor_python)
449+
{
450+
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
451+
452+
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
453+
auto output = at::zeros({320}, float_options);
454+
455+
at::Tensor output_per_tensor;
456+
at::Tensor ret_per_tensor;
457+
458+
int ntensors = tensor_lists[0].size();
459+
int max_chunks_per_tensor = -1;
460+
461+
if(per_tensor)
462+
{
463+
for(int t = 0; t < ntensors; t++)
464+
{
465+
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
466+
if(max_chunks_this_tensor > max_chunks_per_tensor)
467+
max_chunks_per_tensor = max_chunks_this_tensor;
468+
}
469+
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
470+
ret_per_tensor = at::empty({ntensors}, float_options);
471+
}
472+
else
473+
{
474+
ret_per_tensor = at::empty({0}, float_options);
475+
}
476+
477+
DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
478+
multi_tensor_apply<1>(
479+
BLOCK_SIZE,
480+
chunk_size,
481+
noop_flag,
482+
tensor_lists,
483+
UnscaleL2NormFunctor<scalar_t_0>(),
484+
inv_scale.DATA_PTR<float>(),
485+
output.DATA_PTR<float>(),
486+
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
487+
per_tensor,
488+
max_chunks_per_tensor);)
489+
490+
AT_CUDA_CHECK(cudaGetLastError());
491+
// AT_CUDA_CHECK(cudaDeviceSynchronize());
492+
493+
// This involves one more small kernel launches, but will be negligible end to end.
494+
// I could get rid of these by hacking the functor + multi tensor harness with persistence
495+
// logic, but keeping it simple for now
496+
auto ret = at::empty({1}, output.options());
497+
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
498+
auto stream = at::cuda::getCurrentCUDAStream();
499+
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
500+
output.DATA_PTR<float>(),
501+
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
502+
ret.DATA_PTR<float>(),
503+
per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr,
504+
per_tensor,
505+
max_chunks_per_tensor);
506+
507+
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
508+
}
509+
358510

359511
// Compute and update grad norm
360512
// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import unittest
2+
3+
import functools as ft
4+
import itertools as it
5+
6+
from apex import amp
7+
import torch
8+
from torch import nn
9+
import torch.nn.functional as F
10+
11+
from utils import common_init, HALF, FLOAT,\
12+
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
13+
14+
try:
15+
import amp_C
16+
from amp_C import multi_tensor_unscale_l2norm
17+
from apex.multi_tensor_apply import MultiTensorApply
18+
disabled = False
19+
except ImportError as err:
20+
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
21+
disabled = True
22+
23+
24+
class TestMultiTensorUnscaleL2Norm(unittest.TestCase):
25+
26+
def setUp(self):
27+
common_init(self)
28+
self.val = 4.0
29+
self.inv_scale = 0.5
30+
self.inv_scale_cuda = torch.tensor([self.inv_scale], dtype=torch.float32, device='cuda')
31+
self.overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
32+
33+
def tearDown(self):
34+
pass
35+
36+
# The tensor creation here is written for convenience, not speed.
37+
def unscale_l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_tensor):
38+
self.overflow_buf.zero_()
39+
a = torch.full([sizea], self.val, dtype=torch.float32, device='cuda')
40+
b = torch.full([sizeb], self.val, dtype=torch.float32, device='cuda')
41+
42+
in_list = []
43+
for i in range(repeat_tensors):
44+
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
45+
46+
if per_tensor:
47+
norm, norm_per_tensor = applier(multi_tensor_unscale_l2norm, self.overflow_buf, [in_list], self.inv_scale_cuda, True)
48+
normab = torch.cat(((a * self.inv_scale).norm().view(1), (b * self.inv_scale).norm().view(1)))
49+
norm_per_tensor = norm_per_tensor.view(-1, 2)
50+
else:
51+
norm, _ = applier(multi_tensor_unscale_l2norm, self.overflow_buf, [in_list], self.inv_scale_cuda, True)
52+
53+
reference = torch.full([(sizea + sizeb)*repeat_tensors], self.val * self.inv_scale, dtype=torch.float32, device='cuda').norm()
54+
55+
self.assertTrue(torch.allclose(norm, reference))
56+
if per_tensor:
57+
self.assertTrue(torch.allclose(norm_per_tensor, normab))
58+
self.assertTrue(self.overflow_buf.item() == 0)
59+
60+
@unittest.skipIf(disabled, "amp_C is unavailable")
61+
def test_fuzz(self):
62+
input_size_pairs = (
63+
(7777*77, 555*555),
64+
(777, 555),
65+
(555, 2048*32+1),
66+
(2048*32+1, 555),
67+
(555, 2048*32),
68+
(2048*32, 555),
69+
(33333, 555),
70+
(555, 33333))
71+
appliers = (
72+
MultiTensorApply(2048*32),
73+
MultiTensorApply(333),
74+
MultiTensorApply(33333))
75+
repeat_tensors = (
76+
1,
77+
55)
78+
79+
for sizea, sizeb in input_size_pairs:
80+
for applier in appliers:
81+
for repeat in repeat_tensors:
82+
for in_type in (torch.float32, torch.float16):
83+
for per_tensor in (False, True):
84+
self.unscale_l2norm(sizea, sizeb, applier, repeat, in_type, per_tensor)
85+
86+
87+
88+
if __name__ == '__main__':
89+
unittest.main()

0 commit comments

Comments
 (0)