|
21 | 21 | from typing import Any, Callable, Optional, TYPE_CHECKING |
22 | 22 |
|
23 | 23 | import torch |
| 24 | +import torch.utils._pytree as pytree |
24 | 25 | import torch.utils.dlpack |
25 | 26 | from torch import Tensor |
26 | 27 | from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code |
|
29 | 30 | from torch._subclasses import FakeTensor |
30 | 31 | from torch._subclasses.meta_utils import is_sparse_any |
31 | 32 | from torch.fx.experimental._backward_state import BackwardState |
32 | | -from torch.fx.experimental.proxy_tensor import is_sym_node |
| 33 | +from torch.fx.experimental.proxy_tensor import is_sym_node, make_fx |
33 | 34 | from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals |
34 | 35 | from torch.fx.graph_module import GraphModule |
35 | 36 | from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars |
36 | 37 | from torch.multiprocessing.reductions import StorageWeakRef |
37 | 38 | from torchgen.utils import dataclass_repr |
| 39 | +from torch.types import py_sym_types |
38 | 40 |
|
39 | 41 | from .. import config |
40 | 42 | from .autograd_cache import ( |
@@ -877,13 +879,160 @@ def aot_dispatch_autograd( |
877 | 879 | # we only need to bookkeep the symints that are saved for bw, not any symints |
878 | 880 | # the user forward might have returned in its own output |
879 | 881 | fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] |
880 | | - num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) |
| 882 | + num_saved_for_bw = len(fw_outs_saved_for_bw) |
881 | 883 | symint_outs_saved_for_bw = [ |
882 | 884 | n for n in fw_outs_saved_for_bw if is_sym_node(n) |
883 | 885 | ] |
| 886 | + num_symints_saved_for_bw = len(symint_outs_saved_for_bw) |
| 887 | + num_saved_tensors = len(fw_outs_saved_for_bw) - num_symints_saved_for_bw |
| 888 | + |
| 889 | + # TODO XXX: Note about handling saved_tensors_hooks |
| 890 | + # |
| 891 | + hooks = torch._C._autograd._top_saved_tensors_default_hooks(True) |
| 892 | + |
| 893 | + print(f"XXX JIT_COMP_RUNTIME_WRAP hooks:{hooks}") |
| 894 | + # TODO XXX: Set compilation guards on hooks py objects, to recompile if previous hooks changed |
| 895 | + if hooks: |
| 896 | + # TODO: XXX Support stacked hooks |
| 897 | + pack_hook, unpack_hook = hooks |
| 898 | + assert pack_hook and unpack_hook |
| 899 | + fw_g = fw_module.graph |
| 900 | + bw_g = bw_module.graph |
| 901 | + bw_g_inputs = bw_g.find_nodes(op="placeholder") |
| 902 | + print(f"XXX FW_GRAPH BEFORE:{fw_g}") |
| 903 | + print(f"XXX BW_GRAPH BEFORE:{bw_g}") |
| 904 | + |
| 905 | + fw_out_n = fw_g.output_node() |
| 906 | + fw_out_args = list(fw_out_n.args[0]) |
| 907 | + fw_outs_insert_tensors = [] |
| 908 | + fw_outs_insert_non_tensors = [] |
| 909 | + for saved in fw_outs_saved_for_bw: |
| 910 | + val = saved.meta["val"] |
| 911 | + if isinstance(val, torch.Tensor): |
| 912 | + pack_gm = make_fx(pack_hook)(val) |
| 913 | + pack_g = pack_gm.graph |
| 914 | + print(f"XXX PACK_GRAPH:{pack_g}") |
| 915 | + pack_out_val = pack_gm(val) |
| 916 | + # Install pack_g as eiplogue of fw_module and replace saved outputs with pack_g outputs |
| 917 | + pack_g_inputs = pack_g.find_nodes(op="placeholder") |
| 918 | + assert len(pack_g_inputs) == 1 |
| 919 | + env = {pack_g_inputs[0]: saved} |
| 920 | + with fw_g.inserting_after(saved): |
| 921 | + new_out_n = None |
| 922 | + for node in pack_g.nodes: |
| 923 | + if node.op == "placeholder": |
| 924 | + continue |
| 925 | + new_n = fw_g.node_copy(node, lambda n: env[n]) |
| 926 | + env[node] = new_n |
| 927 | + if node.op == "output": |
| 928 | + new_out_n = new_n |
| 929 | + |
| 930 | + assert new_out_n |
| 931 | + for n in pytree.tree_leaves(new_out_n.args[0]): |
| 932 | + if not isinstance(n, torch.fx.Node): |
| 933 | + continue |
| 934 | + |
| 935 | + out_val = n.meta["val"] |
| 936 | + if isinstance(out_val, torch.Tensor): |
| 937 | + fw_outs_insert_tensors.append(n) |
| 938 | + elif is_sym_node(n): |
| 939 | + fw_outs_insert_non_tensors.append(n) |
| 940 | + |
| 941 | + fw_g.erase_node(new_out_n) |
| 942 | + |
| 943 | + # Install unpack_g as prologue of bw_module |
| 944 | + unpack_gm = make_fx(unpack_hook)(pack_out_val) |
| 945 | + unpack_out_val = unpack_gm(pack_out_val) |
| 946 | + unpack_g = unpack_gm.graph |
| 947 | + print(f"XXX PACK_OUT_VAL:{pack_out_val}") |
| 948 | + print(f"XXX UNPACK_OUT_VAL:{unpack_out_val}") |
| 949 | + print(f"XXX UNPACK_GRAPH:{unpack_g}") |
| 950 | + |
| 951 | + |
| 952 | + def find_saved_in_bw_inputs(bw_inputs): |
| 953 | + for n in bw_inputs: |
| 954 | + # TODO: XXX Recheck validity of this identificaiton :) |
| 955 | + if n.name == saved.name: |
| 956 | + return n |
| 957 | + |
| 958 | + bw_g_input = find_saved_in_bw_inputs(bw_g_inputs) |
| 959 | + assert bw_g_input |
| 960 | + # Replace bw_g input with copy of output of pack |
| 961 | + |
| 962 | + unpack_g_inputs = unpack_g.find_nodes(op="placeholder") |
| 963 | + env = {} |
| 964 | + # Adding unpack inputs to bw graph instead of saved |
| 965 | + for unp_in_n, val in zip( |
| 966 | + unpack_g_inputs, pytree.tree_leaves(pack_out_val) |
| 967 | + ): |
| 968 | + is_sym = isinstance(val, py_sym_types) |
| 969 | + if isinstance(val, torch.Tensor) or is_sym: |
| 970 | + new_node_name = bw_g_input.name + "_" + unp_in_n.name |
| 971 | + # Backward calling convention: ctx_symints...ctx_saved_tensors... |
| 972 | + if is_sym: |
| 973 | + with bw_g.inserting_before(bw_g_inputs[0]): |
| 974 | + new_n = bw_g.placeholder(new_node_name) |
| 975 | + else: |
| 976 | + with bw_g.inserting_before(bw_g_inputs[num_saved_for_bw]): |
| 977 | + new_n = bw_g.placeholder(new_node_name) |
| 978 | + new_n.meta["val"] = val |
| 979 | + env[unp_in_n] = new_n |
| 980 | + else: |
| 981 | + # Inline values of non-Tensor, non-SymScalars |
| 982 | + env[unp_in_n] = val |
| 983 | + |
| 984 | + new_out_n = None |
| 985 | + with bw_g.inserting_before(bw_g_inputs[-1]): |
| 986 | + for node in unpack_g.nodes: |
| 987 | + if node.op == "placeholder": |
| 988 | + continue |
| 989 | + new_n = bw_g.node_copy(node, lambda n: env[n]) |
| 990 | + env[node] = new_n |
| 991 | + if node.op == "output": |
| 992 | + new_out_n = new_n |
| 993 | + |
| 994 | + # TODO XXX: Debug why unpack graph produces [node] instead of node |
| 995 | + # For unpack function |
| 996 | + # def unpack_dev_sym_cpu(packed): |
| 997 | + # device, dim0, tensor = packed |
| 998 | + # return tensor.to(device=device) * dim0 |
| 999 | + # |
| 1000 | + # assert len(new_out_n.args) == 1 |
| 1001 | + # print(f"XXX NEW_OUT_N.ARGS:{new_out_n.args}") |
| 1002 | + # unpack_saved_tensor_n = new_out_n.args[0] |
| 1003 | + unpack_saved_tensor_n = pytree.tree_leaves(new_out_n.args)[0] |
| 1004 | + |
| 1005 | + bw_g_input.replace_all_uses_with(unpack_saved_tensor_n) |
| 1006 | + bw_g.erase_node(new_out_n) |
| 1007 | + bw_g.erase_node(bw_g_input) |
| 1008 | + fw_out_n.args = ( |
| 1009 | + tuple( |
| 1010 | + pytree.tree_leaves(( |
| 1011 | + fw_outs[:num_inner_fwd_outputs], |
| 1012 | + fw_outs_insert_tensors, |
| 1013 | + fw_outs_insert_non_tensors, |
| 1014 | + symint_outs_saved_for_bw, |
| 1015 | + )) |
| 1016 | + ), |
| 1017 | + ) |
| 1018 | + print(f"\nXXX FW_GRAPH AFTER:{fw_g}") |
| 1019 | + print(f"\nXXX BW_GRAPH AFTER:{bw_g}") |
| 1020 | + fw_g.lint() |
| 1021 | + bw_g.lint() |
| 1022 | + # TODO: Refactor compute below to deduplicate |
| 1023 | + fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0] |
| 1024 | + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] |
| 1025 | + num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) |
| 1026 | + num_saved_for_bw = len(fw_outs_saved_for_bw) |
| 1027 | + symint_outs_saved_for_bw = [ |
| 1028 | + n for n in fw_outs_saved_for_bw if is_sym_node(n) |
| 1029 | + ] |
| 1030 | + num_symints_saved_for_bw = len(symint_outs_saved_for_bw) |
| 1031 | + num_saved_tensors = len(fw_outs_saved_for_bw) - num_symints_saved_for_bw |
| 1032 | + |
884 | 1033 | fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) |
885 | 1034 | inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw) |
886 | | - num_symints_saved_for_bw = len(symint_outs_saved_for_bw) |
| 1035 | + |
887 | 1036 |
|
888 | 1037 | if torch._functorch.config.donated_buffer: |
889 | 1038 | fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs( |
|
0 commit comments