Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d5e2bb4

Browse files
Fix rare caching allocator race condition in imagenet prefetcher
1 parent c3bcf18 commit d5e2bb4

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

examples/imagenet/main_amp.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,23 @@ def preload(self):
272272
self.next_input = None
273273
self.next_target = None
274274
return
275+
# if record_stream() doesn't work, another option is to make sure device inputs are created
276+
# on the main stream.
277+
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
278+
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
279+
# Need to make sure the memory allocated for next_* is not still in use by the main stream
280+
# at the time we start copying to next_*:
281+
# self.stream.wait_stream(torch.cuda.current_stream())
275282
with torch.cuda.stream(self.stream):
276283
self.next_input = self.next_input.cuda(non_blocking=True)
277284
self.next_target = self.next_target.cuda(non_blocking=True)
285+
# more code for the alternative if record_stream() doesn't work:
286+
# copy_ will record the use of the pinned source tensor in this side stream.
287+
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
288+
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
289+
# self.next_input = self.next_input_gpu
290+
# self.next_target = self.next_target_gpu
291+
278292
# With Amp, it isn't necessary to manually convert data to half.
279293
# if args.fp16:
280294
# self.next_input = self.next_input.half()
@@ -286,6 +300,8 @@ def next(self):
286300
torch.cuda.current_stream().wait_stream(self.stream)
287301
input = self.next_input
288302
target = self.next_target
303+
input.record_stream(torch.cuda.current_stream())
304+
target.record_stream(torch.cuda.current_stream())
289305
self.preload()
290306
return input, target
291307

0 commit comments

Comments
 (0)