|
16 | 16 | # it is important to know when operations performed in a custom function |
17 | 17 | # are recorded by autograd, when they aren't, and most importantly, how |
18 | 18 | # `save_for_backward` works with all of this. |
19 | | -# |
20 | | -# Custom functions implicitly affects grad mode in two ways: |
21 | | -# |
22 | | -# - During forward, autograd does not record any the graph for any |
23 | | -# operations performed within the forward function. When forward |
24 | | -# completes, the backward function of the custom function |
25 | | -# becomes the `grad_fn` of each of the forward's outputs |
26 | | -# |
27 | | -# - During backward, autograd records the computation graph used to |
28 | | -# compute the backward pass if create_graph is specified |
29 | | -# |
30 | | -# Next, to understand how `save_for_backward` interacts with the above, |
31 | | -# we can explore a couple examples: |
32 | | - |
33 | | -###################################################################### |
34 | | -# Saving the Inputs |
35 | | -# ------------------------------------------------------------------- |
36 | | -# Consider this simple squaring function. It saves an input tensor |
37 | | -# for backward. Double backward works automatically when autograd |
38 | | -# is able to record operations in the backward pass, so there is usually |
39 | | -# nothing to worry about when we save an input for backward as |
40 | | -# the input should have grad_fn if it is a function of any tensor |
41 | | -# that requires grad. This allows the gradients to be properly propagated. |
42 | | - |
43 | | -import torch |
44 | | - |
45 | | -class Square(torch.autograd.Function): |
46 | | - @staticmethod |
47 | | - def forward(ctx, x): |
48 | | - # Because we are saving one of the inputs use `save_for_backward` |
49 | | - # Save non-tensors and non-inputs/non-outputs directly on ctx |
50 | | - ctx.save_for_backward(x) |
51 | | - return x**2 |
52 | | - |
53 | | - @staticmethod |
54 | | - def backward(ctx, grad_out): |
55 | | - # A function support double backward automatically if autograd |
56 | | - # is able to record the computations performed in backward |
57 | | - x, = ctx.saved_tensors |
58 | | - return grad_out * 2 * x |
59 | | - |
60 | | -# Use double precision because finite differencing method magnifies errors |
61 | | -x = torch.rand(3, 3, requires_grad=True, dtype=torch.double) |
62 | | -torch.autograd.gradcheck(Square.apply, x) |
63 | | -# Use gradcheck to verify second-order derivatives |
64 | | -torch.autograd.gradgradcheck(Square.apply, x) |
65 | | - |
66 | | -###################################################################### |
67 | | -# We can use torchviz to visualize the graph to see why this works |
68 | | -# |
69 | | -# .. code-block:: python |
70 | | -# |
71 | | -# import torchviz |
72 | | -# |
73 | | -# x = torch.tensor(1., requires_grad=True).clone() |
74 | | -# out = Square.apply(x) |
75 | | -# grad_x, = torch.autograd.grad(out, x, create_graph=True) |
76 | | -# torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out}) |
77 | | -# |
78 | | -# We can see that the gradient wrt to x, is itself a function of x (dout/dx = 2x) |
79 | | -# And the graph of this function has been properly constructed |
80 | | -# |
81 | | -# .. image:: https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png |
82 | | -# :width: 400 |
83 | | - |
84 | | -###################################################################### |
85 | | -# Saving the Outputs |
86 | | -# ------------------------------------------------------------------- |
87 | | -# A slight variation on the previous example is to save an output |
88 | | -# instead of input. The mechanics are similar because outputs are also |
89 | | -# associated with a grad_fn. |
90 | | -class Exp(torch.autograd.Function): |
91 | | - # Simple case where everything goes well |
92 | | - @staticmethod |
93 | | - def forward(ctx, x): |
94 | | - # This time we save the output |
95 | | - result = torch.exp(x) |
96 | | - # Note that we should use `save_for_backward` here when |
97 | | - # the tensor saved is an ouptut (or an input). |
98 | | - ctx.save_for_backward(result) |
99 | | - return result |
100 | | - |
101 | | - @staticmethod |
102 | | - def backward(ctx, grad_out): |
103 | | - result, = ctx.saved_tensors |
104 | | - return result * grad_out |
105 | | - |
106 | | -x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone() |
107 | | -# Validate our gradients using gradcheck |
108 | | -torch.autograd.gradcheck(Exp.apply, x) |
109 | | -torch.autograd.gradgradcheck(Exp.apply, x) |
110 | | - |
111 | | -###################################################################### |
112 | | -# Use torchviz to visualize the graph: |
113 | | -# |
114 | | -# .. code-block:: python |
115 | | -# |
116 | | -# out = Exp.apply(x) |
117 | | -# grad_x, = torch.autograd.grad(out, x, create_graph=True) |
118 | | -# torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out}) |
119 | | -# |
120 | | -# .. image:: https://user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png |
121 | | -# :width: 332 |
122 | | - |
123 | | -###################################################################### |
124 | | -# Saving Intermediate Results |
125 | | -# ------------------------------------------------------------------- |
126 | | -# A more tricky case is when we need to save an intermediate result. |
127 | | -# We demonstrate this case by implementing: |
128 | | -# |
129 | | -# .. math:: |
130 | | -# sinh(x) := \frac{e^x - e^{-x}}{2} |
131 | | -# |
132 | | -# Since the derivative of sinh is cosh, it might be useful to reuse |
133 | | -# `exp(x)` and `exp(-x)`, the two intermediate results in forward |
134 | | -# in the backward computation. |
135 | | -# |
136 | | -# Intermediate results should not be directly saved and used in backward though. |
137 | | -# Because forward is performed in no-grad mode, if an intermediate result |
138 | | -# of the forward pass is used to compute gradients in the backward pass |
139 | | -# the backward graph of the gradients would not include the operations |
140 | | -# that computed the intermediate result. This leads to incorrect gradients. |
141 | | -class Sinh(torch.autograd.Function): |
142 | | - @staticmethod |
143 | | - def forward(ctx, x): |
144 | | - expx = torch.exp(x) |
145 | | - expnegx = torch.exp(-x) |
146 | | - ctx.save_for_backward(expx, expnegx) |
147 | | - # In order to be able to save the intermediate results, a trick is to |
148 | | - # include them as our outputs, so that the backward graph is constructed |
149 | | - return (expx - expnegx) / 2, expx, expnegx |
150 | | - |
151 | | - @staticmethod |
152 | | - def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp): |
153 | | - expx, expnegx = ctx.saved_tensors |
154 | | - grad_input = grad_out * (expx + expnegx) / 2 |
155 | | - # We cannot skip accumulating these even though we won't use the outputs |
156 | | - # directly. They will be used later in the second backward. |
157 | | - grad_input += _grad_out_exp * expx |
158 | | - grad_input -= _grad_out_negexp * expnegx |
159 | | - return grad_input |
160 | | - |
161 | | -def sinh(x): |
162 | | - # Create a wrapper that only returns the first output |
163 | | - return Sinh.apply(x)[0] |
164 | | - |
165 | | -x = torch.rand(3, 3, requires_grad=True, dtype=torch.double) |
166 | | -torch.autograd.gradcheck(sinh, x) |
167 | | -torch.autograd.gradgradcheck(sinh, x) |
168 | | - |
169 | | -###################################################################### |
170 | | -# Use torchviz to visualize the graph: |
171 | | -# |
172 | | -# .. code-block:: python |
173 | | -# |
174 | | -# out = sinh(x) |
175 | | -# grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True) |
176 | | -# torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) |
177 | | -# |
178 | | -# .. image:: https://user-images.githubusercontent.com/13428986/126560494-e48eba62-be84-4b29-8c90-a7f6f40b1438.png |
179 | | -# :width: 460 |
180 | | - |
181 | | -###################################################################### |
182 | | -# Saving Intermediate Results: What not to do |
183 | | -# ------------------------------------------------------------------- |
184 | | -# Now we show what happens when we don't also return our intermediate |
185 | | -# results as outputs: `grad_x` would not even have a backward graph |
186 | | -# because it is purely a function `exp` and `expnegx`, which don't |
187 | | -# require grad. |
188 | | -class SinhBad(torch.autograd.Function): |
189 | | - # This is an example of what NOT to do! |
190 | | - @staticmethod |
191 | | - def forward(ctx, x): |
192 | | - expx = torch.exp(x) |
193 | | - expnegx = torch.exp(-x) |
194 | | - ctx.expx = expx |
195 | | - ctx.expnegx = expnegx |
196 | | - return (expx - expnegx) / 2 |
197 | | - |
198 | | - @staticmethod |
199 | | - def backward(ctx, grad_out): |
200 | | - expx = ctx.expx |
201 | | - expnegx = ctx.expnegx |
202 | | - grad_input = grad_out * (expx + expnegx) / 2 |
203 | | - return grad_input |
204 | | - |
205 | | -###################################################################### |
206 | | -# Use torchviz to visualize the graph. Notice that `grad_x` is not |
207 | | -# part of the graph! |
208 | | -# |
209 | | -# .. code-block:: python |
210 | | -# |
211 | | -# out = SinhBad.apply(x) |
212 | | -# grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True) |
213 | | -# torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) |
214 | | -# |
215 | | -# .. image:: https://user-images.githubusercontent.com/13428986/126565889-13992f01-55bc-411a-8aee-05b721fe064a.png |
216 | | -# :width: 232 |
217 | | - |
218 | | -###################################################################### |
219 | | -# When Backward is not Tracked |
220 | | -# ------------------------------------------------------------------- |
221 | | -# Finally, let's consider an example when it may not be possible for |
222 | | -# autograd to track gradients for a functions backward at all. |
223 | | -# We can imagine cube_backward to be a function that may require a |
224 | | -# non-PyTorch library like SciPy or NumPy, or written as a |
225 | | -# C++ extension. The workaround demonstrated here is to create another |
226 | | -# custom function CubeBackward where you also manually specify the |
227 | | -# backward of cube_backward! |
228 | | - |
229 | | -def cube_forward(x): |
230 | | - return x**3 |
231 | | - |
232 | | -def cube_backward(grad_out, x): |
233 | | - return grad_out * 3 * x**2 |
234 | | - |
235 | | -def cube_backward_backward(grad_out, sav_grad_out, x): |
236 | | - return grad_out * sav_grad_out * 6 * x |
237 | | - |
238 | | -def cube_backward_backward_grad_out(grad_out, x): |
239 | | - return grad_out * 3 * x**2 |
240 | | - |
241 | | -class Cube(torch.autograd.Function): |
242 | | - @staticmethod |
243 | | - def forward(ctx, x): |
244 | | - ctx.save_for_backward(x) |
245 | | - return cube_forward(x) |
246 | | - |
247 | | - @staticmethod |
248 | | - def backward(ctx, grad_out): |
249 | | - x, = ctx.saved_tensors |
250 | | - return CubeBackward.apply(grad_out, x) |
251 | | - |
252 | | -class CubeBackward(torch.autograd.Function): |
253 | | - @staticmethod |
254 | | - def forward(ctx, grad_out, x): |
255 | | - ctx.save_for_backward(x, grad_out) |
256 | | - return cube_backward(grad_out, x) |
257 | | - |
258 | | - @staticmethod |
259 | | - def backward(ctx, grad_out): |
260 | | - x, sav_grad_out = ctx.saved_tensors |
261 | | - dx = cube_backward_backward(grad_out, sav_grad_out, x) |
262 | | - dgrad_out = cube_backward_backward_grad_out(grad_out, x) |
263 | | - return dgrad_out, dx |
264 | | - |
265 | | -x = torch.tensor(2., requires_grad=True, dtype=torch.double) |
266 | | - |
267 | | -torch.autograd.gradcheck(Cube.apply, x) |
268 | | -torch.autograd.gradgradcheck(Cube.apply, x) |
269 | | - |
270 | | -###################################################################### |
271 | | -# Use torchviz to visualize the graph: |
272 | | -# |
273 | | -# .. code-block:: python |
274 | | -# |
275 | | -# out = Cube.apply(x) |
276 | | -# grad_x, = torch.autograd.grad(out, x, create_graph=True) |
277 | | -# torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out}) |
278 | | -# |
279 | | -# .. image:: https://user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png |
280 | | -# :width: 352 |
281 | | - |
282 | | -###################################################################### |
283 | | -# To conclude, whether double backward works for your custom function |
284 | | -# simply depends on whether the backward pass can be tracked by autograd. |
285 | | -# With the first two examples we show situations where double backward |
286 | | -# works out of the box. With the third and fourth examples, we demonstrate |
287 | | -# techniques that enable a backward function to be tracked, when they |
288 | | -# otherwise would not be. |
289 | 19 |
|
0 commit comments