Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2420,6 +2420,89 @@ void FusedMultiTransformerInt8InferMeta(
out->set_dtype(x.dtype());
}

void FusedPartialRopeInferMeta(const MetaTensor& x,
const MetaTensor& cos,
const MetaTensor& sin,
MetaTensor* out) {
const auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
4,
common::errors::InvalidArgument("The input x must be a 4D tensor"));

const int64_t batch_size = x_dims[0];
const int64_t seq_len = x_dims[1];
const int64_t num_heads = x_dims[2];
const int64_t head_dim = x_dims[3];

PADDLE_ENFORCE_LE(
batch_size * seq_len * num_heads,
std::numeric_limits<int>::max(),
common::errors::InvalidArgument("Currently only supports batch_size * "
"seq_len * num_heads <= INT_MAX"));
PADDLE_ENFORCE_LE(head_dim,
std::numeric_limits<int>::max(),
common::errors::InvalidArgument(
"Currently only supports head_dim <= INT_MAX"));

const auto cos_dims = cos.dims();
PADDLE_ENFORCE_EQ(
cos_dims.size(),
4,
common::errors::InvalidArgument("The input cos must be a 4D tensor"));
PADDLE_ENFORCE_EQ(
cos_dims[0],
1,
common::errors::InvalidArgument("The batch_size of cos must be 1"));
PADDLE_ENFORCE_EQ(
cos_dims[1],
seq_len,
common::errors::InvalidArgument("The seq_len of cos must match x"));
PADDLE_ENFORCE_EQ(
cos_dims[2],
1,
common::errors::InvalidArgument("The num_heads of cos must be 1"));

const int64_t pe_head_dim = cos_dims[3];
PADDLE_ENFORCE_LE(pe_head_dim,
head_dim,
common::errors::InvalidArgument(
"pe_head_dim must be no larger than head_dim"));
PADDLE_ENFORCE_EQ(
pe_head_dim % 2,
0,
common::errors::InvalidArgument("pe_head_dim must be multiple of 2"));
PADDLE_ENFORCE_LE(pe_head_dim,
1024,
common::errors::InvalidArgument(
"Currently only supports pe_head_dim <= 1024"));

const auto sin_dims = sin.dims();
PADDLE_ENFORCE_EQ(
sin_dims.size(),
4,
common::errors::InvalidArgument("The input sin must be a 4D tensor"));
PADDLE_ENFORCE_EQ(
sin_dims[0],
1,
common::errors::InvalidArgument("The batch_size of sin must be 1"));
PADDLE_ENFORCE_EQ(
sin_dims[1],
seq_len,
common::errors::InvalidArgument("The seq_len of sin must match x"));
PADDLE_ENFORCE_EQ(
sin_dims[2],
1,
common::errors::InvalidArgument("The num_heads of sin must be 1"));
PADDLE_ENFORCE_EQ(
sin_dims[3],
pe_head_dim,
common::errors::InvalidArgument("The pe_head_dim of sin must match cos"));

out->set_dims(x.dims());
out->set_dtype(x.dtype());
}

void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
const MetaTensor& input_scales,
const IntArray& tokens_per_expert,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,11 @@ void FusedMultiTransformerInt8InferMeta(
std::vector<MetaTensor*> cache_kv_out,
MetaTensor* out);

void FusedPartialRopeInferMeta(const MetaTensor& x,
const MetaTensor& cos,
const MetaTensor& sin,
MetaTensor* out);

void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
const MetaTensor& input_scales,
const IntArray& tokens_per_expert,
Expand Down
154 changes: 154 additions & 0 deletions paddle/phi/kernels/fusion/gpu/fused_partial_rope_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h"

namespace phi {
namespace fusion {

using FastDivMod = phi::funcs::FastDivMod<uint32_t>;

template <typename T, int VecSize, int NopeSize, int PeSize>
__global__ void rope_grad_kernel(const T* __restrict__ cos,
const T* __restrict__ sin,
const T* __restrict__ out_grad,
T* __restrict__ x_grad,
FastDivMod seq_len,
FastDivMod num_heads,
uint32_t nope_head_dim,
uint32_t pe_head_dim,
uint32_t block_num) {
using VT = phi::kps::details::VectorType<T, VecSize>;
extern __shared__ T shm[];

const uint32_t block_idx = blockIdx.x * 8 + threadIdx.y;
if (block_idx >= block_num) return;
const uint32_t seq_idx = seq_len.Divmod(num_heads.Div(block_idx))[1];
const size_t block_offset =
static_cast<size_t>(block_idx) * (nope_head_dim + pe_head_dim);
T* const pe_buffer = shm + threadIdx.y * pe_head_dim;

// copy nope part
LOOP_WITH_SIZE_HINT(
i, threadIdx.x * VecSize, nope_head_dim, 32 * VecSize, NopeSize) {
size_t idx = block_offset + i;
*reinterpret_cast<VT*>(x_grad + idx) =
*reinterpret_cast<const VT*>(out_grad + idx);
}

// load pe part, apply embedding and transpose in shared memory
LOOP_WITH_SIZE_HINT(
i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) {
VT grad = *reinterpret_cast<const VT*>(out_grad + block_offset +
nope_head_dim + i);
VT grad_rot;
if (i < pe_head_dim / 2) {
grad_rot = *reinterpret_cast<const VT*>(
out_grad + block_offset + nope_head_dim + (i + pe_head_dim / 2));
} else {
grad_rot = *reinterpret_cast<const VT*>(
out_grad + block_offset + nope_head_dim + (i - pe_head_dim / 2));
}

VT cos_v = *reinterpret_cast<const VT*>(cos + seq_idx * pe_head_dim + i);
VT sin_v;
if (i < pe_head_dim / 2) {
sin_v = *reinterpret_cast<const VT*>(sin + seq_idx * pe_head_dim +
(i + pe_head_dim / 2));
} else {
sin_v = *reinterpret_cast<const VT*>(sin + seq_idx * pe_head_dim +
(i - pe_head_dim / 2));
}

for (uint32_t j = 0; j < VecSize; j++) {
uint32_t pe_idx = i + j;
if (pe_idx < pe_head_dim / 2) {
pe_buffer[pe_idx * 2] =
grad.val[j] * cos_v.val[j] + grad_rot.val[j] * sin_v.val[j];
} else {
pe_buffer[(pe_idx - pe_head_dim / 2) * 2 + 1] =
grad.val[j] * cos_v.val[j] - grad_rot.val[j] * sin_v.val[j];
}
}
}
#ifdef PADDLE_WITH_HIP
__syncthreads();
#else
__syncwarp();
#endif

// store
LOOP_WITH_SIZE_HINT(
i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) {
VT tmp;
for (uint32_t j = 0; j < VecSize; j++) {
tmp.val[j] = pe_buffer[i + j];
}
*reinterpret_cast<VT*>(x_grad + block_offset + nope_head_dim + i) = tmp;
}
}

template <typename T, typename Context>
void FusedPartialRoPEGradKernel(const Context& dev_ctx,
const DenseTensor& cos,
const DenseTensor& sin,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const auto x_dims = out_grad.dims();
const int64_t batch_size = x_dims[0];
const int64_t seq_len = x_dims[1];
const int64_t num_heads = x_dims[2];
const int64_t head_dim = x_dims[3];
const int64_t pe_head_dim = cos.dims()[3];
const int64_t nope_head_dim = head_dim - pe_head_dim;

// Allocate x_grad
dev_ctx.template Alloc<T>(x_grad);

if (batch_size == 0 || seq_len == 0 || num_heads == 0 || head_dim == 0) {
return;
}

// Launch kernel
int64_t block_num = batch_size * seq_len * num_heads;
dim3 grid((block_num + 7) / 8);
dim3 block(32, 8);
int64_t shm_size = block.y * pe_head_dim * sizeof(T);

auto kernel = [&]() {
SWITCH_ROPE_KERNEL(nope_head_dim, pe_head_dim, {
return rope_grad_kernel<T, VecSize, NopeSize, PeSize>;
});
}();

kernel<<<grid, block, shm_size, dev_ctx.stream()>>>(
cos.data<T>(),
sin.data<T>(),
out_grad.data<T>(),
x_grad->data<T>(),
static_cast<uint32_t>(seq_len),
static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(nope_head_dim),
static_cast<uint32_t>(pe_head_dim),
static_cast<uint32_t>(block_num));
}

} // namespace fusion
} // namespace phi

PD_REGISTER_KERNEL(fused_partial_rope_grad,
GPU,
ALL_LAYOUT,
phi::fusion::FusedPartialRoPEGradKernel,
phi::dtype::bfloat16) {}
138 changes: 138 additions & 0 deletions paddle/phi/kernels/fusion/gpu/fused_partial_rope_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h"

