-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathfused_weight_gradient_dense_cuda.cu
More file actions
83 lines (71 loc) · 3.84 KB
/
fused_weight_gradient_dense_cuda.cu
File metadata and controls
83 lines (71 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "type_shim.h"
// BF16 Tensor core wrapper around cublas GEMMEx
void gemmex_wrapper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
const float* alpha, at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta, float* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb,
beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP16 Tensor core wrapper around cublas GEMMEx
void gemmex_wrapper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
const float* alpha, at::Half* A, int lda, at::Half* B, int ldb, const float* beta, float* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb,
beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP32 wrapper around cublas GEMMEx
void gemmex_wrapper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k,
const float* alpha, float* A, int lda, float* B, int ldb, const float* beta, float* C, int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, CUDA_R_32F, lda, B, CUDA_R_32F, ldb,
beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <typename T>
void wgrad_gemm_accum_fp32_cuda(T* input, T* d_output, float* d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
gemmex_wrapper(handle, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, out_dim, hidden_dim, &alpha, input, in_dim, d_output,
out_dim, &beta, d_weight, in_dim);
}
template void wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half* input, at::Half* d_output, float* d_weight, int in_dim,
int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16* input, at::BFloat16* d_output, float* d_weight,
int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp32_cuda<float>(float* input, float* d_output, float* d_weight, int in_dim,
int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp32_cuda_stub(at::Tensor& input, at::Tensor& d_output, at::Tensor& d_weight) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
const int hidden_dim = input_2d.size(0);
const int in_dim = input_2d.size(1);
const int out_dim = d_weight.size(0);
DISPATCH_FLOAT_HALF_AND_BFLOAT(
input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32",
wgrad_gemm_accum_fp32_cuda<scalar_t_0>(input_2d.data_ptr<scalar_t_0>(), d_output_2d.data_ptr<scalar_t_0>(),
d_weight.data_ptr<float>(), in_dim, hidden_dim, out_dim););
}