* [ ] Refactor * Function: multi_head_attention * ref: https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#how-to-use-flashattention * ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html * Op: MultiHeadAttentionFwdOp * Kernel * [ ] Test * [ ] Benchmark * Baselines: flash-attention, triton