Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1a48b26

Browse files
Kernel + sizes stress test
1 parent e57f5d0 commit 1a48b26

3 files changed

Lines changed: 166 additions & 3 deletions

File tree

csrc/multi_tensor_l2norm_kernel.cu

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/AccumulateType.h>
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <ATen/cuda/Exceptions.h>
5+
// Another possibility:
6+
// #include <torch/all.h>
7+
8+
#include <assert.h>
9+
10+
#include "type_shim.h"
11+
#include "multi_tensor_apply.cuh"
12+
13+
#define BLOCK_SIZE 512
14+
#define ILP 4
15+
16+
template<typename x_t>
17+
struct L2NormFunctor
18+
{
19+
__device__ __forceinline__ void operator()(
20+
int chunk_size,
21+
volatile int* noop_gmem,
22+
TensorListMetadata<1>& tl,
23+
float* output)
24+
{
25+
// I'd like this kernel to propagate infs/nans.
26+
// if(*noop_gmem == 1)
27+
// return;
28+
29+
int tensor_loc = tl.block_to_tensor[blockIdx.x];
30+
int chunk_idx = tl.block_to_chunk[blockIdx.x];
31+
int n = tl.sizes[tensor_loc];
32+
33+
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
34+
x += chunk_idx*chunk_size;
35+
36+
n -= chunk_idx*chunk_size;
37+
38+
__shared__ float vals[512];
39+
40+
// Non-divergent exit condition for __syncthreads, not necessary here
41+
float val = 0;
42+
for(int i = threadIdx.x; i < n && i < chunk_size; i += blockDim.x)
43+
{
44+
float next = static_cast<float>(x[i]);
45+
val += next*next;
46+
}
47+
48+
float final = reduce_block_into_lanes(vals, val);
49+
50+
if(threadIdx.x == 0)
51+
{
52+
if(!isfinite(final))
53+
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
54+
output[blockIdx.x] += final;
55+
}
56+
}
57+
};
58+
59+
at::Tensor multi_tensor_l2norm_cuda(
60+
int chunk_size,
61+
at::Tensor noop_flag,
62+
std::vector<std::vector<at::Tensor>> tensor_lists)
63+
{
64+
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::ScalarType::Float));
65+
66+
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
67+
multi_tensor_apply<1>(
68+
BLOCK_SIZE,
69+
chunk_size,
70+
noop_flag,
71+
tensor_lists,
72+
L2NormFunctor<scalar_t_0>(),
73+
output.data<float>());)
74+
75+
AT_CUDA_CHECK(cudaGetLastError());
76+
77+
// AT_CUDA_CHECK(cudaDeviceSynchronize());
78+
79+
// This involves two more small kernel launches, but will be negligible end to end.
80+
// I could get rid of these by hacking the functor + multi tensor harness with persistence
81+
// logic, but keeping it simple for now
82+
return output.sum().sqrt();
83+
}

csrc/type_shim.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ struct TypeShim
3333
}
3434

3535

36-
template<typename T, typename ReduceOp>
36+
template<typename T>
3737
__device__ __forceinline__ T reduce_block_into_lanes
3838
(T *x,
3939
T val,
40-
int lanes,
41-
bool share_result) // lanes is intended to be <= 32.
40+
int lanes=1,
41+
bool share_result=false) // lanes is intended to be <= 32.
4242
{
4343
int tid = threadIdx.x + threadIdx.y*blockDim.x;
4444
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import unittest
2+
3+
import functools as ft
4+
import itertools as it
5+
6+
from apex import amp
7+
import torch
8+
from torch import nn
9+
import torch.nn.functional as F
10+
11+
from utils import common_init, HALF, FLOAT,\
12+
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
13+
14+
try:
15+
import amp_C
16+
from amp_C import multi_tensor_l2norm
17+
from apex.multi_tensor_apply import MultiTensorApply
18+
disabled = False
19+
except ImportError as err:
20+
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
21+
disabled = True
22+
23+
24+
class TestMultiTensorL2Norm(unittest.TestCase):
25+
26+
def setUp(self):
27+
common_init(self)
28+
self.val = 4.0
29+
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
30+
31+
def tearDown(self):
32+
pass
33+
34+
# The tensor creation here is written for convenience, not speed.
35+
def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type):
36+
self.overflow_buf.zero_()
37+
a = torch.cuda.FloatTensor(sizea).fill_(self.val)
38+
b = torch.cuda.FloatTensor(sizeb).fill_(self.val)
39+
40+
in_list = []
41+
for i in range(repeat_tensors):
42+
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
43+
44+
45+
norm = applier(multi_tensor_l2norm, self.overflow_buf, [in_list])
46+
47+
reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm()
48+
49+
self.assertTrue(torch.allclose(norm, reference))
50+
self.assertTrue(self.overflow_buf.item() == 0)
51+
52+
@unittest.skipIf(disabled, "amp_C is unavailable")
53+
def test_fuzz(self):
54+
input_size_pairs = (
55+
(7777*77, 555*555),
56+
(777, 555),
57+
(555, 2048*32+1),
58+
(2048*32+1, 555),
59+
(555, 2048*32),
60+
(2048*32, 555),
61+
(33333, 555),
62+
(555, 33333))
63+
appliers = (
64+
MultiTensorApply(2048*32),
65+
MultiTensorApply(333),
66+
MultiTensorApply(33333))
67+
repeat_tensors = (
68+
1,
69+
55)
70+
71+
for sizea, sizeb in input_size_pairs:
72+
for applier in appliers:
73+
for repeat in repeat_tensors:
74+
for in_type in (torch.float32, torch.float16):
75+
self.l2norm(sizea, sizeb, applier, repeat, in_type, )
76+
77+
78+
79+
if __name__ == '__main__':
80+
unittest.main()

0 commit comments

Comments
 (0)