Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 42180bd

Browse files
Forward/backward compatibility around pytorch 3aeb78, to fix #191
1 parent 975ed32 commit 42180bd

10 files changed

Lines changed: 125 additions & 71 deletions

apex/optimizers/csrc/fused_adam_cuda_kernel.cu renamed to csrc/fused_adam_cuda_kernel.cu

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "ATen/AccumulateType.h"
1111
#include <THC/THCGeneral.h>
1212

13+
#include "type_shim.h"
14+
1315
typedef enum{
1416
ADAM_MODE_0 =0, // eps under square root
1517
ADAM_MODE_1 =1 // eps outside square root
@@ -29,8 +31,8 @@ __global__ void adam_cuda_kernel(
2931
const float step_size,
3032
const size_t tsize,
3133
adamMode_t mode,
32-
const float decay) {
33-
34+
const float decay)
35+
{
3436
//Assuming 2D grids and 2D blocks
3537
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
3638
const int threadsPerBlock = blockDim.x * blockDim.y;
@@ -67,7 +69,9 @@ void fused_adam_cuda(
6769
int step,
6870
int mode,
6971
int bias_correction,
70-
float decay) {
72+
float decay)
73+
{
74+
// using namespace at;
7175

7276
//Get tensor size
7377
int tsize = p.numel();
@@ -91,7 +95,8 @@ void fused_adam_cuda(
9195
//all other values should be fp32 for half gradients
9296
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
9397
//dispatch is done on the gradient type
94-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] {
98+
using namespace at; // prevents "toString is undefined" errors
99+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(g.type()), "adam_cuda_kernel", ([&] {
95100
using accscalar_t = at::acc_type<scalar_t, true>;
96101
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
97102
p.data<accscalar_t>(),
@@ -109,7 +114,8 @@ void fused_adam_cuda(
109114
decay);
110115
}));
111116
} else {
112-
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
117+
using namespace at;
118+
AT_DISPATCH_FLOATING_TYPES(TypeShim(g.type()), "adam_cuda_kernel", ([&] {
113119
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
114120
p.data<scalar_t>(),
115121
NULL, //don't output p_copy for fp32, it's wasted write
File renamed without changes.

apex/normalization/csrc/layer_norm_cuda_kernel.cu renamed to csrc/layer_norm_cuda_kernel.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <cuda.h>
77
#include <cuda_runtime.h>
88

9+
#include "type_shim.h"
10+
911
template<typename U> __device__
1012
void cuWelfordOnlineSum(
1113
const U curr,
@@ -675,7 +677,8 @@ void cuda_layer_norm(
675677
at::Tensor* beta,
676678
double epsilon)
677679
{
678-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "layer_norm_cuda_kernel", ([&] {
680+
using namespace at;
681+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input->type()), "layer_norm_cuda_kernel", ([&] {
679682
using accscalar_t = at::acc_type<scalar_t, true>;
680683
HostApplyLayerNorm(
681684
output->data<scalar_t>(),
@@ -772,7 +775,8 @@ void cuda_layer_norm_gradient(
772775
at::Tensor* grad_gamma,
773776
at::Tensor* grad_beta)
774777
{
775-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input->type(), "cuComputeGradInput", ([&] {
778+
using namespace at;
779+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(input->type()), "cuComputeGradInput", ([&] {
776780
using accscalar_t = at::acc_type<scalar_t, true>;
777781
HostLayerNormGradient(
778782
dout->data<scalar_t>(),

csrc/multi_tensor_apply.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
1515
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
1616

17-
template<int n> struct TensorList
17+
template<int n> struct TensorListMetadata
1818
{
1919
void* addresses[n][depth_to_max_tensors[n-1]];
2020
int sizes[depth_to_max_tensors[n-1]];
@@ -62,7 +62,7 @@ void multi_tensor_apply(
6262

6363
int ntensors = tensor_lists[0].size();
6464

65-
TensorList<depth> tl;
65+
TensorListMetadata<depth> tl;
6666

6767
auto stream = at::cuda::getCurrentCUDAStream();
6868

csrc/multi_tensor_scale_kernel.cu

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22
#include <ATen/AccumulateType.h>
33
#include <ATen/cuda/CUDAContext.h>
44
#include <ATen/cuda/Exceptions.h>
5-
#include "multi_tensor_apply.cuh"
5+
// Another possibility:
6+
// #include <torch/all.h>
67

78
#include <assert.h>
9+
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
10+
#include <sstream>
11+
12+
#include "type_shim.h"
13+
#include "multi_tensor_apply.cuh"
814

915
#define BLOCK_SIZE 512
1016
#define ILP 4
@@ -15,7 +21,7 @@ struct ScaleFunctor
1521
__device__ __forceinline__ void operator()(
1622
int chunk_size,
1723
volatile int* noop_gmem,
18-
TensorList<2>& tl,
24+
TensorListMetadata<2>& tl,
1925
float scale)
2026
{
2127
__shared__ int noop_smem;
@@ -87,15 +93,17 @@ void multi_tensor_scale_cuda(
8793
std::vector<std::vector<at::Tensor>> tensor_lists,
8894
float scale)
8995
{
96+
using namespace at;
9097
// The output (downscaled) type is always float.
9198
// If build times suffer, think about where to put this dispatch,
9299
// and what logic should be moved out of multi_tensor_apply.
93-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[0][0].type(),
100+
101+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TypeShim(tensor_lists[0][0].type()),
94102
"multi_tensor_scale_cuda",
95103
[&]
96104
{
97105
// using accscalar_t = acc_type<scalar_t, true>;
98-
switch(tensor_lists[1][0].type().scalarType())
106+
switch(tensor_lists[1][0].scalar_type())
99107
{
100108
case at::ScalarType::Half:
101109
multi_tensor_apply<2>(
@@ -116,8 +124,10 @@ void multi_tensor_scale_cuda(
116124
scale);
117125
break;
118126
default:
119-
AT_ERROR("multi_tensor_scale_cuda not implemented for output type = ",
120-
tensor_lists[1][0].type().toString());
127+
std::stringstream ss;
128+
ss << "multi_tensor_scale_cuda not implemented for output type = "
129+
<< tensor_lists[1][0].dtype();
130+
AT_ERROR(ss.str().c_str());
121131
}
122132
});
123133

csrc/type_shim.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <ATen/ATen.h>
2+
3+
// Forward/backward compatiblity hack around
4+
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
5+
// pending more future-proof guidance from upstream.
6+
struct TypeShim
7+
{
8+
const at::Type& payload;
9+
TypeShim(const at::Type& type) : payload(type) {}
10+
// Enable trivial conversion to a const at::Type& for pre-3aeb78
11+
operator const at::Type&(){ return payload; };
12+
// Enable dispatch switch statements to take *this directly for post-3aeb78
13+
operator at::ScalarType(){ return payload.scalarType(); };
14+
};

0 commit comments

Comments
 (0)