@@ -96,12 +96,12 @@ def train(data):
9696# - ``schedule`` - callable that takes step (int) as a single parameter
9797# and returns the profiler action to perform at each step.
9898#
99- # In this example with ``wait=1, warmup=1, active=3, repeat=2 ``,
99+ # In this example with ``wait=1, warmup=1, active=3, repeat=1 ``,
100100# profiler will skip the first step/iteration,
101101# start warming up on the second,
102102# record the following three iterations,
103103# after which the trace will become available and on_trace_ready (when set) is called.
104- # In total, the cycle repeats twice . Each cycle is called a "span" in TensorBoard plugin.
104+ # In total, the cycle repeats once . Each cycle is called a "span" in TensorBoard plugin.
105105#
106106# During ``wait`` steps, the profiler is disabled.
107107# During ``warmup`` steps, the profiler starts tracing but the results are discarded.
@@ -120,31 +120,31 @@ def train(data):
120120# clicking a stack frame will navigate to the specific code line.
121121
122122with torch .profiler .profile (
123- schedule = torch .profiler .schedule (wait = 1 , warmup = 1 , active = 3 , repeat = 2 ),
123+ schedule = torch .profiler .schedule (wait = 1 , warmup = 1 , active = 3 , repeat = 1 ),
124124 on_trace_ready = torch .profiler .tensorboard_trace_handler ('./log/resnet18' ),
125125 record_shapes = True ,
126126 profile_memory = True ,
127127 with_stack = True
128128) as prof :
129129 for step , batch_data in enumerate (train_loader ):
130- if step >= (1 + 1 + 3 ) * 2 :
130+ prof .step () # Need to call this at each step to notify profiler of steps' boundary.
131+ if step >= 1 + 1 + 3 :
131132 break
132133 train (batch_data )
133- prof .step () # Need to call this at the end of each step to notify profiler of steps' boundary.
134134
135135######################################################################
136136# Alternatively, the following non-context manager start/stop is supported as well.
137137prof = torch .profiler .profile (
138- schedule = torch .profiler .schedule (wait = 1 , warmup = 1 , active = 3 , repeat = 2 ),
138+ schedule = torch .profiler .schedule (wait = 1 , warmup = 1 , active = 3 , repeat = 1 ),
139139 on_trace_ready = torch .profiler .tensorboard_trace_handler ('./log/resnet18' ),
140140 record_shapes = True ,
141141 with_stack = True )
142142prof .start ()
143143for step , batch_data in enumerate (train_loader ):
144- if step >= (1 + 1 + 3 ) * 2 :
144+ prof .step ()
145+ if step >= 1 + 1 + 3 :
145146 break
146147 train (batch_data )
147- prof .step ()
148148prof .stop ()
149149
150150######################################################################
@@ -158,6 +158,10 @@ def train(data):
158158# 4. Use TensorBoard to view results and analyze model performance
159159# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160160#
161+ # .. note::
162+ # TensorBoard Plugin support has been deprecated, so some of these functions may not
163+ # work as previously. Please take a look at the replacement, `HTA <https://github.com/pytorch/kineto/tree/main#holistic-trace-analysis>`_.
164+ #
161165# Install PyTorch Profiler TensorBoard Plugin.
162166#
163167# .. code-block::
@@ -395,5 +399,6 @@ def train(data):
395399# Take a look at the following documents to continue your learning,
396400# and feel free to open an issue `here <https://github.com/pytorch/kineto/issues>`_.
397401#
398- # - `Pytorch TensorBoard Profiler github <https://github.com/pytorch/kineto/tree/master/tb_plugin>`_
402+ # - `PyTorch TensorBoard Profiler Github <https://github.com/pytorch/kineto/tree/master/tb_plugin>`_
399403# - `torch.profiler API <https://pytorch.org/docs/master/profiler.html>`_
404+ # - `HTA <https://github.com/pytorch/kineto/tree/main#holistic-trace-analysis>`_
0 commit comments