Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fd4ae7d

Browse files
authored
fix a bug in fused rope (NVIDIA#1750)
Signed-off-by: Xin Yao <[email protected]>
1 parent 08f7402 commit fd4ae7d

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

csrc/megatron/fused_rotary_positional_embedding.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ __global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2,
5252
int offset_head = offset_block + head_id * hn;
5353
#pragma unroll
5454
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
55-
int offset_src_dst = offset_head + hn_id;
56-
dst[offset_src_dst] = src[offset_src_dst];
55+
dst[offset_head + hn_id] = src[offset_head + hn_id];
5756
}
5857
}
5958
}
@@ -89,7 +88,7 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
8988
int offset_head = offset_block + head_id * hn;
9089
#pragma unroll
9190
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
92-
dst[offset_head + hn_id] = 1.0;
91+
dst[offset_head + hn_id] = src[offset_head + hn_id];
9392
}
9493
}
9594
}

tests/L0/run_transformer/test_fused_rope.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ def test_forward_backward(self):
8484

8585
# unfused
8686
output_unfused = apply_rotary_pos_emb(t, emb)
87-
output_unfused.sum().backward()
87+
loss_unfused = output_unfused.sum() * 2
88+
loss_unfused.backward()
8889
grad_unfused = t.grad.detach().clone()
8990
t.grad = None
9091

9192
# fused
9293
output_fused = fused_apply_rotary_pos_emb(t, emb)
93-
output_fused.sum().backward()
94+
loss_fused = output_fused.sum() * 2
95+
loss_fused.backward()
9496
grad_fused = t.grad.detach().clone()
9597

9698
self.assertEqual(

0 commit comments

Comments
 (0)