|
14 | 14 | from torch.autograd.gradcheck import gradgradcheck, gradcheck |
15 | 15 | from torch.autograd.function import once_differentiable |
16 | 16 | from torch.autograd.profiler import profile |
| 17 | +from torch.utils.checkpoint import checkpoint |
17 | 18 | from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack, |
18 | 19 | suppress_warnings, skipIfRocm, |
19 | 20 | prod_single_zero, random_square_matrix_of_rank, |
20 | 21 | random_symmetric_matrix, random_symmetric_psd_matrix, |
21 | 22 | random_symmetric_pd_matrix, make_nonzero_det, |
22 | 23 | random_fullrank_matrix_distinct_singular_value, load_tests) |
| 24 | +from common_cuda import TEST_CUDA |
23 | 25 | from torch.autograd import Variable, Function, detect_anomaly |
24 | 26 | from torch.autograd.function import InplaceFunction |
25 | 27 | from torch.testing import make_non_contiguous, randn_like |
@@ -2722,6 +2724,36 @@ def fn(sparse): |
2722 | 2724 | with self.assertRaisesRegex(RuntimeError, 'gradcheck expects all tensor inputs are dense'): |
2723 | 2725 | gradcheck(fn, torch.rand(10).to_sparse().requires_grad_(True), check_sparse_nnz=False) |
2724 | 2726 |
|
| 2727 | + @unittest.skipIf(not TEST_CUDA, "Requires cuda for multi device") |
| 2728 | + def test_multi_device_reentrant_autograd(self): |
| 2729 | + # Output on gpu so that this task will be associated with the gpu thread |
| 2730 | + def fn_on_gpu(inp): |
| 2731 | + # Artificially increase the priority of the next op to make sure it runs |
| 2732 | + # as soon as we reach it before the ops of branch1. |
| 2733 | + dummy = inp * 2 * 2 * 2 * 2 |
| 2734 | + return inp.cuda() |
| 2735 | + |
| 2736 | + def parent_on_cpu(inp): |
| 2737 | + # Slow branch of ops on gpu so that the work queue for the gpu thread |
| 2738 | + # won't empty too quickly. They also have smaller priorities than the |
| 2739 | + # ones created by fn_on_gpu |
| 2740 | + branch1 = inp.cuda() |
| 2741 | + branch1 = branch1 / branch1 |
| 2742 | + branch1 = branch1 / branch1 |
| 2743 | + branch1 = branch1 / branch1 |
| 2744 | + # Perform checkpoint on cpu tensors. So the last op performed in the reentrant |
| 2745 | + # autograd is an AccumulateGrad that runs on the cpu thread for the gpu thread. |
| 2746 | + # So the cpu thread will notify the gpu thread with an empty FunctionTask. |
| 2747 | + branch2 = checkpoint(fn_on_gpu, inp) |
| 2748 | + out = branch2 + branch1 |
| 2749 | + return out |
| 2750 | + |
| 2751 | + inp = torch.rand(2, requires_grad=True) |
| 2752 | + out = parent_on_cpu(inp) |
| 2753 | + # This will segfault if the empty FunctionTask is not handled properly in the |
| 2754 | + # gpu thread ReadyQueue |
| 2755 | + out.sum().backward() |
| 2756 | + |
2725 | 2757 |
|
2726 | 2758 | def index_variable(shape, max_indices): |
2727 | 2759 | if not isinstance(shape, tuple): |
|
0 commit comments