@@ -272,9 +272,23 @@ def preload(self):
272272 self .next_input = None
273273 self .next_target = None
274274 return
275+ # if record_stream() doesn't work, another option is to make sure device inputs are created
276+ # on the main stream.
277+ # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
278+ # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
279+ # Need to make sure the memory allocated for next_* is not still in use by the main stream
280+ # at the time we start copying to next_*:
281+ # self.stream.wait_stream(torch.cuda.current_stream())
275282 with torch .cuda .stream (self .stream ):
276283 self .next_input = self .next_input .cuda (non_blocking = True )
277284 self .next_target = self .next_target .cuda (non_blocking = True )
285+ # more code for the alternative if record_stream() doesn't work:
286+ # copy_ will record the use of the pinned source tensor in this side stream.
287+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
288+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
289+ # self.next_input = self.next_input_gpu
290+ # self.next_target = self.next_target_gpu
291+
278292 # With Amp, it isn't necessary to manually convert data to half.
279293 # if args.fp16:
280294 # self.next_input = self.next_input.half()
@@ -286,6 +300,8 @@ def next(self):
286300 torch .cuda .current_stream ().wait_stream (self .stream )
287301 input = self .next_input
288302 target = self .next_target
303+ input .record_stream (torch .cuda .current_stream ())
304+ target .record_stream (torch .cuda .current_stream ())
289305 self .preload ()
290306 return input , target
291307
0 commit comments