@@ -76,11 +76,14 @@ def make_model(in_size, out_size, num_layers):
7676num_batches = 50
7777epochs = 3
7878
79+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
80+ torch .set_default_device (device )
81+
7982# Creates data in default precision.
8083# The same data is used for both default and mixed precision trials below.
8184# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.
82- data = [torch .randn (batch_size , in_size , device = "cuda" ) for _ in range (num_batches )]
83- targets = [torch .randn (batch_size , out_size , device = "cuda" ) for _ in range (num_batches )]
85+ data = [torch .randn (batch_size , in_size ) for _ in range (num_batches )]
86+ targets = [torch .randn (batch_size , out_size ) for _ in range (num_batches )]
8487
8588loss_fn = torch .nn .MSELoss ().cuda ()
8689
@@ -116,7 +119,7 @@ def make_model(in_size, out_size, num_layers):
116119for epoch in range (0 ): # 0 epochs, this section is for illustration only
117120 for input , target in zip (data , targets ):
118121 # Runs the forward pass under ``autocast``.
119- with torch .autocast (device_type = 'cuda' , dtype = torch .float16 ):
122+ with torch .autocast (device_type = device , dtype = torch .float16 ):
120123 output = net (input )
121124 # output is float16 because linear layers ``autocast`` to float16.
122125 assert output .dtype is torch .float16
@@ -151,7 +154,7 @@ def make_model(in_size, out_size, num_layers):
151154
152155for epoch in range (0 ): # 0 epochs, this section is for illustration only
153156 for input , target in zip (data , targets ):
154- with torch .autocast (device_type = 'cuda' , dtype = torch .float16 ):
157+ with torch .autocast (device_type = device , dtype = torch .float16 ):
155158 output = net (input )
156159 loss = loss_fn (output , target )
157160
@@ -184,7 +187,7 @@ def make_model(in_size, out_size, num_layers):
184187start_timer ()
185188for epoch in range (epochs ):
186189 for input , target in zip (data , targets ):
187- with torch .autocast (device_type = 'cuda' , dtype = torch .float16 , enabled = use_amp ):
190+ with torch .autocast (device_type = device , dtype = torch .float16 , enabled = use_amp ):
188191 output = net (input )
189192 loss = loss_fn (output , target )
190193 scaler .scale (loss ).backward ()
@@ -202,7 +205,7 @@ def make_model(in_size, out_size, num_layers):
202205
203206for epoch in range (0 ): # 0 epochs, this section is for illustration only
204207 for input , target in zip (data , targets ):
205- with torch .autocast (device_type = 'cuda' , dtype = torch .float16 ):
208+ with torch .autocast (device_type = device , dtype = torch .float16 ):
206209 output = net (input )
207210 loss = loss_fn (output , target )
208211 scaler .scale (loss ).backward ()
0 commit comments