diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 5963fdcb3619..62407a317e9f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -995,6 +995,9 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel(), param.dtype) self.reduce_ipg_grads() if self.contiguous_gradients and self.overlap_comm: + if not get_accelerator().resolves_data_dependency(): + self.reduction_stream.wait_stream(get_accelerator().current_stream()) + get_accelerator().current_stream().wait_stream(self.reduction_stream) # Swap index between 0 and 1 bucket.index = 1 - bucket.index self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel(), param.dtype)