1111# ``torch.compile`` makes PyTorch code run faster by
1212# JIT-compiling PyTorch code into optimized kernels,
1313# all while requiring minimal code changes.
14- #
14+ #
1515# In this tutorial, we cover basic ``torch.compile`` usage,
1616# and demonstrate the advantages of ``torch.compile`` over
1717# previous PyTorch compiler solutions, such as
18- # `TorchScript <https://pytorch.org/docs/stable/jit.html>`__ and
18+ # `TorchScript <https://pytorch.org/docs/stable/jit.html>`__ and
1919# `FX Tracing <https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace>`__.
2020#
2121# **Contents**
22- #
22+ #
2323# - Basic Usage
2424# - Demonstrating Speedups
2525# - Comparison to TorchScript and FX Tracing
5959#
6060# ``torch.compile`` is included in the latest PyTorch..
6161# Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly
62- # binary. If Triton is still missing, try installing ``torchtriton`` via pip
62+ # binary. If Triton is still missing, try installing ``torchtriton`` via pip
6363# (``pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"``
6464# for CUDA 11.7).
6565#
@@ -104,7 +104,7 @@ def forward(self, x):
104104# -----------------------
105105#
106106# Let's now demonstrate that using ``torch.compile`` can speed
107- # up real models. We will compare standard eager mode and
107+ # up real models. We will compare standard eager mode and
108108# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data.
109109#
110110# Before we start, we need to define some utility functions.
@@ -253,15 +253,15 @@ def train(mod, data):
253253######################################################################
254254# Comparison to TorchScript and FX Tracing
255255# -----------------------------------------
256- #
256+ #
257257# We have seen that ``torch.compile`` can speed up PyTorch code.
258258# Why else should we use ``torch.compile`` over existing PyTorch
259259# compiler solutions, such as TorchScript or FX Tracing? Primarily, the
260260# advantage of ``torch.compile`` lies in its ability to handle
261261# arbitrary Python code with minimal changes to existing code.
262262#
263263# One case that ``torch.compile`` can handle that other compiler
264- # solutions struggle with is data-dependent control flow (the
264+ # solutions struggle with is data-dependent control flow (the
265265# ``if x.sum() < 0:`` line below).
266266
267267def f1 (x , y ):
@@ -399,7 +399,7 @@ def f3(x):
399399# `FX graphs <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__, which can
400400# then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode
401401# during runtime and detecting calls to PyTorch operations.
402- #
402+ #
403403# Normally, TorchInductor, another component of ``torch.compile``,
404404# further compiles the FX graphs into optimized kernels,
405405# but TorchDynamo allows for different backends to be used. In order to inspect
@@ -463,10 +463,8 @@ def bar(a, b):
463463
464464# Reset since we are using a different backend.
465465torch ._dynamo .reset ()
466- explanation , out_guards , graphs , ops_per_graph , break_reasons , explanation_verbose = torch ._dynamo .explain (
467- bar , torch .randn (10 ), torch .randn (10 )
468- )
469- print (explanation_verbose )
466+ explain_output = torch ._dynamo .explain (bar )(torch .randn (10 ), torch .randn (10 ))
467+ print (explain_output )
470468
471469######################################################################
472470# In order to maximize speedup, graph breaks should be limited.
@@ -487,16 +485,18 @@ def bar(a, b):
487485print (opt_model (generate_data (16 )[0 ]))
488486
489487######################################################################
488+ # <!----TODO: replace this section with a link to the torch.export tutorial when done --->
489+ #
490490# Finally, if we simply want TorchDynamo to output the FX graph for export,
491491# we can use ``torch._dynamo.export``. Note that ``torch._dynamo.export``, like
492492# ``fullgraph=True``, raises an error if TorchDynamo breaks the graph.
493493
494494try :
495- torch ._dynamo .export (bar , torch .randn (10 ), torch .randn (10 ))
495+ torch ._dynamo .export (bar )( torch .randn (10 ), torch .randn (10 ))
496496except :
497497 tb .print_exc ()
498498
499- model_exp = torch ._dynamo .export (init_model (), generate_data (16 )[0 ])
499+ model_exp = torch ._dynamo .export (init_model ())( generate_data (16 )[0 ])
500500print (model_exp [0 ](generate_data (16 )[0 ]))
501501
502502######################################################################
0 commit comments