-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathupdate_scale_hysteresis.cu
More file actions
55 lines (48 loc) · 2.13 KB
/
update_scale_hysteresis.cu
File metadata and controls
55 lines (48 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
__global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, int* growth_tracker, int* hysteresis_tracker,
const float* found_inf, double growth_factor, double backoff_factor,
int growth_interval, int hysteresis) {
if (*found_inf > 0) {
*hysteresis_tracker -= 1;
// Only reset the growth tracker when hysteresis is larger than zero
if (*hysteresis_tracker > 0) {
*growth_tracker = 0;
return;
}
}
if (*found_inf) {
*current_scale = (*current_scale) * backoff_factor;
*growth_tracker = 0;
} else {
// Entering this branch means we just carried out a successful step,
// so growth_tracker is incremented before comparing to growth_interval.
auto successful = (*growth_tracker) + 1;
if (successful == growth_interval) {
auto new_scale = static_cast<float>((*current_scale) * growth_factor);
// Do not grow the scale past fp32 bounds to inf.
if (isfinite(new_scale)) {
*current_scale = new_scale;
}
*growth_tracker = 0;
} else {
*growth_tracker = successful;
}
}
// Reset the hysteresis tracker if no infs are found
if (*found_inf <= 0) {
*hysteresis_tracker = hysteresis;
}
}
at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, at::Tensor growth_tracker,
at::Tensor hysteresis_tracker, at::Tensor found_inf, const double growth_factor,
const double backoff_factor, const int64_t growth_interval,
const int hysteresis) {
update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
current_scale.mutable_data_ptr<float>(), growth_tracker.mutable_data_ptr<int>(),
hysteresis_tracker.mutable_data_ptr<int>(), found_inf.const_data_ptr<float>(), growth_factor, backoff_factor,
growth_interval, hysteresis);
AT_CUDA_CHECK(cudaGetLastError());
return current_scale;
}