Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 90e5b05

Browse files
Fix end-of-epoch with record_stream
1 parent 1ccaaf4 commit 90e5b05

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

examples/imagenet/main_amp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,10 @@ def next(self):
300300
torch.cuda.current_stream().wait_stream(self.stream)
301301
input = self.next_input
302302
target = self.next_target
303-
input.record_stream(torch.cuda.current_stream())
304-
target.record_stream(torch.cuda.current_stream())
303+
if input is not None:
304+
input.record_stream(torch.cuda.current_stream())
305+
if target is not None:
306+
target.record_stream(torch.cuda.current_stream())
305307
self.preload()
306308
return input, target
307309

0 commit comments

Comments
 (0)