1111
1212
1313ConstInt = ct .Constant [int ]
14+ ConstBool = ct .Constant [bool ]
1415
1516
1617@ct .kernel
@@ -22,7 +23,7 @@ def fused_moe_kernel(
2223 sorted_token_ids ,
2324 sorted_expert_ids ,
2425 num_token_replicas : int ,
25- mul_routed_weight : bool ,
26+ mul_routed_weight : ConstBool ,
2627 TILE_M : ConstInt ,
2728 TILE_N : ConstInt ,
2829 TILE_K : ConstInt ,
@@ -84,9 +85,30 @@ def fused_moe_kernel(
8485 ct .scatter (C , (token_ids [:, None ], c_col_indices [None , :]), accumulator )
8586
8687
87- def silu_and_mul_torch (input : torch .Tensor , out : torch .Tensor ):
88- gate_result , up_result = input .chunk (2 , dim = - 1 )
89- torch .mul (F .silu (gate_result ), up_result , out = out )
88+ @ct .kernel
89+ def silu_and_mul_kernel (A , B , C , TILE_N : ConstInt ):
90+ """
91+ Element-wise kernel that computes SiLU(A) * B.
92+
93+ Args:
94+ A: Input tensor A.
95+ B: Input tensor B.
96+ C: Output tensor.
97+ """
98+
99+ bid_m = ct .bid (0 )
100+ ta = ct .load (A , (bid_m , 0 ), (1 , TILE_N )).astype (ct .float32 )
101+ tb = ct .load (B , (bid_m , 0 ), (1 , TILE_N )).astype (ct .float32 )
102+
103+ # Sigmoid(ta)
104+ denom = ct .add (1 , ct .exp (- ta ), flush_to_zero = True )
105+ sigmoid_ta = ct .truediv (1.0 , denom , flush_to_zero = True , rounding_mode = ct .RoundingMode .APPROX )
106+
107+ # SiLU(ta) * tb
108+ silu_ta = ct .mul (ta , sigmoid_ta , flush_to_zero = True )
109+ tc = ct .mul (silu_ta , tb , flush_to_zero = True )
110+
111+ ct .store (C , (bid_m , 0 ), tc .astype (C .dtype ))
90112
91113
92114def moe_align_tile_size_torch (
@@ -244,7 +266,7 @@ def cutile_moe(
244266 tile_k = tile_k ,
245267 )
246268
247- silu_and_mul_torch (
269+ invoke_silu_and_mul_kernel (
248270 intermediate_cache1 .view (- 1 , intermediate_cache1 .shape [- 1 ]),
249271 intermediate_cache2 ,
250272 )
@@ -353,6 +375,37 @@ def invoke_fused_moe_kernel(
353375 )
354376
355377
378+ def invoke_silu_and_mul_kernel (
379+ AB : torch .Tensor ,
380+ C : torch .Tensor
381+ ):
382+ A , B = AB .chunk (2 , dim = - 1 )
383+ ct .launch (
384+ torch .cuda .current_stream (),
385+ (AB .shape [0 ],),
386+ silu_and_mul_kernel ,
387+ (
388+ A ,
389+ B ,
390+ C ,
391+ next_power_of_2 (C .shape [- 1 ])
392+ )
393+ )
394+
395+
396+ def next_power_of_2 (n : int ):
397+ """Return the smallest power of 2 greater than or equal to n"""
398+ n -= 1
399+ n |= n >> 1
400+ n |= n >> 2
401+ n |= n >> 4
402+ n |= n >> 8
403+ n |= n >> 16
404+ n |= n >> 32
405+ n += 1
406+ return n
407+
408+
356409if __name__ == "__main__" :
357410 parser = argparse .ArgumentParser ()
358411 parser .add_argument (
0 commit comments