Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 87c4deb

Browse files
authored
64-bit indexing Adam (#1765)
* all i want for christmas is larger binaries and longer compile times * actually compare * woops
1 parent c07a4cf commit 87c4deb

3 files changed

Lines changed: 78 additions & 29 deletions

File tree

csrc/multi_tensor_adam.cu

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ typedef enum{
2020

2121
using MATH_T = float;
2222

23-
template<typename T, typename FULL_T>
23+
template<typename T, typename FULL_T, typename index_t>
2424
struct AdamFunctor
2525
{
2626
__device__ __forceinline__ void operator()(
27-
int chunk_size,
27+
index_t chunk_size,
2828
volatile int* noop_gmem,
2929
TensorListMetadata<4>& tl,
3030
const float beta1,
@@ -40,13 +40,13 @@ struct AdamFunctor
4040
// if(*noop_gmem == 1)
4141
// return;
4242

43-
int tensor_loc = tl.block_to_tensor[blockIdx.x];
43+
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
4444

4545
// potentially use to pass in list of scalar
4646
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
4747

48-
int chunk_idx = tl.block_to_chunk[blockIdx.x];
49-
int n = tl.sizes[tensor_loc];
48+
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
49+
index_t n = tl.sizes[tensor_loc];
5050

5151
T* g = (T*)tl.addresses[0][tensor_loc];
5252
g += chunk_idx*chunk_size;
@@ -63,7 +63,7 @@ struct AdamFunctor
6363
n -= chunk_idx*chunk_size;
6464

6565
// see note in multi_tensor_scale_kernel.cu
66-
for(int i_start = 0;
66+
for(index_t i_start = 0;
6767
i_start < n && i_start < chunk_size;
6868
i_start += blockDim.x*ILP)
6969
{
@@ -378,26 +378,61 @@ void multi_tensor_adam_cuda(
378378
bias_correction2 = 1 - std::pow(beta2, step);
379379
}
380380

381-
// Assume single type across p,g,m1,m2 now
382-
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
383-
tensor_lists[0][0].scalar_type(), 0, "adam",
384-
multi_tensor_apply<4>(
385-
BLOCK_SIZE,
386-
chunk_size,
387-
noop_flag,
388-
tensor_lists,
389-
AdamFunctor<scalar_t_0, float>(),
390-
beta1,
391-
beta2,
392-
bias_correction1,
393-
bias_correction2,
394-
epsilon,
395-
lr,
396-
(adamMode_t) mode,
397-
weight_decay); )
381+
size_t max_size = 0;
382+
bool requires_64bit_indexing = false;
383+
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
384+
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
385+
if (it2->numel() > max_size) {
386+
max_size = it2->numel();
387+
if (max_size >= INT_MAX) {
388+
requires_64bit_indexing = true;
389+
break;
390+
}
391+
}
392+
}
393+
if (requires_64bit_indexing) {
394+
break;
395+
}
396+
}
398397

398+
if (requires_64bit_indexing) {
399+
// Assume single type across p,g,m1,m2 now
400+
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
401+
tensor_lists[0][0].scalar_type(), 0, "adam",
402+
multi_tensor_apply<4>(
403+
(int64_t) BLOCK_SIZE,
404+
(int64_t) chunk_size,
405+
noop_flag,
406+
tensor_lists,
407+
AdamFunctor<scalar_t_0, float, int64_t>(),
408+
beta1,
409+
beta2,
410+
bias_correction1,
411+
bias_correction2,
412+
epsilon,
413+
lr,
414+
(adamMode_t) mode,
415+
weight_decay); )
416+
} else {
417+
// Assume single type across p,g,m1,m2 now
418+
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
419+
tensor_lists[0][0].scalar_type(), 0, "adam",
420+
multi_tensor_apply<4>(
421+
BLOCK_SIZE,
422+
chunk_size,
423+
noop_flag,
424+
tensor_lists,
425+
AdamFunctor<scalar_t_0, float, int32_t>(),
426+
beta1,
427+
beta2,
428+
bias_correction1,
429+
bias_correction2,
430+
epsilon,
431+
lr,
432+
(adamMode_t) mode,
433+
weight_decay); )
434+
}
399435
AT_CUDA_CHECK(cudaGetLastError());
400-
401436
}
402437

403438
void multi_tensor_adam_capturable_cuda(

csrc/multi_tensor_apply.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ template<int n> struct TensorListMetadata
2828

2929
template<typename T, typename U, typename... ArgTypes>
3030
__global__ void multi_tensor_apply_kernel(
31-
int chunk_size,
31+
int64_t chunk_size,
3232
volatile int* noop_flag,
3333
T tl,
3434
U callable,
@@ -40,8 +40,8 @@ __global__ void multi_tensor_apply_kernel(
4040

4141
template<int depth, typename T, typename... ArgTypes>
4242
void multi_tensor_apply(
43-
int block_size,
44-
int chunk_size,
43+
int64_t block_size,
44+
int64_t chunk_size,
4545
const at::Tensor& noop_flag,
4646
const std::vector<std::vector<at::Tensor>>& tensor_lists,
4747
T callable,
@@ -85,9 +85,9 @@ void multi_tensor_apply(
8585
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
8686
loc_tensor_info++;
8787

88-
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
88+
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
8989

90-
for(int chunk = 0; chunk < chunks_this_tensor; chunk++)
90+
for(auto chunk = 0; chunk < chunks_this_tensor; chunk++)
9191
{
9292
// std::cout << chunks_this_tensor << std::endl;
9393
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;

tests/L0/run_optimizers/test_adam.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,20 @@ def testNative(self):
232232

233233
self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
234234

235+
def testLargeTensor(self):
236+
t = torch.zeros(2359332864, dtype=torch.half, device='cuda')
237+
t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda')
238+
grad = torch.randn_like(t)
239+
t.grad = grad
240+
t2.grad = grad
241+
params = [t]
242+
params2 = [t2]
243+
optimizer = apex.optimizers.FusedAdam(params, lr=self.lr)
244+
optimizer.step()
245+
optimizer2 = torch.optim.Adam(params2, lr=self.lr)
246+
torch.testing.assert_close(t, t2)
247+
torch.cuda.synchronize()
248+
235249

236250
if __name__ == '__main__':
237251
unittest.main()

0 commit comments

Comments
 (0)