77`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ provides convenience methods for mixed precision,
88where some operations use the ``torch.float32`` (``float``) datatype and other operations
99use ``torch.float16`` (``half``). Some ops, like linear layers and convolutions,
10- are much faster in ``float16``. Other ops, like reductions, often require the dynamic
10+ are much faster in ``float16`` or ``bfloat16`` . Other ops, like reductions, often require the dynamic
1111range of ``float32``. Mixed precision tries to match each op to its appropriate datatype,
1212which can reduce your network's runtime and memory footprint.
1313
14- Ordinarily, "automatic mixed precision training" uses `torch.cuda.amp. autocast <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_ and
14+ Ordinarily, "automatic mixed precision training" uses `torch.autocast <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast>`_ and
1515`torch.cuda.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler>`_ together.
1616
1717This recipe measures the performance of a simple network in default precision,
@@ -116,7 +116,7 @@ def make_model(in_size, out_size, num_layers):
116116for epoch in range (0 ): # 0 epochs, this section is for illustration only
117117 for input , target in zip (data , targets ):
118118 # Runs the forward pass under autocast.
119- with torch .cuda . amp . autocast ():
119+ with torch .autocast (device_type = 'cuda' , dtype = torch . float16 ):
120120 output = net (input )
121121 # output is float16 because linear layers autocast to float16.
122122 assert output .dtype is torch .float16
@@ -151,7 +151,7 @@ def make_model(in_size, out_size, num_layers):
151151
152152for epoch in range (0 ): # 0 epochs, this section is for illustration only
153153 for input , target in zip (data , targets ):
154- with torch .cuda . amp . autocast ():
154+ with torch .autocast (device_type = 'cuda' , dtype = torch . float16 ):
155155 output = net (input )
156156 loss = loss_fn (output , target )
157157
@@ -184,7 +184,7 @@ def make_model(in_size, out_size, num_layers):
184184start_timer ()
185185for epoch in range (epochs ):
186186 for input , target in zip (data , targets ):
187- with torch .cuda . amp . autocast (enabled = use_amp ):
187+ with torch .autocast (device_type = 'cuda' , dtype = torch . float16 , enabled = use_amp ):
188188 output = net (input )
189189 loss = loss_fn (output , target )
190190 scaler .scale (loss ).backward ()
@@ -202,7 +202,7 @@ def make_model(in_size, out_size, num_layers):
202202
203203for epoch in range (0 ): # 0 epochs, this section is for illustration only
204204 for input , target in zip (data , targets ):
205- with torch .cuda . amp . autocast ():
205+ with torch .autocast (device_type = 'cuda' , dtype = torch . float16 ):
206206 output = net (input )
207207 loss = loss_fn (output , target )
208208 scaler .scale (loss ).backward ()
0 commit comments