Thanks to visit codestin.com
Credit goes to github.com

Skip to content

murrellb/NNop.jl

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NNop.jl

GPU Backend CI Status
AMDGPU
CUDA

Kernels (with ChainRules.jl integration):

Benchmarking

See benchmarks/main.jl for comparison scripts between naїve & fused versions.

Flash Attention

Implementation of FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

E, L, H, B = 64, 4096, 4, 4
causal = false

q = ROCArray(rand(Float32, E, L, H, B))
k = ROCArray(rand(Float32, E, L, H, B))
v = ROCArray(rand(Float32, E, L, H, B))

o = NNop.flash_attention(q, k, v; causal)
∇ = Zygote.gradient(q, k, v) do q, k, v
    sum(NNop.flash_attention(q, k, v; causal))
end
Naїve attention Flash Attention
FWD
Execution time 60.987 ms 18.380 ms
Peak memory usage 5.044 GiB 16.500 MiB
FWD + BWD
Execution time 1.154 s 306.960 ms
Peak memory usage 19.164 GiB 80.813 MiB

Features:

  • Forward & backward passes.
  • Arbitrary sequence length.
  • FP32, FP16, BFP16 support.
  • Variable sequence length.
  • Causal masking.

Fused Softmax

Implementation of Online normalizer calculation for softmax.

x = ROCArray(rand(Float32, 8192, 1024))
y = NNop.online_softmax(x)
Naїve Softmax Online Softmax
Execution time 745.123 μs 61.600 μs
Peak memory usage 64.258 MiB 32.000 MiB

Fused RMS Norm

x = ROCArray(rand(Float32, 1024, 1024))
w = ROCArray(rand(Float32, 1024))
y = NNop.rms_norm(x, w)
∇ = Zygote.gradient(x, w) do x, w
    sum(NNop.rms_norm(x, w))
end
Naїve RMS Norm Fused RMS Norm
FWD
Execution time 171.124 μs 48.432 μs
Peak memory usage 8.004 MiB 4.004 MiB
FWD + BWD
Execution time 902.919 μs 241.838 μs
Peak memory usage 44.043 MiB 13.008 MiB

Fused Layer Norm

x = ROCArray(rand(Float32, 1024, 1024))
w = ROCArray(rand(Float32, 1024))
w = ROCArray(rand(Float32, 1024))
y = NNop.layer_norm(x, w)
∇ = Zygote.gradient(x, w, b) do x, w, b
    sum(NNop.layer_norm(x, w, b))
end
Naїve Layer Norm Fused Layer Norm
FWD
Execution time 188.392 μs 48.175 μs
Peak memory usage 4.008 MiB 4.004 MiB
FWD + BWD
Execution time 1.150 ms 293.969 μs
Peak memory usage 52.055 MiB 14.016 MiB

About

Flash Attention & friends in pure Julia

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Julia 100.0%