Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 013fe5e

Browse files
committed
[aotd] Support saved tensors hooks in aot_autograd
ghstack-source-id: 3075c63 Pull Request resolved: #150032
1 parent 7bb9c36 commit 013fe5e

File tree

8 files changed

+349
-13
lines changed

8 files changed

+349
-13
lines changed

aten/src/ATen/SavedTensorHooks.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,30 @@ void SavedTensorDefaultHooks::lazy_initialize() {
6262
void SavedTensorDefaultHooks::push_hooks(SafePyObject pack_hook, SafePyObject unpack_hook) {
6363
TORCH_INTERNAL_ASSERT(is_initialized);
6464
assertSavedTensorHooksNotDisabled();
65-
tls.stack.emplace(std::move(pack_hook), std::move(unpack_hook));
65+
tls.stack.emplace_back(std::move(pack_hook), std::move(unpack_hook));
6666
}
6767

6868
std::pair<SafePyObject, SafePyObject> SavedTensorDefaultHooks::pop_hooks() {
6969
TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty());
70-
std::pair<SafePyObject, SafePyObject> hooks = std::move(tls.stack.top());
71-
tls.stack.pop();
70+
std::pair<SafePyObject, SafePyObject> hooks = std::move(tls.stack.back());
71+
tls.stack.pop_back();
7272
return hooks;
7373
}
7474

