Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix DeepCompile ZeRO-3 release parameter lifetime#8032

Merged
tohtana merged 2 commits into
deepspeedai:masterfrom
tohtana:tohtana/fix/deepcompile-release-return-storage
Jun 2, 2026
Merged

Fix DeepCompile ZeRO-3 release parameter lifetime#8032
tohtana merged 2 commits into
deepspeedai:masterfrom
tohtana:tohtana/fix/deepcompile-release-return-storage

Conversation

@tohtana
Copy link
Copy Markdown
Collaborator

@tohtana tohtana commented May 28, 2026

PR #7489 made ZeRO-3 all-gather allocate a padded base buffer for uneven shards and return a true-shape view into that buffer. That means the registry tensor and the tensor returned to the compiled graph no longer necessarily share the same TensorImpl, although they still share the same underlying storage.

The existing release path only did set_data(empty) on the registry tensor before unregistering it. With the new base/view relationship, that clears the registry-side tensor metadata but does not resize the shared StorageImpl still referenced by returned views. As a result, the padded gathered allocation can remain live after the final release_param.

This patch keeps the release graph ordering unchanged and makes final non-persistent release resize the registered gathered storage to 0 bytes before unregistering it.

@tohtana tohtana requested review from loadams and tjruwase as code owners May 28, 2026 16:20
@tohtana tohtana force-pushed the tohtana/fix/deepcompile-release-return-storage branch from 7b2ae31 to 04ca757 Compare May 28, 2026 16:21
@tohtana tohtana requested a review from eternalNight May 29, 2026 22:09
@eternalNight
Copy link
Copy Markdown
Contributor

The gathered buffer is registered in DSParamRegistry which holds a reference to the tensor until release_param removes it from the registry. So I think torch will not release the buffer storage even inductor drops the gathered view early. I was wondering why that ref-count-based mechanism fails in your case?

@tohtana tohtana force-pushed the tohtana/fix/deepcompile-release-return-storage branch from 7af789e to f075f74 Compare May 30, 2026 05:05
@tohtana tohtana force-pushed the tohtana/fix/deepcompile-release-return-storage branch from f075f74 to d715e86 Compare May 30, 2026 05:46
@tohtana
Copy link
Copy Markdown
Collaborator Author

tohtana commented May 30, 2026

Thank you for your review, @eternalNight!

The actual issue is that, after #7489, the release op does not release the underlying gathered buffer storage. It is not an early-release issue. I updated the PR description to clarify this.
I also reverted the graph modification after verifying that the existing ordering is still safe.

This PR is much simpler now.

@eternalNight
Copy link
Copy Markdown
Contributor

The actual issue is that, after #7489, the release op does not release the underlying gathered buffer storage. It is not an early-release issue. I updated the PR description to clarify this. I also reverted the graph modification after verifying that the existing ordering is still safe.

Thanks for the clarification! That makes the picture much clearer. I still have one doubt, though.

The DSParamRegistry holds a reference to the raw, possibly-padded buffer. The buffer is expected to be released once that reference is dropped in unregisterGatheredParam. If it still persists after being unregistered, there must be a reference alive somewhere else.

You may capture a torch memory history which records exactly when (by recording the call stack) each storage is finally released. There may be more hints in the call stack about the residual reference.

@tohtana
Copy link
Copy Markdown
Collaborator Author

tohtana commented May 31, 2026

Hi @eternalNight,

Thanks for the suggestion. Yes, that makes sense, and I agree there must still be a live reference after unregisterGatheredParam. My current understanding is that the reference is the true-shape view returned from allgather_param, rather than an unknown extra reference.

After #7489:

  • The registry stores the padded base tensor
  • allgather_param returns a true-shape view

Those tensors no longer necessarily share the same TensorImpl, but they still share the same StorageImpl. So set_data(empty) on the registry tensor only updates the registry-side tensor metadata. It does not resize the storage that is still visible through the returned view.

That is why this patch resizes the old StorageImpl directly with resize_bytes_cuda(storage.unsafeGetStorageImpl(), 0). The intent is not to change the registry tensor metadata again, but to release the underlying padded storage that may still be referenced by the returned view.

@eternalNight
Copy link
Copy Markdown
Contributor

Hi @eternalNight,

Thanks for the suggestion. Yes, that makes sense, and I agree there must still be a live reference after unregisterGatheredParam. My current understanding is that the reference is the true-shape view returned from allgather_param, rather than an unknown extra reference.

After #7489:

  • The registry stores the padded base tensor
  • allgather_param returns a true-shape view

