Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fc15597

Browse files
committed
Test
1 parent a5a3c24 commit fc15597

1 file changed

Lines changed: 0 additions & 270 deletions

File tree

intermediate_source/custom_function_double_backward_tutorial.py

Lines changed: 0 additions & 270 deletions
Original file line numberDiff line numberDiff line change
@@ -16,274 +16,4 @@
1616
# it is important to know when operations performed in a custom function
1717
# are recorded by autograd, when they aren't, and most importantly, how
1818
# `save_for_backward` works with all of this.
19-
#
20-
# Custom functions implicitly affects grad mode in two ways:
21-
#
22-
# - During forward, autograd does not record any the graph for any
23-
# operations performed within the forward function. When forward
24-
# completes, the backward function of the custom function
25-
# becomes the `grad_fn` of each of the forward's outputs
26-
#
27-
# - During backward, autograd records the computation graph used to
28-
# compute the backward pass if create_graph is specified
29-
#
30-
# Next, to understand how `save_for_backward` interacts with the above,
31-
# we can explore a couple examples:
32-
33-
######################################################################
34-
# Saving the Inputs
35-
# -------------------------------------------------------------------
36-
# Consider this simple squaring function. It saves an input tensor
37-
# for backward. Double backward works automatically when autograd
38-
# is able to record operations in the backward pass, so there is usually
39-
# nothing to worry about when we save an input for backward as
40-
# the input should have grad_fn if it is a function of any tensor
41-
# that requires grad. This allows the gradients to be properly propagated.
42-
43-
import torch
44-
45-
class Square(torch.autograd.Function):
46-
@staticmethod
47-
def forward(ctx, x):
48-
# Because we are saving one of the inputs use `save_for_backward`
49-
# Save non-tensors and non-inputs/non-outputs directly on ctx
50-
ctx.save_for_backward(x)
51-
return x**2
52-
53-
@staticmethod
54-
def backward(ctx, grad_out):
55-
# A function support double backward automatically if autograd
56-
# is able to record the computations performed in backward
57-
x, = ctx.saved_tensors
58-
return grad_out * 2 * x
59-
60-
# Use double precision because finite differencing method magnifies errors
61-
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
62-
torch.autograd.gradcheck(Square.apply, x)
63-
# Use gradcheck to verify second-order derivatives
64-
torch.autograd.gradgradcheck(Square.apply, x)
65-
66-
######################################################################
67-
# We can use torchviz to visualize the graph to see why this works
68-
#
69-
# .. code-block:: python
70-
#
71-
# import torchviz
72-
#
73-
# x = torch.tensor(1., requires_grad=True).clone()
74-
# out = Square.apply(x)
75-
# grad_x, = torch.autograd.grad(out, x, create_graph=True)
76-
# torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})
77-
#
78-
# We can see that the gradient wrt to x, is itself a function of x (dout/dx = 2x)
79-
# And the graph of this function has been properly constructed
80-
#
81-
# .. image:: https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png
82-
# :width: 400
83-
84-
######################################################################
85-
# Saving the Outputs
86-
# -------------------------------------------------------------------
87-
# A slight variation on the previous example is to save an output
88-
# instead of input. The mechanics are similar because outputs are also
89-
# associated with a grad_fn.
90-
class Exp(torch.autograd.Function):
91-
# Simple case where everything goes well
92-
@staticmethod
93-
def forward(ctx, x):
94-
# This time we save the output
95-
result = torch.exp(x)
96-
# Note that we should use `save_for_backward` here when
97-
# the tensor saved is an ouptut (or an input).
98-
ctx.save_for_backward(result)
99-
return result
100-
101-
@staticmethod
102-
def backward(ctx, grad_out):
103-
result, = ctx.saved_tensors
104-
return result * grad_out
105-
106-
x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
107-
# Validate our gradients using gradcheck
108-
torch.autograd.gradcheck(Exp.apply, x)
109-
torch.autograd.gradgradcheck(Exp.apply, x)
110-
111-
######################################################################
112-
# Use torchviz to visualize the graph:
113-
#
114-
# .. code-block:: python
115-
#
116-
# out = Exp.apply(x)
117-
# grad_x, = torch.autograd.grad(out, x, create_graph=True)
118-
# torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})
119-
#
120-
# .. image:: https://user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png
121-
# :width: 332
122-
123-
######################################################################
124-
# Saving Intermediate Results
125-
# -------------------------------------------------------------------
126-
# A more tricky case is when we need to save an intermediate result.
127-
# We demonstrate this case by implementing:
128-
#
129-
# .. math::
130-
# sinh(x) := \frac{e^x - e^{-x}}{2}
131-
#
132-
# Since the derivative of sinh is cosh, it might be useful to reuse
133-
# `exp(x)` and `exp(-x)`, the two intermediate results in forward
134-
# in the backward computation.
135-
#
136-
# Intermediate results should not be directly saved and used in backward though.
137-
# Because forward is performed in no-grad mode, if an intermediate result
138-
# of the forward pass is used to compute gradients in the backward pass
139-
# the backward graph of the gradients would not include the operations
140-
# that computed the intermediate result. This leads to incorrect gradients.
141-
class Sinh(torch.autograd.Function):
142-
@staticmethod
143-
def forward(ctx, x):
144-
expx = torch.exp(x)
145-
expnegx = torch.exp(-x)
146-
ctx.save_for_backward(expx, expnegx)
147-
# In order to be able to save the intermediate results, a trick is to
148-
# include them as our outputs, so that the backward graph is constructed
149-
return (expx - expnegx) / 2, expx, expnegx
150-
151-
@staticmethod
152-
def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp):
153-
expx, expnegx = ctx.saved_tensors
154-
grad_input = grad_out * (expx + expnegx) / 2
155-
# We cannot skip accumulating these even though we won't use the outputs
156-
# directly. They will be used later in the second backward.
157-
grad_input += _grad_out_exp * expx
158-
grad_input -= _grad_out_negexp * expnegx
159-
return grad_input
160-
161-
def sinh(x):
162-
# Create a wrapper that only returns the first output
163-
return Sinh.apply(x)[0]
164-
165-
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
166-
torch.autograd.gradcheck(sinh, x)
167-
torch.autograd.gradgradcheck(sinh, x)
168-
169-
######################################################################
170-
# Use torchviz to visualize the graph:
171-
#
172-
# .. code-block:: python
173-
#
174-
# out = sinh(x)
175-
# grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
176-
# torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
177-
#
178-
# .. image:: https://user-images.githubusercontent.com/13428986/126560494-e48eba62-be84-4b29-8c90-a7f6f40b1438.png
179-
# :width: 460
180-
181-
######################################################################
182-
# Saving Intermediate Results: What not to do
183-
# -------------------------------------------------------------------
184-
# Now we show what happens when we don't also return our intermediate
185-
# results as outputs: `grad_x` would not even have a backward graph
186-
# because it is purely a function `exp` and `expnegx`, which don't
187-
# require grad.
188-
class SinhBad(torch.autograd.Function):
189-
# This is an example of what NOT to do!
190-
@staticmethod
191-
def forward(ctx, x):
192-
expx = torch.exp(x)
193-
expnegx = torch.exp(-x)
194-
ctx.expx = expx
195-
ctx.expnegx = expnegx
196-
return (expx - expnegx) / 2
197-
198-
@staticmethod
199-
def backward(ctx, grad_out):
200-
expx = ctx.expx
201-
expnegx = ctx.expnegx
202-
grad_input = grad_out * (expx + expnegx) / 2
203-
return grad_input
204-
205-
######################################################################
206-
# Use torchviz to visualize the graph. Notice that `grad_x` is not
207-
# part of the graph!
208-
#
209-
# .. code-block:: python
210-
#
211-
# out = SinhBad.apply(x)
212-
# grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
213-
# torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
214-
#
215-
# .. image:: https://user-images.githubusercontent.com/13428986/126565889-13992f01-55bc-411a-8aee-05b721fe064a.png
216-
# :width: 232
217-
218-
######################################################################
219-
# When Backward is not Tracked
220-
# -------------------------------------------------------------------
221-
# Finally, let's consider an example when it may not be possible for
222-
# autograd to track gradients for a functions backward at all.
223-
# We can imagine cube_backward to be a function that may require a
224-
# non-PyTorch library like SciPy or NumPy, or written as a
225-
# C++ extension. The workaround demonstrated here is to create another
226-
# custom function CubeBackward where you also manually specify the
227-
# backward of cube_backward!
228-
229-
def cube_forward(x):
230-
return x**3
231-
232-
def cube_backward(grad_out, x):
233-
return grad_out * 3 * x**2
234-
235-
def cube_backward_backward(grad_out, sav_grad_out, x):
236-
return grad_out * sav_grad_out * 6 * x
237-
238-
def cube_backward_backward_grad_out(grad_out, x):
239-
return grad_out * 3 * x**2
240-
241-
class Cube(torch.autograd.Function):
242-
@staticmethod
243-
def forward(ctx, x):
244-
ctx.save_for_backward(x)
245-
return cube_forward(x)
246-
247-
@staticmethod
248-
def backward(ctx, grad_out):
249-
x, = ctx.saved_tensors
250-
return CubeBackward.apply(grad_out, x)
251-
252-
class CubeBackward(torch.autograd.Function):
253-
@staticmethod
254-
def forward(ctx, grad_out, x):
255-
ctx.save_for_backward(x, grad_out)
256-
return cube_backward(grad_out, x)
257-
258-
@staticmethod
259-
def backward(ctx, grad_out):
260-
x, sav_grad_out = ctx.saved_tensors
261-
dx = cube_backward_backward(grad_out, sav_grad_out, x)
262-
dgrad_out = cube_backward_backward_grad_out(grad_out, x)
263-
return dgrad_out, dx
264-
265-
x = torch.tensor(2., requires_grad=True, dtype=torch.double)
266-
267-
torch.autograd.gradcheck(Cube.apply, x)
268-
torch.autograd.gradgradcheck(Cube.apply, x)
269-
270-
######################################################################
271-
# Use torchviz to visualize the graph:
272-
#
273-
# .. code-block:: python
274-
#
275-
# out = Cube.apply(x)
276-
# grad_x, = torch.autograd.grad(out, x, create_graph=True)
277-
# torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
278-
#
279-
# .. image:: https://user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png
280-
# :width: 352
281-
282-
######################################################################
283-
# To conclude, whether double backward works for your custom function
284-
# simply depends on whether the backward pass can be tracked by autograd.
285-
# With the first two examples we show situations where double backward
286-
# works out of the box. With the third and fourth examples, we demonstrate
287-
# techniques that enable a backward function to be tracked, when they
288-
# otherwise would not be.
28919

0 commit comments

Comments
 (0)