@@ -20,11 +20,11 @@ typedef enum{
2020
2121using MATH_T = float ;
2222
23- template <typename T, typename FULL_T>
23+ template <typename T, typename FULL_T, typename index_t >
2424struct AdamFunctor
2525{
2626 __device__ __forceinline__ void operator ()(
27- int chunk_size,
27+ index_t chunk_size,
2828 volatile int * noop_gmem,
2929 TensorListMetadata<4 >& tl,
3030 const float beta1,
@@ -40,13 +40,13 @@ struct AdamFunctor
4040 // if(*noop_gmem == 1)
4141 // return;
4242
43- int tensor_loc = tl.block_to_tensor [blockIdx .x ];
43+ index_t tensor_loc = tl.block_to_tensor [blockIdx .x ];
4444
4545 // potentially use to pass in list of scalar
4646 // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
4747
48- int chunk_idx = tl.block_to_chunk [blockIdx .x ];
49- int n = tl.sizes [tensor_loc];
48+ index_t chunk_idx = tl.block_to_chunk [blockIdx .x ];
49+ index_t n = tl.sizes [tensor_loc];
5050
5151 T* g = (T*)tl.addresses [0 ][tensor_loc];
5252 g += chunk_idx*chunk_size;
@@ -63,7 +63,7 @@ struct AdamFunctor
6363 n -= chunk_idx*chunk_size;
6464
6565 // see note in multi_tensor_scale_kernel.cu
66- for (int i_start = 0 ;
66+ for (index_t i_start = 0 ;
6767 i_start < n && i_start < chunk_size;
6868 i_start += blockDim .x *ILP)
6969 {
@@ -378,26 +378,61 @@ void multi_tensor_adam_cuda(
378378 bias_correction2 = 1 - std::pow (beta2, step);
379379 }
380380
381- // Assume single type across p,g,m1,m2 now
382- DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT (
383- tensor_lists[0 ][0 ].scalar_type (), 0 , " adam" ,
384- multi_tensor_apply<4 >(
385- BLOCK_SIZE,
386- chunk_size,
387- noop_flag,
388- tensor_lists,
389- AdamFunctor<scalar_t_0, float >(),
390- beta1,
391- beta2,
392- bias_correction1,
393- bias_correction2,
394- epsilon,
395- lr,
396- (adamMode_t) mode,
397- weight_decay); )
381+ size_t max_size = 0 ;
382+ bool requires_64bit_indexing = false ;
383+ for (auto it = tensor_lists.begin (); it != tensor_lists.end (); it++) {
384+ for (auto it2 = it->begin (); it2 != it->end (); it2++) {
385+ if (it2->numel () > max_size) {
386+ max_size = it2->numel ();
387+ if (max_size >= INT_MAX) {
388+ requires_64bit_indexing = true ;
389+ break ;
390+ }
391+ }
392+ }
393+ if (requires_64bit_indexing) {
394+ break ;
395+ }
396+ }
398397
398+ if (requires_64bit_indexing) {
399+ // Assume single type across p,g,m1,m2 now
400+ DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT (
401+ tensor_lists[0 ][0 ].scalar_type (), 0 , " adam" ,
402+ multi_tensor_apply<4 >(
403+ (int64_t ) BLOCK_SIZE,
404+ (int64_t ) chunk_size,
405+ noop_flag,
406+ tensor_lists,
407+ AdamFunctor<scalar_t_0, float , int64_t >(),
408+ beta1,
409+ beta2,
410+ bias_correction1,
411+ bias_correction2,
412+ epsilon,
413+ lr,
414+ (adamMode_t) mode,
415+ weight_decay); )
416+ } else {
417+ // Assume single type across p,g,m1,m2 now
418+ DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT (
419+ tensor_lists[0 ][0 ].scalar_type (), 0 , " adam" ,
420+ multi_tensor_apply<4 >(
421+ BLOCK_SIZE,
422+ chunk_size,
423+ noop_flag,
424+ tensor_lists,
425+ AdamFunctor<scalar_t_0, float , int32_t >(),
426+ beta1,
427+ beta2,
428+ bias_correction1,
429+ bias_correction2,
430+ epsilon,
431+ lr,
432+ (adamMode_t) mode,
433+ weight_decay); )
434+ }
399435 AT_CUDA_CHECK (cudaGetLastError ());
400-
401436}
402437
403438void multi_tensor_adam_capturable_cuda (
0 commit comments