75-
std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks() {
75+
std::optional<std::pair<SafePyObject, SafePyObject>> SavedTensorDefaultHooks::get_hooks(bool ignore_is_tracing) {
7676
// For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime]
77-
if (!is_initialized || tls.stack.empty() || tls.is_tracing) {
77+
if (!is_initialized || tls.stack.empty() || (!ignore_is_tracing && tls.is_tracing)) {
7878
return std::nullopt;
7979
}
80-
return tls.stack.top();
80+
return tls.stack.back();
81+
}
82+
83+
std::optional<std::vector<std::pair<SafePyObject, SafePyObject>>>
84+
SavedTensorDefaultHooks::get_all_hooks(bool ignore_is_tracing) {
85+
if (!is_initialized || tls.stack.empty() || (!ignore_is_tracing && tls.is_tracing)) {
86+
return std::nullopt;
87+
}
88+
return tls.stack;
8189
}
8290

8391
}

aten/src/ATen/SavedTensorHooks.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace impl {
1515

1616
struct TORCH_API SavedTensorDefaultHooksTLS {
1717
// PyObject is defined in c10/util/python_stub.h
18-
std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
18+
std::vector<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
1919

2020
// See NOTE: [Disabling SavedTensorDefaultHooks] for context
2121
// NOTE: [disabled_error_message invariant]
@@ -36,7 +36,9 @@ struct TORCH_API SavedTensorDefaultHooks {
3636
c10::SafePyObject unpack_hook);
3737
static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
3838
static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
39-
get_hooks();
39+
get_hooks(bool ignore_is_tracing = false);
40+
static std::optional<std::vector<std::pair<SafePyObject, SafePyObject>>>
41+
get_all_hooks(bool ignore_is_tracing = false);
4042
static void lazy_initialize();
4143

4244
static const impl::SavedTensorDefaultHooksTLS& get_tls_state();

test/functorch/test_aotdispatch.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import itertools
1111
import unittest
1212
import warnings
13-
from contextlib import ContextDecorator, nullcontext
13+
from contextlib import ContextDecorator, nullcontext, ExitStack
1414
from functools import partial, wraps
1515
from typing import Any, Callable, Optional, Union
1616
from unittest.mock import patch
@@ -6601,6 +6601,127 @@ def _inp():
66016601
self.assertEqual(1, len(ctx.tangent_strides))
66026602
self.assertEqual((128, 4, 16, 1), ctx.tangent_strides[0])
66036603

6604+
def test_saved_tensors_hooks(self):
6605+
def _test_pack_hooks(fn, inp_fn, hooks):
6606+
torch._dynamo.reset()
6607+
with ExitStack() as stack:
6608+
for hook in hooks:
6609+
pack, unpack = hook
6610+
stack.enter_context(torch.autograd.graph.saved_tensors_hooks(pack, unpack))
6611+
ref_x = inp_fn()
6612+
x = ref_x.detach().clone().requires_grad_()
6613+
6614+
print(f"XXX EAGER BEGIN")
6615+
ref_y = fn(ref_x)
6616+
ref_y.sum().backward()
6617+
print(f"XXX EAGER END")
6618+
6619+
torch._dynamo.mark_dynamic(x, 0)
6620+
torch._dynamo.mark_dynamic(x, 1)
6621+
y = torch.compile(fn, backend="aot_eager", fullgraph=True)(x)
6622+
y.sum().backward()
6623+
self.assertEqual(ref_y, y, atol=1e-2, rtol=1e-2)
6624+
print(f"XXX REF_X.GRAD:{ref_x.grad}")
6625+
print(f"XXX X.GRAD:{x.grad}")
6626+
self.assertEqual(ref_x.grad, x.grad, atol=1e-2, rtol=1e-2)
6627+
6628+
from torch.utils._traceback import CapturedTraceback
6629+
def _print_traceback():
6630+
print("".join(CapturedTraceback.extract(cpp=True).format()))
6631+
6632+
class SAF(torch.autograd.Function):
6633+
@staticmethod
6634+
def forward(ctx, x):
6635+
ctx.save_for_backward(x)
6636+
return x
6637+
6638+
@staticmethod
6639+
def backward(ctx, gx):
6640+
(saved_x,) = ctx.saved_tensors
6641+
return gx + saved_x
6642+
6643+
class AF(torch.autograd.Function):
6644+
@staticmethod
6645+
def forward(ctx, x):
6646+
ctx.save_for_backward(x)
6647+
ctx.d1 = x.size(1)
6648+
return x
6649+
6650+
@staticmethod
6651+
def backward(ctx, gx):
6652+
(saved_x,) = ctx.saved_tensors
6653+
d1 = ctx.d1
6654+
return gx + saved_x * d1
6655+
6656+
def fn(x):
6657+
x = x.relu()
6658+
x = x + 1
6659+
x = 2 * x
6660+
x = AF.apply(x)
6661+
return x
6662+
6663+
def simple_fn(x):
6664+
x = x + 1
6665+
x = SAF.apply(x)
6666+
return x
6667+
device=torch.device("cuda:0")
6668+
6669+
def inp_fn():
6670+
return torch.ones(2, 3, device=device, requires_grad=True)
6671+
6672+
def pack_dev_sym_cpu(x):
6673+
return (x.device, x.size(0), x.cpu())
6674+
6675+
def unpack_dev_sym_cpu(packed):
6676+
device, dim0, tensor = packed
6677+
ret = tensor.to(device=device) * dim0
6678+
return ret
6679+
6680+
def pack_tensor(x):
6681+
return x.cpu()
6682+
6683+
def unpack_tensor(packed):
6684+
t_cpu = packed
6685+
return t_cpu.to(device=device)
6686+
6687+
def pack_bf16(x):
6688+
print(f"XXX PACK_BF16")
6689+
return x.to(dtype=torch.bfloat16)
6690+
6691+
def unpack_bf16(x):
6692+
print(f"XXX UNPACK_BF16")
6693+
return x.to(dtype=torch.float)
6694+
6695+
def pack_mul2(x):
6696+
print(f"XXX PACK_MUL2")
6697+
return x * 2
6698+
6699+
def unpack_mul2(x):
6700+
print(f"XXX UNPACK_MUL2")
6701+
return x / 2
6702+
6703+
def pack_two_tensor(x):
6704+
return TwoTensor(x, x)
6705+
6706+
def unpack_two_tensor(x):
6707+
return x.a
6708+
6709+
for test_fn in [simple_fn]:
6710+
# print("XXX 0")
6711+
# _test_pack_hooks(test_fn, inp_fn, [(pack_bf16, unpack_bf16)])
6712+
# print("XXX 1")
6713+
# _test_pack_hooks(test_fn, inp_fn, [(pack_mul2, unpack_mul2)])
6714+
# print("XXX 2")
6715+
# _test_pack_hooks(test_fn, inp_fn, [(pack_mul2, unpack_mul2), (pack_bf16, unpack_bf16)])
6716+
print("XXX 3")
6717+
_test_pack_hooks(test_fn, inp_fn, [(pack_dev_sym_cpu, unpack_dev_sym_cpu)])
6718+
# print("XXX 4")
6719+
# _test_pack_hooks(test_fn, inp_fn, [(pack_tensor, unpack_tensor)])
6720+
# print("XXX 5")
6721+
# _test_pack_hooks(test_fn, inp_fn, [(pack_two_tensor, unpack_two_tensor)])
6722+
# TODO XXX: Add packing to subclasses
6723+
6724+
66046725

66056726
# entries in here don't work and need to be fixed.
66066727
# Each one of these is a bug (or needs to be investigated)

torch/_dynamo/variables/builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2458,7 +2458,11 @@ def _wrap_fx_proxy(
24582458
assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
24592459

24602460
# See NOTE: [Deferring tensor pack/unpack hooks until runtime]
2461-
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
2461+
import contextlib
2462+
2463+
with (
2464+
contextlib.nullcontext()
2465+
): # torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
24622466
# with preserve_rng_state():
24632467
# only allow_non_graph_fake in this instance because we handle the non-fake
24642468
# cases properly below.

torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py

Lines changed: 152 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Any, Callable, Optional, TYPE_CHECKING
2222

2323
import torch
24+
import torch.utils._pytree as pytree
2425
import torch.utils.dlpack
2526
from torch import Tensor
2627
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
@@ -29,12 +30,13 @@
2930
from torch._subclasses import FakeTensor
3031
from torch._subclasses.meta_utils import is_sparse_any
3132
from torch.fx.experimental._backward_state import BackwardState
32-
from torch.fx.experimental.proxy_tensor import is_sym_node
33+
from torch.fx.experimental.proxy_tensor import is_sym_node, make_fx
3334
from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals
3435
from torch.fx.graph_module import GraphModule
3536
from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars
3637
from torch.multiprocessing.reductions import StorageWeakRef
3738
from torchgen.utils import dataclass_repr
39+
from torch.types import py_sym_types
3840

3941
from .. import config
4042
from .autograd_cache import (
@@ -877,13 +879,160 @@ def aot_dispatch_autograd(
877879
# we only need to bookkeep the symints that are saved for bw, not any symints
878880
# the user forward might have returned in its own output
879881
fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:]
880-
num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw)
882+
num_saved_for_bw = len(fw_outs_saved_for_bw)
881883
symint_outs_saved_for_bw = [
882884
n for n in fw_outs_saved_for_bw if is_sym_node(n)
883885
]
886+
num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
887+
num_saved_tensors = len(fw_outs_saved_for_bw) - num_symints_saved_for_bw
888+
889+
# TODO XXX: Note about handling saved_tensors_hooks
890+
#
891+
hooks = torch._C._autograd._top_saved_tensors_default_hooks(True)
892+
893+
print(f"XXX JIT_COMP_RUNTIME_WRAP hooks:{hooks}")
894+
# TODO XXX: Set compilation guards on hooks py objects, to recompile if previous hooks changed
895+
if hooks:
896+
# TODO: XXX Support stacked hooks
897+
pack_hook, unpack_hook = hooks
898+
assert pack_hook and unpack_hook
899+
fw_g = fw_module.graph
900+
bw_g = bw_module.graph
901+
bw_g_inputs = bw_g.find_nodes(op="placeholder")
902+
print(f"XXX FW_GRAPH BEFORE:{fw_g}")
903+
print(f"XXX BW_GRAPH BEFORE:{bw_g}")
904+
905+
fw_out_n = fw_g.output_node()
906+
fw_out_args = list(fw_out_n.args[0])
907+
fw_outs_insert_tensors = []
908+
fw_outs_insert_non_tensors = []
909+
for saved in fw_outs_saved_for_bw:
910+
val = saved.meta["val"]
911+
if isinstance(val, torch.Tensor):
912+
pack_gm = make_fx(pack_hook)(val)
913+
pack_g = pack_gm.graph
914+
print(f"XXX PACK_GRAPH:{pack_g}")
915+
pack_out_val = pack_gm(val)
916+
# Install pack_g as eiplogue of fw_module and replace saved outputs with pack_g outputs
917+
pack_g_inputs = pack_g.find_nodes(op="placeholder")
918+
assert len(pack_g_inputs) == 1
919+
env = {pack_g_inputs[0]: saved}
920+
with fw_g.inserting_after(saved):
921+
new_out_n = None
922+
for node in pack_g.nodes:
923+
if node.op == "placeholder":
924+
continue
925+
new_n = fw_g.node_copy(node, lambda n: env[n])
926+
env[node] = new_n
927+
if node.op == "output":
928+
new_out_n = new_n
929+
930+
assert new_out_n
931+
for n in pytree.tree_leaves(new_out_n.args[0]):
932+
if not isinstance(n, torch.fx.Node):
933+
continue
934+
935+
out_val = n.meta["val"]
936+
if isinstance(out_val, torch.Tensor):
937+
fw_outs_insert_tensors.append(n)
938+
elif is_sym_node(n):
939+
fw_outs_insert_non_tensors.append(n)
940+
941+
fw_g.erase_node(new_out_n)
942+
943+
# Install unpack_g as prologue of bw_module
944+
unpack_gm = make_fx(unpack_hook)(pack_out_val)
945+
unpack_out_val = unpack_gm(pack_out_val)
946+
unpack_g = unpack_gm.graph
947+
print(f"XXX PACK_OUT_VAL:{pack_out_val}")
948+
print(f"XXX UNPACK_OUT_VAL:{unpack_out_val}")
949+
print(f"XXX UNPACK_GRAPH:{unpack_g}")
950+
951+
952+
def find_saved_in_bw_inputs(bw_inputs):
953+
for n in bw_inputs:
954+
# TODO: XXX Recheck validity of this identificaiton :)
955+
if n.name == saved.name:
956+
return n
957+
958+
bw_g_input = find_saved_in_bw_inputs(bw_g_inputs)
959+
assert bw_g_input
960+
# Replace bw_g input with copy of output of pack
961+
962+
unpack_g_inputs = unpack_g.find_nodes(op="placeholder")
963+
env = {}
964+
# Adding unpack inputs to bw graph instead of saved
965+
for unp_in_n, val in zip(
966+
unpack_g_inputs, pytree.tree_leaves(pack_out_val)
967+
):
968+
is_sym = isinstance(val, py_sym_types)
969+
if isinstance(val, torch.Tensor) or is_sym:
970+
new_node_name = bw_g_input.name + "_" + unp_in_n.name
971+
# Backward calling convention: ctx_symints...ctx_saved_tensors...
972+
if is_sym:
973+
with bw_g.inserting_before(bw_g_inputs[0]):
974+
new_n = bw_g.placeholder(new_node_name)
975+
else:
976+
with bw_g.inserting_before(bw_g_inputs[num_saved_for_bw]):
977+
new_n = bw_g.placeholder(new_node_name)
978+
new_n.meta["val"] = val
979+
env[unp_in_n] = new_n
980+
else:
981+
# Inline values of non-Tensor, non-SymScalars
982+
env[unp_in_n] = val
983+
984+
new_out_n = None
985+
with bw_g.inserting_before(bw_g_inputs[-1]):
986+
for node in unpack_g.nodes:
987+
if node.op == "placeholder":
988+
continue
989+
new_n = bw_g.node_copy(node, lambda n: env[n])
990+
env[node] = new_n
991+
if node.op == "output":
992+
new_out_n = new_n
993+
994+
# TODO XXX: Debug why unpack graph produces [node] instead of node
995+
# For unpack function
996+
# def unpack_dev_sym_cpu(packed):
997+
# device, dim0, tensor = packed
998+
# return tensor.to(device=device) * dim0
999+
#
1000+
# assert len(new_out_n.args) == 1
1001+
# print(f"XXX NEW_OUT_N.ARGS:{new_out_n.args}")
1002+
# unpack_saved_tensor_n = new_out_n.args[0]
1003+
unpack_saved_tensor_n = pytree.tree_leaves(new_out_n.args)[0]
1004+
1005+
bw_g_input.replace_all_uses_with(unpack_saved_tensor_n)
1006+
bw_g.erase_node(new_out_n)
1007+
bw_g.erase_node(bw_g_input)
1008+
fw_out_n.args = (
1009+
tuple(
1010+
pytree.tree_leaves((
1011+
fw_outs[:num_inner_fwd_outputs],
1012+
fw_outs_insert_tensors,
1013+
fw_outs_insert_non_tensors,
1014+
symint_outs_saved_for_bw,
1015+
))
1016+
),
1017+
)
1018+
print(f"\nXXX FW_GRAPH AFTER:{fw_g}")
1019+
print(f"\nXXX BW_GRAPH AFTER:{bw_g}")
1020+
fw_g.lint()
1021+
bw_g.lint()
1022+
# TODO: Refactor compute below to deduplicate
1023+
fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0]
1024+
fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:]
1025+
num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw)
1026+
num_saved_for_bw = len(fw_outs_saved_for_bw)
1027+
symint_outs_saved_for_bw = [
1028+
n for n in fw_outs_saved_for_bw if is_sym_node(n)
1029+
]
1030+
num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
1031+
num_saved_tensors = len(fw_outs_saved_for_bw) - num_symints_saved_for_bw
1032+
8841033
fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
8851034
inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
886-
num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
1035+
8871036

8881037
if torch._functorch.config.donated_buffer:
8891038
fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs(

0 commit comments

Comments
 (0)