@@ -109,6 +109,91 @@ struct L2NormFunctor
109109 }
110110};
111111
112+ template <typename x_t >
113+ struct UnscaleL2NormFunctor
114+ {
115+ __device__ __forceinline__ void operator ()(
116+ int chunk_size,
117+ volatile int * noop_gmem,
118+ TensorListMetadata<1 >& tl,
119+ const float * inv_scale,
120+ float * output,
121+ float * output_per_tensor,
122+ bool per_tensor,
123+ int max_chunks_per_tensor)
124+ {
125+ // I'd like this kernel to propagate infs/nans.
126+ // if(*noop_gmem == 1)
127+ // return;
128+
129+ int tensor_loc = tl.block_to_tensor [blockIdx .x ];
130+ int chunk_idx = tl.block_to_chunk [blockIdx .x ];
131+ int n = tl.sizes [tensor_loc];
132+
133+ x_t * x = (x_t *)tl.addresses [0 ][tensor_loc];
134+ x += chunk_idx*chunk_size;
135+
136+ n -= chunk_idx*chunk_size;
137+
138+ __shared__ float s_vals[512 ];
139+
140+ float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
141+ x_t r_x[ILP];
142+ for (int i = 0 ; i < ILP; i++)
143+ {
144+ vals[i] = 0 .f ;
145+ r_x[i] = 0 ;
146+ }
147+
148+ // to make things simple, we put aligned case in a different code path
149+ if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned (x))
150+ {
151+ for (int i_start = threadIdx .x ; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim .x )
152+ {
153+ // load
154+ load_store (r_x, x, 0 , i_start);
155+ #pragma unroll
156+ for (int ii = 0 ; ii < ILP; ii++)
157+ {
158+ float next = static_cast <float >(r_x[ii]) * (*inv_scale);
159+ vals[ii] += next*next;
160+ }
161+ }
162+ }
163+ else
164+ {
165+ for (int i_start = 0 ; i_start < n && i_start < chunk_size; i_start += blockDim .x *ILP)
166+ {
167+ #pragma unroll
168+ for (int ii = 0 ; ii < ILP; ii++)
169+ {
170+ int i = i_start + threadIdx .x + ii*blockDim .x ;
171+ if (i < n && i < chunk_size)
172+ {
173+ float next = static_cast <float >(x[i]) * (*inv_scale);
174+ vals[ii] += next*next;
175+ }
176+ }
177+ }
178+ }
179+
180+ float val = 0 .f ;
181+ for (int i = 0 ; i < ILP; i++)
182+ val += vals[i];
183+
184+ float final = reduce_block_into_lanes (s_vals, val);
185+
186+ if (threadIdx .x == 0 )
187+ {
188+ if (!isfinite (final ))
189+ *noop_gmem = 1 ; // Blindly fire off a write. These will race but that's ok.
190+ output[blockIdx .x ] += final ;
191+ if (per_tensor)
192+ output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final ;
193+ }
194+ }
195+ };
196+
112197// Probably better to template, but since we are not likely to support other norm
113198template <typename x_t >
114199struct MaxNormFunctor
@@ -355,6 +440,73 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
355440 return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
356441}
357442
443+ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda (
444+ int chunk_size,
445+ at::Tensor noop_flag,
446+ std::vector<std::vector<at::Tensor>> tensor_lists,
447+ at::Tensor inv_scale,
448+ at::optional<bool > per_tensor_python)
449+ {
450+ bool per_tensor = per_tensor_python.has_value () ? per_tensor_python.value () : false ;
451+
452+ auto float_options = tensor_lists[0 ][0 ].options ().dtype (at::kFloat );
453+ auto output = at::zeros ({320 }, float_options);
454+
455+ at::Tensor output_per_tensor;
456+ at::Tensor ret_per_tensor;
457+
458+ int ntensors = tensor_lists[0 ].size ();
459+ int max_chunks_per_tensor = -1 ;
460+
461+ if (per_tensor)
462+ {
463+ for (int t = 0 ; t < ntensors; t++)
464+ {
465+ int max_chunks_this_tensor = (tensor_lists[0 ][t].numel () + chunk_size - 1 )/chunk_size;
466+ if (max_chunks_this_tensor > max_chunks_per_tensor)
467+ max_chunks_per_tensor = max_chunks_this_tensor;
468+ }
469+ output_per_tensor = at::zeros ({ntensors*max_chunks_per_tensor}, float_options);
470+ ret_per_tensor = at::empty ({ntensors}, float_options);
471+ }
472+ else
473+ {
474+ ret_per_tensor = at::empty ({0 }, float_options);
475+ }
476+
477+ DISPATCH_FLOAT_HALF_AND_BFLOAT (tensor_lists[0 ][0 ].scalar_type (), 0 , " multi_tensor_unscale_l2norm_cuda" ,
478+ multi_tensor_apply<1 >(
479+ BLOCK_SIZE,
480+ chunk_size,
481+ noop_flag,
482+ tensor_lists,
483+ UnscaleL2NormFunctor<scalar_t_0>(),
484+ inv_scale.DATA_PTR <float >(),
485+ output.DATA_PTR <float >(),
486+ per_tensor ? output_per_tensor.DATA_PTR <float >() : nullptr ,
487+ per_tensor,
488+ max_chunks_per_tensor);)
489+
490+ AT_CUDA_CHECK (cudaGetLastError ());
491+ // AT_CUDA_CHECK(cudaDeviceSynchronize());
492+
493+ // This involves one more small kernel launches, but will be negligible end to end.
494+ // I could get rid of these by hacking the functor + multi tensor harness with persistence
495+ // logic, but keeping it simple for now
496+ auto ret = at::empty ({1 }, output.options ());
497+ const at::cuda::OptionalCUDAGuard device_guard (device_of (output));
498+ auto stream = at::cuda::getCurrentCUDAStream ();
499+ cleanup<<<per_tensor ? ntensors : 1 , 512 , 0 , stream>>> (
500+ output.DATA_PTR <float >(),
501+ per_tensor ? output_per_tensor.DATA_PTR <float >() : nullptr ,
502+ ret.DATA_PTR <float >(),
503+ per_tensor ? ret_per_tensor.DATA_PTR <float >() : nullptr ,
504+ per_tensor,
505+ max_chunks_per_tensor);
506+
507+ return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
508+ }
509+
358510
359511// Compute and update grad norm
360512// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
0 commit comments