Those tensors no longer necessarily share the same TensorImpl, but they still share the same StorageImpl. So set_data(empty) on the registry tensor only updates the registry-side tensor metadata. It does not resize the storage that is still visible through the returned view.

That is why this patch resizes the old StorageImpl directly with resize_bytes_cuda(storage.unsafeGetStorageImpl(), 0). The intent is not to change the registry tensor metadata again, but to release the underlying padded storage that may still be referenced by the returned view.

That is like manually freeing a buffer while we know there are some other pointers dangling around, which looks risky to me. Do you have a minimal reproduction of the issue? I can help investigate where the residual reference is and whether it is safe to assume that reference, while being held, will never be touched post release_param.

@tohtana
Copy link
Copy Markdown
Collaborator Author

tohtana commented Jun 1, 2026

Hi @eternalNight,

I think the residual reference is already identified here: it is the true-shape view returned from allgather_param, which can outlive the registry entry while sharing the same StorageImpl. That view is a normal graph value passed to downstream consumers, so it is not something the registry can simply delete.
I added tests/torch_compile/test_deepcompile_z3_release.py as the minimal repro for this case.

For the safety part of your question, the guarantee I am relying on is that the final release_param is placed after the last consumer of that view. The view object may still exist after release, but no later op should read from or write to it.

This follows the existing ZeRO-3 (no-DeepCompile) explicit-release model, which clears parameter data with set_data(empty). After #7489, that is no longer enough in DeepCompile because it only updates the registry tensor, while the returned view can still share the underlying StorageImpl. If you see a path where the returned view can be accessed after release_param, that is the unsafe case I should fix.

@eternalNight
Copy link
Copy Markdown
Contributor

Hi @eternalNight,

I think the residual reference is already identified here: it is the true-shape view returned from allgather_param, which can outlive the registry entry while sharing the same StorageImpl. That view is a normal graph value passed to downstream consumers, so it is not something the registry can simply delete. I added tests/torch_compile/test_deepcompile_z3_release.py as the minimal repro for this case.

For the safety part of your question, the guarantee I am relying on is that the final release_param is placed after the last consumer of that view. The view object may still exist after release, but no later op should read from or write to it.

That's what confused me. In torchinductor-generated graphs, references to tensors are droped after their last use. If the true-shape view is not consumed after release_param, its reference should be dropped, which allows torch to release the storage.

I have met a similar issue before with amp enabled, because amp caches casted weights by default. The amp weight cache thus holds another reference to the gathered buffer till the end of the forward phase, preventing release_param from releasing the weight buffers early. It can be worked around by passing cache_enabled=False to torch.autocast.

That's why I would like to have a repro and check its memory profile: storage references may be hold not because how torch inductor tracks data dependency, but because of other mechanisms under the hood. The unit tests do not serve the purpose because they already hold a reference to storage till the end for assertion checking.

This follows the existing ZeRO-3 (no-DeepCompile) explicit-release model, which clears parameter data with set_data(empty). After #7489, that is no longer enough in DeepCompile because it only updates the registry tensor, while the returned view can still share the underlying StorageImpl. If you see a path where the returned view can be accessed after release_param, that is the unsafe case I should fix.

@tohtana
Copy link
Copy Markdown
Collaborator Author

tohtana commented Jun 1, 2026

@eternalNight I got your point, let me do small experiments to clarify it.

@tohtana
Copy link
Copy Markdown
Collaborator Author

tohtana commented Jun 1, 2026

Hi @eternalNight,

I ran a separate small DeepCompile ZeRO-3 experiment to check where we keep the reference. At the Python level, the relevant part is roughly just a linear layer:

out = linear(x)
loss = loss_fn(out, target)

With DeepCompile / ZeRO-3, the generated code is conceptually closer to this:

bias_buf = wait_allgather(allgather_param(linear.bias))
weight_buf = wait_allgather(allgather_param(linear.weight))

out = torch.addmm(bias_buf, x, weight_buf.t())

out = release_param(out, linear.bias)
out = release_param(out, linear.weight)

loss = loss_fn(out, target)

So after the addmm, the gathered parameter views should no longer be consumed. However, the generated wrapper I observed keeps local aliases (buf1/buf3/buf5/buf7 below) of the gathered views alive across the later release_param calls:

buf0 = torch.ops.dc.allgather_param.default(...)
buf1 = buf0
del buf0

buf2 = torch.ops.dc.allgather_param.default(...)
buf3 = buf2
del buf2

buf4 = torch.ops.dc.wait_allgather.default(buf3, ...)
buf5 = buf4

buf6 = torch.ops.dc.wait_allgather.default(buf1, ...)
buf7 = buf6

