@@ -58,7 +58,7 @@ def forward(self, x):
5858# we will need a function that accepts the parameters of the model and a single
5959# input (as opposed to a batch of inputs!) and returns a single output.
6060#
61- # We'll use ``torch.func.functional_call``, which allows us to call an nn.Module
61+ # We'll use ``torch.func.functional_call``, which allows us to call an `` nn.Module``
6262# using different parameters/buffers, to help accomplish the first step.
6363#
6464# Keep in mind that the model was originally written to accept a batch of input
@@ -200,21 +200,21 @@ def func_x2(params):
200200 output , vjp_fn = vjp (func_x1 , params )
201201
202202 def get_ntk_slice (vec ):
203- # This computes vec @ J(x2).T
203+ # This computes `` vec @ J(x2).T``
204204 # `vec` is some unit vector (a single slice of the Identity matrix)
205205 vjps = vjp_fn (vec )
206- # This computes J(X1) @ vjps
206+ # This computes `` J(X1) @ vjps``
207207 _ , jvps = jvp (func_x2 , (params ,), vjps )
208208 return jvps
209209
210210 # Here's our identity matrix
211211 basis = torch .eye (output .numel (), dtype = output .dtype , device = output .device ).view (output .numel (), - 1 )
212212 return vmap (get_ntk_slice )(basis )
213213
214- # get_ntk(x1, x2) computes the NTK for a single data point x1, x2
215- # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,
214+ # `` get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2
215+ # Since the x1, x2 inputs to `` empirical_ntk_ntk_vps`` are batched,
216216 # we actually wish to compute the NTK between every pair of data points
217- # between {x1} and {x2}. That's what the vmaps here do.
217+ # between {x1} and {x2}. That's what the `` vmaps`` here do.
218218 result = vmap (vmap (get_ntk , (None , 0 )), (0 , None ))(x1 , x2 )
219219
220220 if compute == 'full' :
0 commit comments