|
10 | 10 | import itertools |
11 | 11 | import unittest |
12 | 12 | import warnings |
13 | | -from contextlib import ContextDecorator, nullcontext |
| 13 | +from contextlib import ContextDecorator, ExitStack, nullcontext |
14 | 14 | from functools import partial, wraps |
15 | 15 | from typing import Any, Callable, Optional, Union |
16 | 16 | from unittest.mock import patch |
@@ -6601,6 +6601,116 @@ def _inp(): |
6601 | 6601 | self.assertEqual(1, len(ctx.tangent_strides)) |
6602 | 6602 | self.assertEqual((128, 4, 16, 1), ctx.tangent_strides[0]) |
6603 | 6603 |
|
| 6604 | + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| 6605 | + def test_saved_tensors_hooks(self): |
| 6606 | + def _test_pack_hooks(fn, inp_fn, hooks): |
| 6607 | + # TODO XXX: Add Dynamo ID_MATCH guards on hooks |
| 6608 | + torch._dynamo.reset() |
| 6609 | + with ExitStack() as stack: |
| 6610 | + for hook in hooks: |
| 6611 | + pack, unpack = hook |
| 6612 | + stack.enter_context( |
| 6613 | + torch.autograd.graph.saved_tensors_hooks(pack, unpack) |
| 6614 | + ) |
| 6615 | + ref_x = inp_fn() |
| 6616 | + x = ref_x.detach().clone().requires_grad_() |
| 6617 | + |
| 6618 | + ref_y = fn(ref_x) |
| 6619 | + ref_y.sum().backward() |
| 6620 | + |
| 6621 | + torch._dynamo.mark_dynamic(x, 0) |
| 6622 | + torch._dynamo.mark_dynamic(x, 1) |
| 6623 | + y = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) |
| 6624 | + y.sum().backward() |
| 6625 | + self.assertEqual(ref_y, y, atol=1e-2, rtol=1e-2) |
| 6626 | + self.assertEqual(ref_x.grad, x.grad, atol=1e-2, rtol=1e-2) |
| 6627 | + |
| 6628 | + class SAF(torch.autograd.Function): |
| 6629 | + @staticmethod |
| 6630 | + def forward(ctx, x): |
| 6631 | + ctx.save_for_backward(x) |
| 6632 | + return x |
| 6633 | + |
| 6634 | + @staticmethod |
| 6635 | + def backward(ctx, gx): |
| 6636 | + (saved_x,) = ctx.saved_tensors |
| 6637 | + return gx + saved_x |
| 6638 | + |
| 6639 | + class AF(torch.autograd.Function): |
| 6640 | + @staticmethod |
| 6641 | + def forward(ctx, x): |
| 6642 | + ctx.save_for_backward(x) |
| 6643 | + ctx.d1 = x.size(1) |
| 6644 | + return x |
| 6645 | + |
| 6646 | + @staticmethod |
| 6647 | + def backward(ctx, gx): |
| 6648 | + (saved_x,) = ctx.saved_tensors |
| 6649 | + d1 = ctx.d1 |
| 6650 | + return gx + saved_x * d1 |
| 6651 | + |
| 6652 | + def fn(x): |
| 6653 | + x = x.relu() |
| 6654 | + x = x + 1 |
| 6655 | + x = 2 * x |
| 6656 | + x = AF.apply(x) |
| 6657 | + return x |
| 6658 | + |
| 6659 | + def simple_fn(x): |
| 6660 | + x = x + 1 |
| 6661 | + x = SAF.apply(x) |
| 6662 | + return x |
| 6663 | + |
| 6664 | + device = torch.device("cuda:0") |
| 6665 | + |
| 6666 | + def inp_fn(): |
| 6667 | + return torch.ones(2, 3, device=device, requires_grad=True) |
| 6668 | + |
| 6669 | + def pack_dev_sym_cpu(x): |
| 6670 | + return (x.device, x.size(0), 10 * x.cpu()) |
| 6671 | + |
| 6672 | + def unpack_dev_sym_cpu(packed): |
| 6673 | + device, dim0, tensor = packed |
| 6674 | + ret = tensor.to(device=device) * dim0 |
| 6675 | + return ret |
| 6676 | + |
| 6677 | + def pack_tensor(x): |
| 6678 | + return x.cpu() |
| 6679 | + |
| 6680 | + def unpack_tensor(packed): |
| 6681 | + t_cpu = packed |
| 6682 | + return t_cpu.to(device=device) |
| 6683 | + |
| 6684 | + def pack_bf16(x): |
| 6685 | + return x.to(dtype=torch.bfloat16) |
| 6686 | + |
| 6687 | + def unpack_bf16(x): |
| 6688 | + return x.to(dtype=torch.float) |
| 6689 | + |
| 6690 | + def pack_mul2(x): |
| 6691 | + return x * 2 |
| 6692 | + |
| 6693 | + def unpack_mul2(x): |
| 6694 | + return x / 2 |
| 6695 | + |
| 6696 | + def pack_float8(x): |
| 6697 | + return (x.dtype, x.to(torch.float8_e4m3fn)) |
| 6698 | + |
| 6699 | + def unpack_float8(packed): |
| 6700 | + dtype, tensor = packed |
| 6701 | + return tensor.to(dtype) |
| 6702 | + |
| 6703 | + for test_fn in [simple_fn, fn]: |
| 6704 | + _test_pack_hooks(test_fn, inp_fn, [(pack_bf16, unpack_bf16)]) |
| 6705 | + _test_pack_hooks(test_fn, inp_fn, [(pack_mul2, unpack_mul2)]) |
| 6706 | + _test_pack_hooks( |
| 6707 | + test_fn, inp_fn, [(pack_mul2, unpack_mul2), (pack_bf16, unpack_bf16)] |
| 6708 | + ) |
| 6709 | + _test_pack_hooks(test_fn, inp_fn, [(pack_float8, unpack_float8)]) |
| 6710 | + _test_pack_hooks(test_fn, inp_fn, [(pack_tensor, unpack_tensor)]) |
| 6711 | + _test_pack_hooks(test_fn, inp_fn, [(pack_dev_sym_cpu, unpack_dev_sym_cpu)]) |
| 6712 | + # TODO XXX: Test packing/unpacking to subclasses |
| 6713 | + |
6604 | 6714 |
|
6605 | 6715 | # entries in here don't work and need to be fixed. |
6606 | 6716 | # Each one of these is a bug (or needs to be investigated) |
|
0 commit comments