Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a12c0d8

Browse files
committed
Merge branch 'boyanl/silu_and_mul' into 'main'
add silu_and_mul kernel for MoE benchmark and sample See merge request dl/tileir/cutile-python!42
2 parents 7ea97fb + 90d05f0 commit a12c0d8

File tree

6 files changed

+153
-30
lines changed

6 files changed

+153
-30
lines changed

samples/MoE.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212

1313
ConstInt = ct.Constant[int]
14+
ConstBool = ct.Constant[bool]
1415

1516

1617
@ct.kernel
@@ -22,7 +23,7 @@ def fused_moe_kernel(
2223
sorted_token_ids,
2324
sorted_expert_ids,
2425
num_token_replicas: int,
25-
mul_routed_weight: bool,
26+
mul_routed_weight: ConstBool,
2627
TILE_M: ConstInt,
2728
TILE_N: ConstInt,
2829
TILE_K: ConstInt,
@@ -84,9 +85,30 @@ def fused_moe_kernel(
8485
ct.scatter(C, (token_ids[:, None], c_col_indices[None, :]), accumulator)
8586

8687

87-
def silu_and_mul_torch(input: torch.Tensor, out: torch.Tensor):
88-
gate_result, up_result = input.chunk(2, dim=-1)
89-
torch.mul(F.silu(gate_result), up_result, out=out)
88+
@ct.kernel
89+
def silu_and_mul_kernel(A, B, C, TILE_N: ConstInt):
90+
"""
91+
Element-wise kernel that computes SiLU(A) * B.
92+
93+
Args:
94+
A: Input tensor A.
95+
B: Input tensor B.
96+
C: Output tensor.
97+
"""
98+
99+
bid_m = ct.bid(0)
100+
ta = ct.load(A, (bid_m, 0), (1, TILE_N)).astype(ct.float32)
101+
tb = ct.load(B, (bid_m, 0), (1, TILE_N)).astype(ct.float32)
102+
103+
# Sigmoid(ta)
104+
denom = ct.add(1, ct.exp(-ta), flush_to_zero=True)
105+
sigmoid_ta = ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=ct.RoundingMode.APPROX)
106+
107+
# SiLU(ta) * tb
108+
silu_ta = ct.mul(ta, sigmoid_ta, flush_to_zero=True)
109+
tc = ct.mul(silu_ta, tb, flush_to_zero=True)
110+
111+
ct.store(C, (bid_m, 0), tc.astype(C.dtype))
90112

91113

92114
def moe_align_tile_size_torch(
@@ -244,7 +266,7 @@ def cutile_moe(
244266
tile_k=tile_k,
245267
)
246268

247-
silu_and_mul_torch(
269+
invoke_silu_and_mul_kernel(
248270
intermediate_cache1.view(-1, intermediate_cache1.shape[-1]),
249271
intermediate_cache2,
250272
)
@@ -353,6 +375,37 @@ def invoke_fused_moe_kernel(
353375
)
354376

355377

378+
def invoke_silu_and_mul_kernel(
379+
AB: torch.Tensor,
380+
C: torch.Tensor
381+
):
382+
A, B = AB.chunk(2, dim=-1)
383+
ct.launch(
384+
torch.cuda.current_stream(),
385+
(AB.shape[0],),
386+
silu_and_mul_kernel,
387+
(
388+
A,
389+
B,
390+
C,
391+
next_power_of_2(C.shape[-1])
392+
)
393+
)
394+
395+
396+
def next_power_of_2(n: int):
397+
"""Return the smallest power of 2 greater than or equal to n"""
398+
n -= 1
399+
n |= n >> 1
400+
n |= n >> 2
401+
n |= n >> 4
402+
n |= n >> 8
403+
n |= n >> 16
404+
n |= n >> 32
405+
n += 1
406+
return n
407+
408+
356409
if __name__ == "__main__":
357410
parser = argparse.ArgumentParser()
358411
parser.add_argument(

samples/templates/MoE.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn.functional as F
1010
import cuda.tile as ct
1111

12-
from test.kernels.fused_moe import fused_moe_kernel, moe_align_tile_size_torch, silu_and_mul_torch
12+
from test.kernels.fused_moe import fused_moe_kernel, moe_align_tile_size_torch, silu_and_mul_kernel
1313

1414

1515
# --- cuTile MoE Wrapper ------------------------------------------------------
@@ -85,7 +85,7 @@ def cutile_moe(
8585
tile_k=tile_k,
8686
)
8787

88-
silu_and_mul_torch(
88+
invoke_silu_and_mul_kernel(
8989
intermediate_cache1.view(-1, intermediate_cache1.shape[-1]),
9090
intermediate_cache2,
9191
)
@@ -194,6 +194,37 @@ def invoke_fused_moe_kernel(
194194
)
195195

196196

197+
def invoke_silu_and_mul_kernel(
198+
AB: torch.Tensor,
199+
C: torch.Tensor
200+
):
201+
A, B = AB.chunk(2, dim=-1)
202+
ct.launch(
203+
torch.cuda.current_stream(),
204+
(AB.shape[0],),
205+
silu_and_mul_kernel,
206+
(
207+
A,
208+
B,
209+
C,
210+
next_power_of_2(C.shape[-1])
211+
)
212+
)
213+
214+
215+
def next_power_of_2(n: int):
216+
"""Return the smallest power of 2 greater than or equal to n"""
217+
n -= 1
218+
n |= n >> 1
219+
n |= n >> 2
220+
n |= n >> 4
221+
n |= n >> 8
222+
n |= n >> 16
223+
n |= n >> 32
224+
n += 1
225+
return n
226+
227+
197228
if __name__ == "__main__":
198229
parser = argparse.ArgumentParser()
199230
parser.add_argument(

test/bench_moe.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import cuda.tile as ct
1010

1111
from conftest import dtype_id, shape_id
12-
from util import estimate_bench_iter
13-
from kernels.fused_moe import fused_moe_kernel, silu_and_mul_torch, moe_align_tile_size_torch
12+
from util import estimate_bench_iter, next_power_of_2
13+
from kernels.fused_moe import fused_moe_kernel, silu_and_mul_kernel, moe_align_tile_size_torch
1414

1515

1616
@pytest.fixture(params=[
@@ -186,7 +186,7 @@ def cutile_moe(
186186
tile_k,
187187
)
188188

189-
silu_and_mul_torch(
189+
invoke_silu_and_mul_kernel(
190190
intermediate_cache1.view(-1, intermediate_cache1.shape[-1]),
191191
intermediate_cache2,
192192
)
@@ -247,3 +247,21 @@ def invoke_fused_moe_kernel(
247247
tile_k,
248248
)
249249
)
250+
251+
252+
def invoke_silu_and_mul_kernel(
253+
AB: torch.Tensor,
254+
C: torch.Tensor
255+
):
256+
A, B = AB.chunk(2, dim=-1)
257+
ct.launch(
258+
torch.cuda.current_stream(),
259+
(AB.shape[0],),
260+
silu_and_mul_kernel,
261+
(
262+
A,
263+
B,
264+
C,
265+
next_power_of_2(C.shape[-1])
266+
)
267+
)

test/bench_rms_norm.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,13 @@
88
import torch
99
import cuda.tile as ct
1010
from math import ceil
11-
from util import estimate_bench_iter
11+
from util import estimate_bench_iter, next_power_of_2
1212
from kernels.rms_norm import (
1313
rms_norm_kernel, rms_norm_kernel_gather, rms_norm_kernel_static_persistent
1414
)
1515
from autotuner.autotuner import Autotuner, Config, SearchSpace, autotune
1616

1717

18-
def next_power_of_2(n: int):
19-
"""Return the smallest power of 2 greater than or equal to n"""
20-
n -= 1
21-
n |= n >> 1
22-
n |= n >> 2
23-
n |= n >> 4
24-
n |= n >> 8
25-
n |= n >> 16
26-
n |= n >> 32
27-
n += 1
28-
return n
29-
30-
3118
@pytest.fixture(params=[
3219
(262144, 1024),
3320
(262144, 2048),

test/kernels/fused_moe.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
import cuda.tile as ct
66
import torch
7-
import torch.nn.functional as F
87

98
from kernels.matmul import swizzle_2d
109

1110
ConstInt = ct.Constant[int]
11+
ConstBool = ct.Constant[bool]
1212

1313

1414
@ct.kernel
@@ -20,7 +20,7 @@ def fused_moe_kernel(
2020
sorted_token_ids,
2121
sorted_expert_ids,
2222
num_token_replicas: int,
23-
mul_routed_weight: bool,
23+
mul_routed_weight: ConstBool,
2424
TILE_M: ConstInt,
2525
TILE_N: ConstInt,
2626
TILE_K: ConstInt,
@@ -82,12 +82,33 @@ def fused_moe_kernel(
8282
ct.scatter(C, (token_ids[:, None], c_col_indices[None, :]), accumulator)
8383

8484

85-
# -- PyTorch Utilities --
85+
@ct.kernel
86+
def silu_and_mul_kernel(A, B, C, TILE_N: ConstInt):
87+
"""
88+
Element-wise kernel that computes SiLU(A) * B.
89+
90+
Args:
91+
A: Input tensor A.
92+
B: Input tensor B.
93+
C: Output tensor.
94+
"""
8695

87-
def silu_and_mul_torch(input: torch.Tensor, out: torch.Tensor):
88-
gate_result, up_result = input.chunk(2, dim=-1)
89-
torch.mul(F.silu(gate_result), up_result, out=out)
96+
bid_m = ct.bid(0)
97+
ta = ct.load(A, (bid_m, 0), (1, TILE_N)).astype(ct.float32)
98+
tb = ct.load(B, (bid_m, 0), (1, TILE_N)).astype(ct.float32)
9099

100+
# Sigmoid(ta)
101+
denom = ct.add(1, ct.exp(-ta), flush_to_zero=True)
102+
sigmoid_ta = ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=ct.RoundingMode.APPROX)
103+
104+
# SiLU(ta) * tb
105+
silu_ta = ct.mul(ta, sigmoid_ta, flush_to_zero=True)
106+
tc = ct.mul(silu_ta, tb, flush_to_zero=True)
107+
108+
ct.store(C, (bid_m, 0), tc.astype(C.dtype))
109+
110+
111+
# -- PyTorch Utilities --
91112

92113
def moe_align_tile_size_torch(
93114
topk_ids: torch.Tensor, tile_m: int, num_experts: int

test/util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,16 @@ def get_int_dtype_of_same_size(t: torch.dtype) -> torch.dtype:
196196
case torch.int16: return torch.int16
197197
case torch.int8: return torch.int8
198198
case _: raise NotImplementedError()
199+
200+
201+
def next_power_of_2(n: int):
202+
"""Return the smallest power of 2 greater than or equal to n"""
203+
n -= 1
204+
n |= n >> 1
205+
n |= n >> 2
206+
n |= n >> 4
207+
n |= n >> 8
208+
n |= n >> 16
209+
n |= n >> 32
210+
n += 1
211+
return n

0 commit comments

Comments
 (0)