extern_kernels.addmm(buf5, x, reinterpret_tensor(buf7, ...), out=buf8)
del buf4
del buf6

buf9 = torch.ops.dc.release_param.default(buf8, ...)
buf10 = buf9
buf11 = torch.ops.dc.release_param.default(buf10, ...)
buf12 = buf11

triton_kernel.run(..., buf12, ...)

This behavior is related to how these DeepCompile custom ops are registered as fallback kernels. Today DeepCompile intentionally marks the outputs as never_reuse_output=True:

register_fallback_no_reuse(
    torch.ops.dc.allgather_param.default,
    never_reuse_input=False,
    never_reuse_output=True,
)
register_fallback_no_reuse(
    torch.ops.dc.wait_allgather.default,
    never_reuse_input=True,
    never_reuse_output=True,
)

If I change wait_allgather and allgather_param to never_reuse_output=False, Inductor emits deletes for those aliases. However, this version hit a CUDA illegal memory access in my experiment. This matches the reason those outputs are marked no-reuse.

But we cannot make this purely GC-driven either. ZeRO-3 needs the Parameter object itself to remain stable across the module lifetime. My understanding is that this is why the existing non-DeepCompile ZeRO-3 path used set_data(empty).

A more complete long-term direction might be to implement a custom lowering / aliasing model for these ops, so Inductor can understand the external lifetime more precisely. That would be a larger change, though.

@eternalNight
Copy link
Copy Markdown
Contributor

Hi @eternalNight,

I ran a separate small DeepCompile ZeRO-3 experiment to check where we keep the reference. At the Python level, the relevant part is roughly just a linear layer:

out = linear(x)
loss = loss_fn(out, target)

With DeepCompile / ZeRO-3, the generated code is conceptually closer to this:

bias_buf = wait_allgather(allgather_param(linear.bias))
weight_buf = wait_allgather(allgather_param(linear.weight))

out = torch.addmm(bias_buf, x, weight_buf.t())

out = release_param(out, linear.bias)
out = release_param(out, linear.weight)

loss = loss_fn(out, target)

So after the addmm, the gathered parameter views should no longer be consumed. However, the generated wrapper I observed keeps local aliases (buf1/buf3/buf5/buf7 below) of the gathered views alive across the later release_param calls:

buf0 = torch.ops.dc.allgather_param.default(...)
buf1 = buf0
del buf0

buf2 = torch.ops.dc.allgather_param.default(...)
buf3 = buf2
del buf2

buf4 = torch.ops.dc.wait_allgather.default(buf3, ...)
buf5 = buf4

buf6 = torch.ops.dc.wait_allgather.default(buf1, ...)
buf7 = buf6

extern_kernels.addmm(buf5, x, reinterpret_tensor(buf7, ...), out=buf8)
del buf4
del buf6

buf9 = torch.ops.dc.release_param.default(buf8, ...)
buf10 = buf9
buf11 = torch.ops.dc.release_param.default(buf10, ...)
buf12 = buf11

triton_kernel.run(..., buf12, ...)

This behavior is related to how these DeepCompile custom ops are registered as fallback kernels. Today DeepCompile intentionally marks the outputs as never_reuse_output=True:

register_fallback_no_reuse(
    torch.ops.dc.allgather_param.default,
    never_reuse_input=False,
    never_reuse_output=True,
)
register_fallback_no_reuse(
    torch.ops.dc.wait_allgather.default,
    never_reuse_input=True,
    never_reuse_output=True,
)

If I change wait_allgather and allgather_param to never_reuse_output=False, Inductor emits deletes for those aliases. However, this version hit a CUDA illegal memory access in my experiment. This matches the reason those outputs are marked no-reuse.

Thanks for the detailed investigation, @tohtana ! That explains why the residual reference remains till the end of the graph. Torch inductor marks a tensor to be unreusable by keeping its reference alive, which is simple but makes such tensors live longer and eventually increases the peak GPU memory usage.

But we cannot make this purely GC-driven either. ZeRO-3 needs the Parameter object itself to remain stable across the module lifetime. My understanding is that this is why the existing non-DeepCompile ZeRO-3 path used set_data(empty).

A more complete long-term direction might be to implement a custom lowering / aliasing model for these ops, so Inductor can understand the external lifetime more precisely. That would be a larger change, though.

I agree that a thorough solution would require more systematic analysis. Let's manually free the buffer to unblock the issue for now.

@tohtana tohtana merged commit 3e486fe into deepspeedai:master Jun 2, 2026
12 checks passed
@tohtana
Copy link
Copy Markdown
Collaborator Author

tohtana commented Jun 2, 2026

Thank you, @eternalNight, for sharing your thoughtful insights!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants