1+ """
2+ Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
3+ ================================================================================
4+
5+ """
6+
7+
8+ ######################################################################
9+ # Summary
10+ # ~~~~~~~~
11+ #
12+ # In this tutorial, we want to highlight a new ``torch.nn.functional`` function
13+ # that can be helpful for implementing transformer architectures. The
14+ # function is named ``torch.nn.functional.scaled_dot_product_attention``.
15+ # For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
16+ # This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
17+ #
18+ # Overview
19+ # ~~~~~~~~~
20+ # At a high level, this PyTorch function calculates the
21+ # scaled dot product attention (SDPA) between query, key, and value according to
22+ # the definition found in the paper `Attention is all you
23+ # need <https://arxiv.org/abs/1706.03762>`__. While this function can
24+ # be written in PyTorch using existing functions, a fused implementation can provide
25+ # large performance benefits over a naive implementation.
26+ #
27+ # Fused implementations
28+ # ~~~~~~~~~~~~~~~~~~~~~~
29+ #
30+ # For CUDA tensor inputs, the function will dispatch into one of the following
31+ # implementations:
32+ #
33+ # * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
34+ # * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
35+ # * A PyTorch implementation defined in C++
36+ #
37+
38+ import torch
39+ import torch .nn as nn
40+ import torch .nn .functional as F
41+ device = "cuda" if torch .cuda .is_available () else "cpu"
42+
43+ # Example Usage:
44+ query , key , value = torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device )
45+ F .scaled_dot_product_attention (query , key , value )
46+
47+
48+ ######################################################################
49+ # Explicit Dispatcher Control
50+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
51+ #
52+ # While the function will implicitly dispatch to one of the three
53+ # implementations, the user can also explicitly control the dispatch via
54+ # the use of a context manager. This context manager allows users to
55+ # explicitly disable certain implementations. If a user wants to ensure
56+ # the function is indeed using the fastest implementation for their
57+ # specific inputs, the context manager can be used to sweep through
58+ # measuring performance.
59+ #
60+
61+ # Lets define a helpful benchmarking function:
62+ import torch .utils .benchmark as benchmark
63+ def benchmark_torch_function_in_microseconds (f , * args , ** kwargs ):
64+ t0 = benchmark .Timer (
65+ stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
66+ )
67+ return t0 .blocked_autorange ().mean * 1e6
68+
69+ # Lets define the hyper-parameters of our input
70+ batch_size = 32
71+ max_sequence_len = 1024
72+ num_heads = 32
73+ embed_dimension = 32
74+
75+ dtype = torch .float16
76+
77+ query = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
78+ key = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
79+ value = torch .rand (batch_size , num_heads , max_sequence_len , embed_dimension , device = device , dtype = dtype )
80+
81+ print (f"The default implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
82+
83+ # Lets explore the speed of each of the 3 implementations
84+ from torch .backends .cuda import sdp_kernel , SDPBackend
85+
86+ # Helpful arg mapper
87+ backend_map = {
88+ SDPBackend .MATH : {"enable_math" : True , "enable_flash" : False , "enable_mem_efficient" : False },
89+ SDPBackend .FLASH_ATTENTION : {"enable_math" : False , "enable_flash" : True , "enable_mem_efficient" : False },
90+ SDPBackend .EFFICIENT_ATTENTION : {
91+ "enable_math" : False , "enable_flash" : False , "enable_mem_efficient" : True }
92+ }
93+
94+ with sdp_kernel (** backend_map [SDPBackend .MATH ]):
95+ print (f"The math implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
96+
97+
98+ with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
99+ try :
100+ print (f"The flash attention implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
101+ except RuntimeError :
102+ print ("FlashAttention is not supported. See warnings for reasons." )
103+
104+ with sdp_kernel (** backend_map [SDPBackend .EFFICIENT_ATTENTION ]):
105+ try :
106+ print (f"The memory efficient implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
107+ except RuntimeError :
108+ print ("EfficientAttention is not supported. See warnings for reasons." )
109+
110+
111+ ######################################################################
112+ # Hardware dependence
113+ # ~~~~~~~~~~~~~~~~~~~
114+ #
115+ # Depending on what machine you ran the above cell on and what hardware is
116+ # available, your results might be different.
117+ # - If you don’t have a GPU and are running on CPU then the context manager
118+ # will have no effect and all three runs should return similar timings.
119+ # - Depending on what compute capability your graphics card supports
120+ # flash attention or memory efficient might have failed.
121+
122+
123+ ######################################################################
124+ # Causal Self Attention
125+ # ~~~~~~~~~~~~~~~~~~~~~
126+ #
127+ # Below is an example implementation of a multi-headed causal self
128+ # attention block inspired by Andrej Karpathy’s
129+ # `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
130+ #
131+
132+ class CausalSelfAttention (nn .Module ):
133+
134+ def __init__ (self , num_heads : int , embed_dimension : int , bias : bool = False , is_causal : bool = False , dropout :float = 0.0 ):
135+ super ().__init__ ()
136+ assert embed_dimension % num_heads == 0
137+ # key, query, value projections for all heads, but in a batch
138+ self .c_attn = nn .Linear (embed_dimension , 3 * embed_dimension , bias = bias )
139+ # output projection
140+ self .c_proj = nn .Linear (embed_dimension , embed_dimension , bias = bias )
141+ # regularization
142+ self .dropout = dropout
143+ self .resid_dropout = nn .Dropout (dropout )
144+ self .num_heads = num_heads
145+ self .embed_dimension = embed_dimension
146+ # Perform causal masking
147+ self .is_causal = is_causal
148+
149+ def forward (self , x ):
150+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
151+ query_projected = self .c_attn (x )
152+
153+ batch_size = query_projected .size (0 )
154+ embed_dim = query_projected .size (2 )
155+ head_dim = embed_dim // (self .num_heads * 3 )
156+
157+ query , key , value = query_projected .chunk (3 , - 1 )
158+ query = query .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
159+ key = key .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
160+ value = value .view (batch_size , - 1 , self .num_heads , head_dim ).transpose (1 , 2 )
161+
162+ if self .training :
163+ dropout = self .dropout
164+ is_causal = self .is_causal
165+ else :
166+ dropout = 0.0
167+ is_causal = False
168+
169+ y = F .scaled_dot_product_attention (query , key , value , attn_mask = None , dropout_p = dropout , is_causal = is_causal )
170+ y = y .transpose (1 , 2 ).view (batch_size , - 1 , self .num_heads * head_dim )
171+
172+ y = self .resid_dropout (self .c_proj (y ))
173+ return y
174+
175+
176+ num_heads = 8
177+ heads_per_dim = 64
178+ embed_dimension = num_heads * heads_per_dim
179+ dtype = torch .float16
180+ model = CausalSelfAttention (num_heads = num_heads , embed_dimension = embed_dimension , bias = False , is_causal = True , dropout = 0.1 ).to ("cuda" ).to (dtype ).eval ()
181+ print (model )
182+
183+
184+ ######################################################################
185+ # NestedTensor and Dense tensor support
186+ # -------------------------------------
187+ #
188+ # SDPA supports both NestedTensor and Dense tensor inputs. NestedTensors handle the case where the input is a batch of variable length sequences
189+ # without needing to pad each sequence to the maximum length in the batch. For more information about NestedTensors see
190+ # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`__.
191+ #
192+
193+ import random
194+ def generate_rand_batch (
195+ batch_size ,
196+ max_sequence_len ,
197+ embed_dimension ,
198+ pad_percentage = None ,
199+ dtype = torch .float16 ,
200+ device = "cuda" ,
201+ ):
202+ if not pad_percentage :
203+ return (
204+ torch .randn (
205+ batch_size ,
206+ max_sequence_len ,
207+ embed_dimension ,
208+ dtype = dtype ,
209+ device = device ,
210+ ),
211+ None ,
212+ )
213+ # Random sequence lengths
214+ seq_len_list = [
215+ int (max_sequence_len * (1 - random .gauss (pad_percentage , 0.01 )))
216+ for _ in range (batch_size )
217+ ]
218+ # Make random entry in the batch have max sequence length
219+ seq_len_list [random .randint (0 , batch_size - 1 )] = max_sequence_len
220+ return (
221+ torch .nested .nested_tensor (
222+ [
223+ torch .randn (seq_len , embed_dimension ,
224+ dtype = dtype , device = device )
225+ for seq_len in seq_len_list
226+ ]
227+ ),
228+ seq_len_list ,
229+ )
230+
231+ random_nt , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = 0.5 , dtype = dtype , device = device )
232+ random_dense , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = None , dtype = dtype , device = device )
233+
234+ # Currently the fused implementations don't support NestedTensor for training
235+ model .eval ()
236+
237+ with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
238+ try :
239+ print (f"Random NT runs in { benchmark_torch_function_in_microseconds (model , random_nt ):.3f} microseconds" )
240+ print (f"Random Dense runs in { benchmark_torch_function_in_microseconds (model , random_dense ):.3f} microseconds" )
241+ except RuntimeError :
242+ print ("FlashAttention is not supported. See warnings for reasons." )
243+
244+
245+ ######################################################################
246+ # Using SDPA with torch.compile
247+ # ============================
248+ #
249+ # With the release of PyTorch 2.0, a new feature called
250+ # ``torch.compile()`` has been introduced, which can provide
251+ # significant performance improvements over eager mode.
252+ # Scaled dot product attention is fully composable with ``torch.compile()``.
253+ # To demonstrate this, let's compile the CausalSelfAttention module using
254+ # ``torch.compile()`` and observe the resulting performance improvements.
255+ #
256+
257+ batch_size = 32
258+ max_sequence_len = 256
259+ x = torch .rand (batch_size , max_sequence_len ,
260+ embed_dimension , device = device , dtype = dtype )
261+ print (
262+ f"The non compiled module runs in { benchmark_torch_function_in_microseconds (model , x ):.3f} microseconds" )
263+
264+
265+ compiled_model = torch .compile (model )
266+ # Let's compile it
267+ compiled_model (x )
268+ print (
269+ f"The compiled module runs in { benchmark_torch_function_in_microseconds (compiled_model , x ):.3f} microseconds" )
270+
271+
272+ ######################################################################
273+ #
274+ # The exact execution time is dependent on machine, however the results for mine:
275+ # The non compiled module runs in 166.616 microseconds
276+ # The compiled module runs in 166.726 microseconds
277+ # That is not what we were expecting. Let's dig a little deeper.
278+ # PyTorch comes with an amazing built-in profiler that you can use to
279+ # inspect the performance characteristics of your code.
280+ #
281+
282+ from torch .profiler import profile , record_function , ProfilerActivity
283+ activities = [ProfilerActivity .CPU ]
284+ if device == 'cuda' :
285+ activities .append (ProfilerActivity .CUDA )
286+
287+ with profile (activities = activities , record_shapes = False ) as prof :
288+ with record_function (" Non-Compilied Causal Attention" ):
289+ for _ in range (25 ):
290+ model (x )
291+ print (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 ))
292+
293+
294+ with profile (activities = activities , record_shapes = False ) as prof :
295+ with record_function ("Compiled Causal Attention" ):
296+ for _ in range (25 ):
297+ compiled_model (x )
298+ print (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 ))
299+
300+ # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
301+ # prof.export_chrome_trace("compiled_causal_attention_trace.json").
302+
303+
304+
305+
306+ ######################################################################
307+ # The previous code snippet generates a report of the top 10 PyTorch functions
308+ # that consumed the most GPU execution time, for both the compiled and non-compiled module.
309+ # The analysis reveals that the majority of time spent on the GPU is concentrated
310+ # on the same set of functions for both modules.
311+ # The reason for this here is that ``torch.compile`` is very good at removing the
312+ # framework overhead associated with PyTorch. If your model is launching
313+ # large, efficient CUDA kernels, which in this case CausaulSelfAttention
314+ # is, then the overhead of PyTorch can be hidden.
315+ #
316+ # In reality, your module does not normally consist of a singular
317+ # CausalSelfAttention block. When experimenting with Andrej Karpathy’s
318+ # `NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
319+ # the module took the time per train step from: ``6090.49ms`` to
320+ # ``3273.17ms``! This was done on commit: ae3a8d5 of NanoGPT training on
321+ # the shakespeare dataset.
322+ #
323+
324+
325+ ######################################################################
326+ # Conclusion
327+ # ==========
328+ #
329+ # In this tutorial, we have demonstrated the basic usage of
330+ # ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
331+ # the ``sdp_kernel`` context manager can be used to assert a certain
332+ # implementation is used on GPU. As well, we built a simple
333+ # CausalSelfAttention module that works with NestedTensor and is torch
334+ # compilable. In the process we have shown how to the profiling tools can
335+ # be used to explore the performance characteristics of a user defined
336+ # module.
337+ #
0 commit comments