@@ -175,7 +175,7 @@ def evaluate(mod, inp):
175175######################################################################
176176# And indeed, we can see that running our model with ``torch.compile``
177177# results in a significant speedup. On an NVIDIA A100 GPU, we observe a
178- # 2.2x speedup. Speedup mainly comes from reducing Python overhead and
178+ # 2.3x speedup. Speedup mainly comes from reducing Python overhead and
179179# GPU read/writes, and so the observed speedup may vary on factors such as model
180180# architecture and batch size. For example, if a model's architecture is simple
181181# and the amount of data is large, then the bottleneck would be
@@ -197,16 +197,16 @@ def evaluate(mod, inp):
197197opt = torch .optim .Adam (model .parameters ())
198198
199199def train (mod , data ):
200+ opt .zero_grad (True )
200201 pred = mod (data [0 ])
201202 loss = torch .nn .CrossEntropyLoss ()(pred , data [1 ])
202203 loss .backward ()
204+ opt .step ()
203205
204206eager_times = []
205207for i in range (N_ITERS ):
206208 inp = generate_data (16 )
207- opt .zero_grad (True )
208209 _ , eager_time = timed (lambda : train (model , inp ))
209- opt .step ()
210210 eager_times .append (eager_time )
211211 print (f"eager train time { i } : { eager_time } " )
212212print ("~" * 10 )
@@ -218,9 +218,7 @@ def train(mod, data):
218218compile_times = []
219219for i in range (N_ITERS ):
220220 inp = generate_data (16 )
221- opt .zero_grad (True )
222221 _ , compile_time = timed (lambda : train_opt (model , inp ))
223- opt .step ()
224222 compile_times .append (compile_time )
225223 print (f"compile train time { i } : { compile_time } " )
226224print ("~" * 10 )
@@ -235,13 +233,7 @@ def train(mod, data):
235233# Again, we can see that ``torch.compile`` takes longer in the first
236234# iteration, as it must compile the model, but afterward, we see
237235# significant speedups compared to eager. On an NVIDIA A100 GPU, we
238- # observe a 1.8x speedup.
239- #
240- # One thing to note is that, as of now, we cannot place optimizer code --
241- # ``opt.zero_grad`` and ``opt.step`` -- inside of an optimized function.
242- # The rest of the training loop -- the forward pass and the backward pass --
243- # can be optimized. We are currently working on enabling optimizers to be
244- # compatible with ``torch.compile``.
236+ # observe a 2.2x speedup.
245237
246238######################################################################
247239# Comparison to TorchScript and FX Tracing
0 commit comments