Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5e130ef

Browse files
committed
[WIP][aotd] Support saved tensors hooks in aot_autograd
ghstack-source-id: c944f6a Pull Request resolved: #150032
1 parent 7bb9c36 commit 5e130ef

File tree

8 files changed

+512
-149
lines changed

8 files changed

+512
-149
lines changed

aten/src/ATen/SavedTensorHooks.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ bool SavedTensorDefaultHooks::is_enabled() {
2626
return !tls.disabled_error_message.has_value();
2727
}
2828

29-
void SavedTensorDefaultHooks::disable(const std::string& message) {
29+
void SavedTensorDefaultHooks::disable(const std::string& message, const bool fail_if_non_empty) {
3030
tls.disabled_error_message = message;
31-
if (!tls.stack.empty()) {
31+
if (fail_if_non_empty && !tls.stack.empty()) {
3232
assertSavedTensorHooksNotDisabled();
3333
}
3434
}
@@ -72,9 +72,9 @@ std::pair<SafePyObject, SafePyObject> SavedTensorDefaultHooks::pop_hooks() {
7272
return hooks;
7373
}
7474

75-
std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() {
75+
std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks(bool ignore_is_tracing) {
7676
// For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime]
77-
if (!is_initialized || tls.stack.empty() || tls.is_tracing) {
77+
if (!is_initialized || tls.stack.empty() || (!ignore_is_tracing && tls.is_tracing)) {
7878
return std::nullopt;
7979
}
8080
return tls.stack.top();

aten/src/ATen/SavedTensorHooks.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct TORCH_API SavedTensorDefaultHooks {
3636
c10::SafePyObject unpack_hook);
3737
static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
3838
static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
39-
get_hooks();
39+
get_hooks(bool ignore_is_tracing = false);
4040
static void lazy_initialize();
4141

4242
static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
@@ -48,7 +48,9 @@ struct TORCH_API SavedTensorDefaultHooks {
4848
// disabled, then the following will raise an error:
4949
// - Attempting to push_hooks
5050
// - calling disable(message) with a non-zero stack (hooks) size
51-
static void disable(const std::string& error_message);
51+
static void disable(
52+
const std::string& error_message,
53+
const bool fail_if_non_empty = true);
5254
static void enable();
5355
static bool is_enabled();
5456
static const std::optional<std::string>& get_disabled_error_message();

test/functorch/test_aotdispatch.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import itertools
1111
import unittest
1212
import warnings
13-
from contextlib import ContextDecorator, nullcontext
13+
from contextlib import ContextDecorator, ExitStack, nullcontext
1414
from functools import partial, wraps
1515
from typing import Any, Callable, Optional, Union
1616
from unittest.mock import patch
@@ -6601,6 +6601,116 @@ def _inp():
66016601
self.assertEqual(1, len(ctx.tangent_strides))
66026602
self.assertEqual((128, 4, 16, 1), ctx.tangent_strides[0])
66036603

6604+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
6605+
def test_saved_tensors_hooks(self):
6606+
def _test_pack_hooks(fn, inp_fn, hooks):
6607+
# TODO XXX: Add Dynamo ID_MATCH guards on hooks
6608+
torch._dynamo.reset()
6609+
with ExitStack() as stack:
6610+
for hook in hooks:
6611+
pack, unpack = hook
6612+
stack.enter_context(
6613+
torch.autograd.graph.saved_tensors_hooks(pack, unpack)
6614+
)
6615+
ref_x = inp_fn()
6616+
x = ref_x.detach().clone().requires_grad_()
6617+
6618+
ref_y = fn(ref_x)
6619+
ref_y.sum().backward()
6620+
6621+
torch._dynamo.mark_dynamic(x, 0)
6622+
torch._dynamo.mark_dynamic(x, 1)
6623+
y = torch.compile(fn, backend="aot_eager", fullgraph=True)(x)
6624+
y.sum().backward()
6625+
self.assertEqual(ref_y, y, atol=1e-2, rtol=1e-2)
6626+
self.assertEqual(ref_x.grad, x.grad, atol=1e-2, rtol=1e-2)
6627+
6628+
class SAF(torch.autograd.Function):
6629+
@staticmethod
6630+
def forward(ctx, x):
6631+
ctx.save_for_backward(x)
6632+
return x
6633+
6634+
@staticmethod
6635+
def backward(ctx, gx):
6636+
(saved_x,) = ctx.saved_tensors
6637+
return gx + saved_x
6638+
6639+
class AF(torch.autograd.Function):
6640+
@staticmethod
6641+
def forward(ctx, x):
6642+
ctx.save_for_backward(x)
6643+
ctx.d1 = x.size(1)
6644+
return x
6645+
6646+
@staticmethod
6647+
def backward(ctx, gx):
6648+
(saved_x,) = ctx.saved_tensors
6649+
d1 = ctx.d1
6650+
return gx + saved_x * d1
6651+
6652+
def fn(x):
6653+
x = x.relu()
6654+
x = x + 1
6655+
x = 2 * x
6656+
x = AF.apply(x)
6657+
return x
6658+
6659+
def simple_fn(x):
6660+
x = x + 1
6661+
x = SAF.apply(x)
6662+
return x
6663+
6664+
device = torch.device("cuda:0")
6665+
6666+
def inp_fn():
6667+
return torch.ones(2, 3, device=device, requires_grad=True)
6668+
6669+
def pack_dev_sym_cpu(x):
6670+
return (x.device, x.size(0), 10 * x.cpu())
6671+
6672+
def unpack_dev_sym_cpu(packed):
6673+
device, dim0, tensor = packed
6674+
ret = tensor.to(device=device) * dim0
6675+
return ret
6676+
6677+
def pack_tensor(x):
6678+
return x.cpu()
6679+
6680+
def unpack_tensor(packed):
6681+
t_cpu = packed
6682+
return t_cpu.to(device=device)
6683+
6684+
def pack_bf16(x):
6685+
return x.to(dtype=torch.bfloat16)
6686+
6687+
def unpack_bf16(x):
6688+
return x.to(dtype=torch.float)
6689+
6690+
def pack_mul2(x):
6691+
return x * 2
6692+
6693+
def unpack_mul2(x):
6694+
return x / 2
6695+
6696+
def pack_float8(x):
6697+
return (x.dtype, x.to(torch.float8_e4m3fn))
6698+
6699+
def unpack_float8(packed):
6700+
dtype, tensor = packed
6701+
return tensor.to(dtype)
6702+
6703+
for test_fn in [simple_fn, fn]:
6704+
_test_pack_hooks(test_fn, inp_fn, [(pack_bf16, unpack_bf16)])
6705+
_test_pack_hooks(test_fn, inp_fn, [(pack_mul2, unpack_mul2)])
6706+
_test_pack_hooks(
6707+
test_fn, inp_fn, [(pack_mul2, unpack_mul2), (pack_bf16, unpack_bf16)]
6708+
)
6709+
_test_pack_hooks(test_fn, inp_fn, [(pack_float8, unpack_float8)])
6710+
_test_pack_hooks(test_fn, inp_fn, [(pack_tensor, unpack_tensor)])
6711+
_test_pack_hooks(test_fn, inp_fn, [(pack_dev_sym_cpu, unpack_dev_sym_cpu)])
6712+
# TODO XXX: Test packing/unpacking to subclasses
6713+
66046714

66056715
# entries in here don't work and need to be fixed.
66066716
# Each one of these is a bug (or needs to be investigated)

torch/_C/_autograd.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,17 @@ def _push_saved_tensors_default_hooks(
116116
unpack_hook: Callable[[Any], torch.Tensor],
117117
) -> None: ...
118118
def _pop_saved_tensors_default_hooks() -> None: ...
119+
def _top_saved_tensors_default_hooks(
120+
ignore_is_tracing: bool,
121+
) -> tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]]: ...
119122
def _unsafe_set_version_counter(
120123
t: tuple[torch.Tensor, ...], prev_version: tuple[int, ...]
121124
) -> None: ...
122125
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
123126
def _disable_profiler_legacy() -> list[list[ProfilerEvent]]: ...
124127
def _profiler_type() -> ActiveProfilerType: ...
125128
def _saved_tensors_hooks_enable() -> None: ...
126-
def _saved_tensors_hooks_disable(message: str) -> None: ...
129+
def _saved_tensors_hooks_disable(message: str, fail_if_non_empty=True) -> None: ...
127130
def _saved_tensors_hooks_get_disabled_error_message() -> str | None: ...
128131
def _saved_tensors_hooks_set_tracing(is_tracing: bool) -> bool: ...
129132

0 commit comments

Comments
 (0)