Fix DeepCompile ZeRO-3 release parameter lifetime#8032
Conversation
7b2ae31 to
04ca757
Compare
|
The gathered buffer is registered in |
7af789e to
f075f74
Compare
Signed-off-by: Masahiro Tanaka <[email protected]>
f075f74 to
d715e86
Compare
|
Thank you for your review, @eternalNight! The actual issue is that, after #7489, the release op does not release the underlying gathered buffer storage. It is not an early-release issue. I updated the PR description to clarify this. This PR is much simpler now. |
Thanks for the clarification! That makes the picture much clearer. I still have one doubt, though. The You may capture a torch memory history which records exactly when (by recording the call stack) each storage is finally released. There may be more hints in the call stack about the residual reference. |
|
Hi @eternalNight, Thanks for the suggestion. Yes, that makes sense, and I agree there must still be a live reference after After #7489:
Those tensors no longer necessarily share the same That is why this patch resizes the old |
That is like manually freeing a buffer while we know there are some other pointers dangling around, which looks risky to me. Do you have a minimal reproduction of the issue? I can help investigate where the residual reference is and whether it is safe to assume that reference, while being held, will never be touched post |
|
Hi @eternalNight, I think the residual reference is already identified here: it is the true-shape view returned from For the safety part of your question, the guarantee I am relying on is that the final This follows the existing ZeRO-3 (no-DeepCompile) explicit-release model, which clears parameter data with |
That's what confused me. In torchinductor-generated graphs, references to tensors are droped after their last use. If the true-shape view is not consumed after I have met a similar issue before with amp enabled, because amp caches casted weights by default. The amp weight cache thus holds another reference to the gathered buffer till the end of the forward phase, preventing That's why I would like to have a repro and check its memory profile: storage references may be hold not because how torch inductor tracks data dependency, but because of other mechanisms under the hood. The unit tests do not serve the purpose because they already hold a reference to
|
|
@eternalNight I got your point, let me do small experiments to clarify it. |
|
Hi @eternalNight, I ran a separate small DeepCompile ZeRO-3 experiment to check where we keep the reference. At the Python level, the relevant part is roughly just a linear layer: out = linear(x)
loss = loss_fn(out, target)With DeepCompile / ZeRO-3, the generated code is conceptually closer to this: bias_buf = wait_allgather(allgather_param(linear.bias))
weight_buf = wait_allgather(allgather_param(linear.weight))
out = torch.addmm(bias_buf, x, weight_buf.t())
out = release_param(out, linear.bias)
out = release_param(out, linear.weight)
loss = loss_fn(out, target)So after the buf0 = torch.ops.dc.allgather_param.default(...)
buf1 = buf0
del buf0
buf2 = torch.ops.dc.allgather_param.default(...)
buf3 = buf2
del buf2
buf4 = torch.ops.dc.wait_allgather.default(buf3, ...)
buf5 = buf4
buf6 = torch.ops.dc.wait_allgather.default(buf1, ...)
buf7 = buf6
extern_kernels.addmm(buf5, x, reinterpret_tensor(buf7, ...), out=buf8)
del buf4
del buf6
buf9 = torch.ops.dc.release_param.default(buf8, ...)
buf10 = buf9
buf11 = torch.ops.dc.release_param.default(buf10, ...)
buf12 = buf11
triton_kernel.run(..., buf12, ...)This behavior is related to how these DeepCompile custom ops are registered as fallback kernels. Today DeepCompile intentionally marks the outputs as register_fallback_no_reuse(
torch.ops.dc.allgather_param.default,
never_reuse_input=False,
never_reuse_output=True,
)
register_fallback_no_reuse(
torch.ops.dc.wait_allgather.default,
never_reuse_input=True,
never_reuse_output=True,
)If I change But we cannot make this purely GC-driven either. ZeRO-3 needs the A more complete long-term direction might be to implement a custom lowering / aliasing model for these ops, so Inductor can understand the external lifetime more precisely. That would be a larger change, though. |
Thanks for the detailed investigation, @tohtana ! That explains why the residual reference remains till the end of the graph. Torch inductor marks a tensor to be unreusable by keeping its reference alive, which is simple but makes such tensors live longer and eventually increases the peak GPU memory usage.
I agree that a thorough solution would require more systematic analysis. Let's manually free the buffer to unblock the issue for now. |
|
Thank you, @eternalNight, for sharing your thoughtful insights! |
PR #7489 made ZeRO-3 all-gather allocate a padded base buffer for uneven shards and return a true-shape view into that buffer. That means the registry tensor and the tensor returned to the compiled graph no longer necessarily share the same
TensorImpl, although they still share the same underlying storage.The existing release path only did
set_data(empty)on the registry tensor before unregistering it. With the new base/view relationship, that clears the registry-side tensor metadata but does not resize the sharedStorageImplstill referenced by returned views. As a result, the padded gathered allocation can remain live after the finalrelease_param.This patch keeps the release graph ordering unchanged and makes final non-persistent release resize the registered gathered storage to 0 bytes before unregistering it.