diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 3a7e6eb108f1b9..e1dcaadc69cfcc 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -2420,6 +2420,89 @@ void FusedMultiTransformerInt8InferMeta( out->set_dtype(x.dtype()); } +void FusedPartialRopeInferMeta(const MetaTensor& x, + const MetaTensor& cos, + const MetaTensor& sin, + MetaTensor* out) { + const auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 4, + common::errors::InvalidArgument("The input x must be a 4D tensor")); + + const int64_t batch_size = x_dims[0]; + const int64_t seq_len = x_dims[1]; + const int64_t num_heads = x_dims[2]; + const int64_t head_dim = x_dims[3]; + + PADDLE_ENFORCE_LE( + batch_size * seq_len * num_heads, + std::numeric_limits::max(), + common::errors::InvalidArgument("Currently only supports batch_size * " + "seq_len * num_heads <= INT_MAX")); + PADDLE_ENFORCE_LE(head_dim, + std::numeric_limits::max(), + common::errors::InvalidArgument( + "Currently only supports head_dim <= INT_MAX")); + + const auto cos_dims = cos.dims(); + PADDLE_ENFORCE_EQ( + cos_dims.size(), + 4, + common::errors::InvalidArgument("The input cos must be a 4D tensor")); + PADDLE_ENFORCE_EQ( + cos_dims[0], + 1, + common::errors::InvalidArgument("The batch_size of cos must be 1")); + PADDLE_ENFORCE_EQ( + cos_dims[1], + seq_len, + common::errors::InvalidArgument("The seq_len of cos must match x")); + PADDLE_ENFORCE_EQ( + cos_dims[2], + 1, + common::errors::InvalidArgument("The num_heads of cos must be 1")); + + const int64_t pe_head_dim = cos_dims[3]; + PADDLE_ENFORCE_LE(pe_head_dim, + head_dim, + common::errors::InvalidArgument( + "pe_head_dim must be no larger than head_dim")); + PADDLE_ENFORCE_EQ( + pe_head_dim % 2, + 0, + common::errors::InvalidArgument("pe_head_dim must be multiple of 2")); + PADDLE_ENFORCE_LE(pe_head_dim, + 1024, + common::errors::InvalidArgument( + "Currently only supports pe_head_dim <= 1024")); + + const auto sin_dims = sin.dims(); + PADDLE_ENFORCE_EQ( + sin_dims.size(), + 4, + common::errors::InvalidArgument("The input sin must be a 4D tensor")); + PADDLE_ENFORCE_EQ( + sin_dims[0], + 1, + common::errors::InvalidArgument("The batch_size of sin must be 1")); + PADDLE_ENFORCE_EQ( + sin_dims[1], + seq_len, + common::errors::InvalidArgument("The seq_len of sin must match x")); + PADDLE_ENFORCE_EQ( + sin_dims[2], + 1, + common::errors::InvalidArgument("The num_heads of sin must be 1")); + PADDLE_ENFORCE_EQ( + sin_dims[3], + pe_head_dim, + common::errors::InvalidArgument("The pe_head_dim of sin must match cos")); + + out->set_dims(x.dims()); + out->set_dtype(x.dtype()); +} + void FusedTransposeSplitQuantInferMeta(const MetaTensor& x, const MetaTensor& input_scales, const IntArray& tokens_per_expert, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index c1f6a988bf59b1..4cc2a65253d5df 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -668,6 +668,11 @@ void FusedMultiTransformerInt8InferMeta( std::vector cache_kv_out, MetaTensor* out); +void FusedPartialRopeInferMeta(const MetaTensor& x, + const MetaTensor& cos, + const MetaTensor& sin, + MetaTensor* out); + void FusedTransposeSplitQuantInferMeta(const MetaTensor& x, const MetaTensor& input_scales, const IntArray& tokens_per_expert, diff --git a/paddle/phi/kernels/fusion/gpu/fused_partial_rope_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_partial_rope_grad_kernel.cu new file mode 100644 index 00000000000000..44597795491982 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_partial_rope_grad_kernel.cu @@ -0,0 +1,154 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h" + +namespace phi { +namespace fusion { + +using FastDivMod = phi::funcs::FastDivMod; + +template +__global__ void rope_grad_kernel(const T* __restrict__ cos, + const T* __restrict__ sin, + const T* __restrict__ out_grad, + T* __restrict__ x_grad, + FastDivMod seq_len, + FastDivMod num_heads, + uint32_t nope_head_dim, + uint32_t pe_head_dim, + uint32_t block_num) { + using VT = phi::kps::details::VectorType; + extern __shared__ T shm[]; + + const uint32_t block_idx = blockIdx.x * 8 + threadIdx.y; + if (block_idx >= block_num) return; + const uint32_t seq_idx = seq_len.Divmod(num_heads.Div(block_idx))[1]; + const size_t block_offset = + static_cast(block_idx) * (nope_head_dim + pe_head_dim); + T* const pe_buffer = shm + threadIdx.y * pe_head_dim; + + // copy nope part + LOOP_WITH_SIZE_HINT( + i, threadIdx.x * VecSize, nope_head_dim, 32 * VecSize, NopeSize) { + size_t idx = block_offset + i; + *reinterpret_cast(x_grad + idx) = + *reinterpret_cast(out_grad + idx); + } + + // load pe part, apply embedding and transpose in shared memory + LOOP_WITH_SIZE_HINT( + i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) { + VT grad = *reinterpret_cast(out_grad + block_offset + + nope_head_dim + i); + VT grad_rot; + if (i < pe_head_dim / 2) { + grad_rot = *reinterpret_cast( + out_grad + block_offset + nope_head_dim + (i + pe_head_dim / 2)); + } else { + grad_rot = *reinterpret_cast( + out_grad + block_offset + nope_head_dim + (i - pe_head_dim / 2)); + } + + VT cos_v = *reinterpret_cast(cos + seq_idx * pe_head_dim + i); + VT sin_v; + if (i < pe_head_dim / 2) { + sin_v = *reinterpret_cast(sin + seq_idx * pe_head_dim + + (i + pe_head_dim / 2)); + } else { + sin_v = *reinterpret_cast(sin + seq_idx * pe_head_dim + + (i - pe_head_dim / 2)); + } + + for (uint32_t j = 0; j < VecSize; j++) { + uint32_t pe_idx = i + j; + if (pe_idx < pe_head_dim / 2) { + pe_buffer[pe_idx * 2] = + grad.val[j] * cos_v.val[j] + grad_rot.val[j] * sin_v.val[j]; + } else { + pe_buffer[(pe_idx - pe_head_dim / 2) * 2 + 1] = + grad.val[j] * cos_v.val[j] - grad_rot.val[j] * sin_v.val[j]; + } + } + } +#ifdef PADDLE_WITH_HIP + __syncthreads(); +#else + __syncwarp(); +#endif + + // store + LOOP_WITH_SIZE_HINT( + i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) { + VT tmp; + for (uint32_t j = 0; j < VecSize; j++) { + tmp.val[j] = pe_buffer[i + j]; + } + *reinterpret_cast(x_grad + block_offset + nope_head_dim + i) = tmp; + } +} + +template +void FusedPartialRoPEGradKernel(const Context& dev_ctx, + const DenseTensor& cos, + const DenseTensor& sin, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + const auto x_dims = out_grad.dims(); + const int64_t batch_size = x_dims[0]; + const int64_t seq_len = x_dims[1]; + const int64_t num_heads = x_dims[2]; + const int64_t head_dim = x_dims[3]; + const int64_t pe_head_dim = cos.dims()[3]; + const int64_t nope_head_dim = head_dim - pe_head_dim; + + // Allocate x_grad + dev_ctx.template Alloc(x_grad); + + if (batch_size == 0 || seq_len == 0 || num_heads == 0 || head_dim == 0) { + return; + } + + // Launch kernel + int64_t block_num = batch_size * seq_len * num_heads; + dim3 grid((block_num + 7) / 8); + dim3 block(32, 8); + int64_t shm_size = block.y * pe_head_dim * sizeof(T); + + auto kernel = [&]() { + SWITCH_ROPE_KERNEL(nope_head_dim, pe_head_dim, { + return rope_grad_kernel; + }); + }(); + + kernel<<>>( + cos.data(), + sin.data(), + out_grad.data(), + x_grad->data(), + static_cast(seq_len), + static_cast(num_heads), + static_cast(nope_head_dim), + static_cast(pe_head_dim), + static_cast(block_num)); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_partial_rope_grad, + GPU, + ALL_LAYOUT, + phi::fusion::FusedPartialRoPEGradKernel, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_partial_rope_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_partial_rope_kernel.cu new file mode 100644 index 00000000000000..fbf79347d7ae84 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_partial_rope_kernel.cu @@ -0,0 +1,138 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h" + +namespace phi { +namespace fusion { + +using FastDivMod = phi::funcs::FastDivMod; + +template +__global__ void rope_kernel(const T* __restrict__ x, + const T* __restrict__ cos, + const T* __restrict__ sin, + T* __restrict__ out, + FastDivMod seq_len, + FastDivMod num_heads, + uint32_t nope_head_dim, + uint32_t pe_head_dim, + uint32_t block_num) { + using VT = phi::kps::details::VectorType; + extern __shared__ T shm[]; + + const uint32_t block_idx = blockIdx.x * 8 + threadIdx.y; + if (block_idx >= block_num) return; + const uint32_t seq_idx = seq_len.Divmod(num_heads.Div(block_idx))[1]; + const size_t block_offset = + static_cast(block_idx) * (nope_head_dim + pe_head_dim); + T* const pe_buffer = shm + threadIdx.y * pe_head_dim; + + // copy nope part + LOOP_WITH_SIZE_HINT( + i, threadIdx.x * VecSize, nope_head_dim, 32 * VecSize, NopeSize) { + size_t idx = block_offset + i; + *reinterpret_cast(out + idx) = *reinterpret_cast(x + idx); + } + + // load pe part and transpose in shared memory + LOOP_WITH_SIZE_HINT( + i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) { + VT tmp = *reinterpret_cast(x + block_offset + nope_head_dim + i); + for (uint32_t j = 0; j < VecSize; j++) { + uint32_t pe_idx = i + j; + if (pe_idx % 2 == 0) { + pe_buffer[pe_idx / 2] = tmp.val[j]; + } else { + pe_buffer[pe_idx / 2 + pe_head_dim / 2] = tmp.val[j]; + } + } + } +#ifdef PADDLE_WITH_HIP + __syncthreads(); +#else + __syncwarp(); +#endif + + // apply embedding and store + LOOP_WITH_SIZE_HINT( + i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) { + VT cos_v = *reinterpret_cast(cos + seq_idx * pe_head_dim + i); + VT sin_v = *reinterpret_cast(sin + seq_idx * pe_head_dim + i); + VT tmp; + for (uint32_t j = 0; j < VecSize; j++) { + uint32_t pe_idx = i + j; + T x_pe = pe_buffer[pe_idx]; + T x_pe_rot = (pe_idx < pe_head_dim / 2) + ? -pe_buffer[pe_idx + pe_head_dim / 2] + : pe_buffer[pe_idx - pe_head_dim / 2]; + tmp.val[j] = (x_pe * cos_v.val[j]) + (x_pe_rot * sin_v.val[j]); + } + *reinterpret_cast(out + block_offset + nope_head_dim + i) = tmp; + } +} + +template +void FusedPartialRoPEKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& cos, + const DenseTensor& sin, + DenseTensor* out) { + const auto x_dims = x.dims(); + const int64_t batch_size = x_dims[0]; + const int64_t seq_len = x_dims[1]; + const int64_t num_heads = x_dims[2]; + const int64_t head_dim = x_dims[3]; + const int64_t pe_head_dim = cos.dims()[3]; + const int64_t nope_head_dim = head_dim - pe_head_dim; + + // Allocate out + dev_ctx.template Alloc(out); + + if (batch_size == 0 || seq_len == 0 || num_heads == 0 || head_dim == 0) { + return; + } + + // Launch kernel + int64_t block_num = batch_size * seq_len * num_heads; + dim3 grid((block_num + 7) / 8); + dim3 block(32, 8); + int64_t shm_size = block.y * pe_head_dim * sizeof(T); + + auto kernel = [&]() { + SWITCH_ROPE_KERNEL(nope_head_dim, pe_head_dim, { + return rope_kernel; + }); + }(); + + kernel<<>>( + x.data(), + cos.data(), + sin.data(), + out->data(), + static_cast(seq_len), + static_cast(num_heads), + static_cast(nope_head_dim), + static_cast(pe_head_dim), + static_cast(block_num)); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_partial_rope, + GPU, + ALL_LAYOUT, + phi::fusion::FusedPartialRoPEKernel, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h b/paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h new file mode 100644 index 00000000000000..3d5b6e3e970462 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h @@ -0,0 +1,85 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/primitive/datamover_primitives.h" + +#define SWITCH_NOPE_HEAD_DIM(__dim, ...) \ + if (__dim == 32) { \ + constexpr int NopeSize = 32; \ + { __VA_ARGS__ } \ + } else if (__dim == 64) { \ + constexpr int NopeSize = 64; \ + { __VA_ARGS__ } \ + } else if (__dim == 96) { \ + constexpr int NopeSize = 96; \ + { __VA_ARGS__ } \ + } else if (__dim == 128) { \ + constexpr int NopeSize = 128; \ + { __VA_ARGS__ } \ + } else { \ + constexpr int NopeSize = 0; \ + { __VA_ARGS__ } \ + } + +#define SWITCH_PE_HEAD_DIM(__dim, ...) \ + if (__dim == 32) { \ + constexpr int PeSize = 32; \ + { __VA_ARGS__ } \ + } else if (__dim == 64) { \ + constexpr int PeSize = 64; \ + { __VA_ARGS__ } \ + } else if (__dim == 96) { \ + constexpr int PeSize = 96; \ + { __VA_ARGS__ } \ + } else if (__dim == 128) { \ + constexpr int PeSize = 128; \ + { __VA_ARGS__ } \ + } else { \ + constexpr int PeSize = 0; \ + { __VA_ARGS__ } \ + } + +// Note: pe_head_dim must be divisible by 2x of the vector size. +#define SWITCH_VEC_SIZE(__nope_head_dim, __pe_head_dim, ...) \ + if (__nope_head_dim % 4 == 0 && __nope_head_dim >= 128 && \ + __pe_head_dim % 8 == 0 && __pe_head_dim >= 128) { \ + constexpr int VecSize = 4; \ + { __VA_ARGS__ } \ + } else if (__nope_head_dim % 2 == 0 && __nope_head_dim >= 64 && \ + __pe_head_dim % 4 == 0 && __pe_head_dim >= 64) { \ + constexpr int VecSize = 2; \ + { __VA_ARGS__ } \ + } else { \ + constexpr int VecSize = 1; \ + { __VA_ARGS__ } \ + } + +#define SWITCH_ROPE_KERNEL(__nope_head_dim, __pe_head_dim, ...) \ + SWITCH_NOPE_HEAD_DIM( \ + __nope_head_dim, \ + SWITCH_PE_HEAD_DIM( \ + __pe_head_dim, \ + SWITCH_VEC_SIZE(__nope_head_dim, __pe_head_dim, {__VA_ARGS__}))) + +#define LOOP_WITH_SIZE_HINT(__index, __init, __size, __stride, __hint) \ + for (uint32_t __index = (__init), __offset = 0; \ + (__hint) > 0 ? __offset < (__hint) : __index < (__size); \ + __index += (__stride), __offset += (__stride)) \ + if ((__hint) == 0 || (__hint) % (__stride) == 0 || \ + __offset + (__stride) < (__hint) || __index < (__size)) diff --git a/paddle/phi/ops/yaml/fused_backward.yaml b/paddle/phi/ops/yaml/fused_backward.yaml index 7a0f8239630af1..69544691c06dc7 100644 --- a/paddle/phi/ops/yaml/fused_backward.yaml +++ b/paddle/phi/ops/yaml/fused_backward.yaml @@ -65,6 +65,17 @@ optional: x, intermediate_out no_need_buffer: x, y +- backward_op : fused_partial_rope_grad + forward: fused_partial_rope (Tensor x, Tensor cos, Tensor sin) -> Tensor(out) + args : (Tensor cos, Tensor sin, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [out_grad] + kernel : + func : fused_partial_rope_grad + support_dygraph_mode : true + - backward_op : fused_rotary_position_embedding_grad forward: fused_rotary_position_embedding (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style, bool time_major, float rotary_emb_base) -> Tensor(out_q), Tensor(out_k), Tensor(out_v) args : (Tensor sin, Tensor cos, Tensor position_ids, Tensor out_q_grad, Tensor out_k_grad,Tensor out_v_grad, bool use_neox_rotary_style, bool time_major, float rotary_emb_base) diff --git a/paddle/phi/ops/yaml/fused_ops.yaml b/paddle/phi/ops/yaml/fused_ops.yaml index 991b1ab8c0ab6d..0b22345aa1733a 100644 --- a/paddle/phi/ops/yaml/fused_ops.yaml +++ b/paddle/phi/ops/yaml/fused_ops.yaml @@ -430,6 +430,16 @@ data_type : x optional : cache_kv, pre_caches, rotary_pos_emb, time_step, seq_lengths, src_mask, gather_index +- op : fused_partial_rope + args: (Tensor x, Tensor cos, Tensor sin) + output: Tensor(out) + infer_meta: + func: FusedPartialRopeInferMeta + kernel: + func: fused_partial_rope + backward: fused_partial_rope_grad + support_dygraph_mode : true + - op : fused_rotary_position_embedding args : (Tensor q, Tensor k, Tensor v, Tensor sin, Tensor cos, Tensor position_ids, bool use_neox_rotary_style = true, bool time_major = false, float rotary_emb_base = 10000.0) output : Tensor(out_q), Tensor(out_k), Tensor(out_v) diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index 50eaca9dbf62ad..1b0f78e65da4f0 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -48,6 +48,7 @@ fused_linear_activation, fused_matmul_bias, ) +from .fused_partial_rope import fused_partial_rope from .fused_rms_norm import fused_rms_norm from .fused_rms_norm_ext import fused_rms_norm_ext from .fused_rotary_position_embedding import fused_rotary_position_embedding diff --git a/python/paddle/incubate/nn/functional/fused_partial_rope.py b/python/paddle/incubate/nn/functional/fused_partial_rope.py new file mode 100644 index 00000000000000..edec341f95e6f5 --- /dev/null +++ b/python/paddle/incubate/nn/functional/fused_partial_rope.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from paddle import _C_ops +from paddle.framework import in_dynamic_or_pir_mode + +if TYPE_CHECKING: + from paddle import Tensor + + +def fused_partial_rope( + x: Tensor, + cos: Tensor, + sin: Tensor, +) -> Tensor: + r""" + Applies partial rotary position embedding on the pe_head_dim portion of input. + + Args: + x (Tensor): The input tensor. The data type is bfloat16. The shape of x must be [batch_size, seq_len, num_heads, head_dim]. + cos (Tensor): The input tensor. The data type is bfloat16. The shape of cos must be [1, seq_len, 1, pe_head_dim] and pe_head_dim must be a multiple of 2 and mustn't exceed head_dim. + sin (Tensor): The input tensor. The data type is bfloat16. The shape of sin must be [1, seq_len, 1, pe_head_dim] and pe_head_dim must be a multiple of 2 and mustn't exceed head_dim. + + Returns: + out: Tensor representing the fused rotary position embedding, has same shape and data type as `x` . + + + Examples: + + .. code-block:: python + + >>> # doctest: +REQUIRES(env:GPU) + >>> import paddle + >>> from paddle.incubate.nn.functional import fused_partial_rope + + >>> paddle.set_device('gpu') + >>> paddle.seed(2025) + + >>> # x: [batch_size, seq_len, num_heads, head_dim] + >>> x = paddle.randn([2, 2, 2, 4], dtype='bfloat16') + + >>> # sin, cos: [1, seq_len, 1, pe_head_dim] + >>> cos = paddle.randn([1, 2, 1, 2], dtype='bfloat16') + >>> sin = paddle.randn([1, 2, 1, 2], dtype='bfloat16') + + >>> # out: [batch_size, seq_len, num_heads, head_dim] + >>> out = fused_partial_rope(x, cos, sin) + >>> print(out) + Tensor(shape=[2, 2, 2, 4], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True, + [[[[-0.17968750, 0.28125000, -0.34765625, -0.92187500], + [-0.83593750, 2. , -0.13476562, -0.67187500]], + [[ 0.38281250, -0.63281250, 0.25000000, -1.03125000], + [-1.92187500, 2.12500000, 1.92968750, -4.21875000]]], + [[[-0.90625000, -1.62500000, -0.22167969, -0.68359375], + [-0.76562500, 0.23828125, 0.36523438, 0.53515625]], + [[ 0.92578125, -0.85156250, -0.75000000, 1.50000000], + [ 0.41992188, -1.13281250, 0.73437500, -2.18750000]]]]) + """ + if in_dynamic_or_pir_mode(): + return _C_ops.fused_partial_rope(x, cos, sin) diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index ceaf163d39329e..5d2bbf3721c3ac 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -74,6 +74,7 @@ if(NOT WITH_GPU) list(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_op) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) + list(REMOVE_ITEM TEST_OPS test_fused_partial_rope_op) list(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) diff --git a/test/legacy_test/test_fused_partial_rope_op.py b/test/legacy_test/test_fused_partial_rope_op.py new file mode 100644 index 00000000000000..162cb5e5349ab2 --- /dev/null +++ b/test/legacy_test/test_fused_partial_rope_op.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np + +import paddle +from paddle.incubate.nn.functional import fused_partial_rope + + +def fused_partial_rope_ref(x, cos, sin): + x_nope = x[..., : -cos.shape[-1]] + x_pe = x[..., -cos.shape[-1] :] + + b, s, h, d = x_pe.shape # [bs, seq_len, num_heads, pe_head_dim] + x_pe = ( + x_pe.reshape([b, s, h, d // 2, 2]) + .transpose([0, 1, 2, 4, 3]) + .reshape([b, s, h, d]) + ) + + cos = cos[:, :s, :, :] # [1, seq_len, 1, pe_head_dim] + sin = sin[:, :s, :, :] + + x1 = x_pe[..., : x_pe.shape[-1] // 2] + x2 = x_pe[..., x_pe.shape[-1] // 2 :] + x_pe_rotate_half = paddle.concat([-x2, x1], axis=-1) + + x_pe = (x_pe * cos) + (x_pe_rotate_half * sin) + + return paddle.concat([x_nope, x_pe], axis=-1) + + +class TestFusedPartialRoPEOp(unittest.TestCase): + def eval(self, batch_size, seq_len, num_heads, head_dim, pe_head_dim): + x = paddle.randn([batch_size, seq_len, num_heads, head_dim], 'bfloat16') + x.stop_gradient = False + x_ref = paddle.clone(x).detach() + x_ref.stop_gradient = False + + cos = paddle.randn([1, seq_len, 1, pe_head_dim], 'bfloat16') + sin = paddle.randn_like(cos) + + # Test forward + out = fused_partial_rope(x, cos, sin) + out_ref = fused_partial_rope_ref(x_ref, cos, sin) + + np.testing.assert_allclose( + out.astype('float32'), out_ref.astype('float32') + ) + + # Test backward + out_grad = paddle.randn_like(out) + paddle.autograd.backward([out], [out_grad]) + paddle.autograd.backward([out_ref], [out_grad]) + + np.testing.assert_allclose( + x.grad.astype('float32'), x_ref.grad.astype('float32') + ) + + def test_0_size_in_batch_size(self): + self.eval(0, 32, 64, 128, 64) + + def test_0_size_in_seq_len(self): + self.eval(32, 0, 64, 128, 64) + + def test_all_pe_head_dim(self): + self.eval(1, 8, 1, 128, 128) + + def test_medium_1x_vec(self): + self.eval(1, 8, 16, 75, 50) + + def test_medium_2x_vec(self): + self.eval(4, 1, 16, 200, 100) + + def test_medium_4x_vec(self): + self.eval(2, 4, 8, 192, 64) + + def test_large(self): + self.eval(1, 2, 16, 1024, 384) + + +if __name__ == "__main__": + unittest.main()