namespace phi {
namespace fusion {

using FastDivMod = phi::funcs::FastDivMod<uint32_t>;

template <typename T, int VecSize, int NopeSize, int PeSize>
__global__ void rope_kernel(const T* __restrict__ x,
const T* __restrict__ cos,
const T* __restrict__ sin,
T* __restrict__ out,
FastDivMod seq_len,
FastDivMod num_heads,
uint32_t nope_head_dim,
uint32_t pe_head_dim,
uint32_t block_num) {
using VT = phi::kps::details::VectorType<T, VecSize>;
extern __shared__ T shm[];

const uint32_t block_idx = blockIdx.x * 8 + threadIdx.y;
if (block_idx >= block_num) return;
const uint32_t seq_idx = seq_len.Divmod(num_heads.Div(block_idx))[1];
const size_t block_offset =
static_cast<size_t>(block_idx) * (nope_head_dim + pe_head_dim);
T* const pe_buffer = shm + threadIdx.y * pe_head_dim;

// copy nope part
LOOP_WITH_SIZE_HINT(
i, threadIdx.x * VecSize, nope_head_dim, 32 * VecSize, NopeSize) {
size_t idx = block_offset + i;
*reinterpret_cast<VT*>(out + idx) = *reinterpret_cast<const VT*>(x + idx);
}

// load pe part and transpose in shared memory
LOOP_WITH_SIZE_HINT(
i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) {
VT tmp = *reinterpret_cast<const VT*>(x + block_offset + nope_head_dim + i);
for (uint32_t j = 0; j < VecSize; j++) {
uint32_t pe_idx = i + j;
if (pe_idx % 2 == 0) {
pe_buffer[pe_idx / 2] = tmp.val[j];
} else {
pe_buffer[pe_idx / 2 + pe_head_dim / 2] = tmp.val[j];
}
}
}
#ifdef PADDLE_WITH_HIP
__syncthreads();
#else
__syncwarp();
#endif

// apply embedding and store
LOOP_WITH_SIZE_HINT(
i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) {
VT cos_v = *reinterpret_cast<const VT*>(cos + seq_idx * pe_head_dim + i);
VT sin_v = *reinterpret_cast<const VT*>(sin + seq_idx * pe_head_dim + i);
VT tmp;
for (uint32_t j = 0; j < VecSize; j++) {
uint32_t pe_idx = i + j;
T x_pe = pe_buffer[pe_idx];
T x_pe_rot = (pe_idx < pe_head_dim / 2)
? -pe_buffer[pe_idx + pe_head_dim / 2]
: pe_buffer[pe_idx - pe_head_dim / 2];
tmp.val[j] = (x_pe * cos_v.val[j]) + (x_pe_rot * sin_v.val[j]);
}
*reinterpret_cast<VT*>(out + block_offset + nope_head_dim + i) = tmp;
}
}

template <typename T, typename Context>
void FusedPartialRoPEKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& cos,
const DenseTensor& sin,
DenseTensor* out) {
const auto x_dims = x.dims();
const int64_t batch_size = x_dims[0];
const int64_t seq_len = x_dims[1];
const int64_t num_heads = x_dims[2];
const int64_t head_dim = x_dims[3];
const int64_t pe_head_dim = cos.dims()[3];
const int64_t nope_head_dim = head_dim - pe_head_dim;

// Allocate out
dev_ctx.template Alloc<T>(out);

if (batch_size == 0 || seq_len == 0 || num_heads == 0 || head_dim == 0) {
return;
}

// Launch kernel
int64_t block_num = batch_size * seq_len * num_heads;
dim3 grid((block_num + 7) / 8);
dim3 block(32, 8);
int64_t shm_size = block.y * pe_head_dim * sizeof(T);

auto kernel = [&]() {
SWITCH_ROPE_KERNEL(nope_head_dim, pe_head_dim, {
return rope_kernel<T, VecSize, NopeSize, PeSize>;
});
}();

kernel<<<grid, block, shm_size, dev_ctx.stream()>>>(
x.data<T>(),
cos.data<T>(),
sin.data<T>(),
out->data<T>(),
static_cast<uint32_t>(seq_len),
static_cast<uint32_t>(num_heads),
static_cast<uint32_t>(nope_head_dim),
static_cast<uint32_t>(pe_head_dim),
static_cast<uint32_t>(block_num));
}

} // namespace fusion
} // namespace phi

PD_REGISTER_KERNEL(fused_partial_rope,
GPU,
ALL_LAYOUT,
phi::fusion::FusedPartialRoPEKernel,
phi::dtype::bfloat16) {}
Loading