|
1 | 1 | { |
2 | 2 | "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Creating extensions using numpy and scipy\n", |
| 8 | + "\n", |
| 9 | + "In this notebook, we shall go through two tasks:\n", |
| 10 | + "\n", |
| 11 | + "1. Create a neural network layer with no parameters. \n", |
| 12 | + " - This calls into **numpy** as part of it's implementation\n", |
| 13 | + "2. Create a neural network layer that has learnable weights\n", |
| 14 | + " - This calls into **SciPy** as part of it's implementation" |
| 15 | + ] |
| 16 | + }, |
3 | 17 | { |
4 | 18 | "cell_type": "code", |
5 | | - "execution_count": 92, |
| 19 | + "execution_count": 37, |
6 | 20 | "metadata": { |
7 | 21 | "collapsed": false |
8 | 22 | }, |
9 | 23 | "outputs": [], |
10 | 24 | "source": [ |
11 | 25 | "import torch\n", |
12 | | - "from torch.autograd import Function" |
| 26 | + "from torch.autograd import Function\n", |
| 27 | + "from torch.autograd import Variable" |
13 | 28 | ] |
14 | 29 | }, |
15 | 30 | { |
16 | 31 | "cell_type": "markdown", |
17 | 32 | "metadata": {}, |
18 | 33 | "source": [ |
19 | | - "# Parameter-less example" |
| 34 | + "## Parameter-less example\n", |
| 35 | + "\n", |
| 36 | + "This layer doesn't particularly do anything useful or mathematically correct.\n", |
| 37 | + "\n", |
| 38 | + "It is aptly named BadFFTFunction\n", |
| 39 | + "\n", |
| 40 | + "**Layer Implementation**" |
20 | 41 | ] |
21 | 42 | }, |
22 | 43 | { |
23 | 44 | "cell_type": "code", |
24 | | - "execution_count": 93, |
| 45 | + "execution_count": 38, |
25 | 46 | "metadata": { |
26 | 47 | "collapsed": false |
27 | 48 | }, |
|
40 | 61 | " result = irfft2(numpy_go)\n", |
41 | 62 | " return torch.FloatTensor(result)\n", |
42 | 63 | "\n", |
| 64 | + "# since this layer does not have any parameters, we can\n", |
| 65 | + "# simply declare this as a function, rather than as an nn.Module class\n", |
43 | 66 | "def incorrect_fft(input):\n", |
44 | | - " return FFTFunction()(input)" |
| 67 | + " return BadFFTFunction()(input)" |
| 68 | + ] |
| 69 | + }, |
| 70 | + { |
| 71 | + "cell_type": "markdown", |
| 72 | + "metadata": {}, |
| 73 | + "source": [ |
| 74 | + "**Example usage of the created layer:**" |
45 | 75 | ] |
46 | 76 | }, |
47 | 77 | { |
48 | 78 | "cell_type": "code", |
49 | | - "execution_count": 94, |
| 79 | + "execution_count": 39, |
50 | 80 | "metadata": { |
51 | 81 | "collapsed": false |
52 | 82 | }, |
|
56 | 86 | "output_type": "stream", |
57 | 87 | "text": [ |
58 | 88 | "\n", |
59 | | - " 3.0878 7.1403 7.5860 1.7596 3.0176\n", |
60 | | - " 6.3160 15.2517 11.1081 0.9172 6.8577\n", |
61 | | - " 8.6503 2.2013 6.3555 11.1981 1.9266\n", |
62 | | - " 3.9919 6.8862 8.8132 5.7938 4.2413\n", |
63 | | - " 12.2501 10.7839 6.7181 12.1096 1.1942\n", |
64 | | - " 3.9919 9.3072 2.6704 3.3263 4.2413\n", |
65 | | - " 8.6503 6.8158 12.4148 2.6462 1.9266\n", |
66 | | - " 6.3160 15.2663 9.8261 5.8583 6.8577\n", |
| 89 | + " 4.7742 8.5149 9.8856 10.2735 8.4410\n", |
| 90 | + " 3.8592 2.2888 5.0019 5.9478 5.1993\n", |
| 91 | + " 4.6596 3.4522 5.9725 11.0878 7.8076\n", |
| 92 | + " 8.2634 6.6598 6.0634 15.5515 6.9418\n", |
| 93 | + " 0.6407 7.4943 0.8726 4.4138 7.1496\n", |
| 94 | + " 8.2634 6.8300 2.8353 8.3108 6.9418\n", |
| 95 | + " 4.6596 1.9511 6.3037 5.1471 7.8076\n", |
| 96 | + " 3.8592 7.3977 7.2260 1.6832 5.1993\n", |
67 | 97 | "[torch.FloatTensor of size 8x5]\n", |
68 | 98 | "\n", |
69 | 99 | "\n", |
70 | | - " 0.0569 -0.3193 0.0401 0.1293 0.0318 0.1293 0.0401 -0.3193\n", |
71 | | - " 0.0570 0.0161 -0.0421 -0.1272 0.0414 0.0121 -0.0592 -0.0874\n", |
72 | | - "-0.1144 -0.0146 0.0604 -0.0023 0.0222 0.0622 0.0825 -0.1057\n", |
73 | | - "-0.0451 0.1061 0.0329 -0.0274 0.0302 -0.0347 0.0227 -0.1079\n", |
74 | | - " 0.1287 0.1796 -0.0766 -0.0698 0.0929 -0.0698 -0.0766 0.1796\n", |
75 | | - "-0.0451 -0.1079 0.0227 -0.0347 0.0302 -0.0274 0.0329 0.1061\n", |
76 | | - "-0.1144 -0.1057 0.0825 0.0622 0.0222 -0.0023 0.0604 -0.0146\n", |
77 | | - " 0.0570 -0.0874 -0.0592 0.0121 0.0414 -0.1272 -0.0421 0.0161\n", |
| 100 | + " 0.1044 0.0067 -0.0247 -0.0800 -0.1355 -0.0800 -0.0247 0.0067\n", |
| 101 | + "-0.1948 -0.0138 -0.1396 -0.0084 0.0774 0.0370 0.1352 0.1332\n", |
| 102 | + "-0.0153 -0.0668 0.1799 0.0574 0.0394 0.1392 0.0268 -0.1462\n", |
| 103 | + " 0.0199 0.0676 -0.1475 -0.0332 0.1312 0.0740 -0.1128 -0.1948\n", |
| 104 | + "-0.0416 -0.0159 -0.0166 -0.0070 0.1471 -0.0070 -0.0166 -0.0159\n", |
| 105 | + " 0.0199 -0.1948 -0.1128 0.0740 0.1312 -0.0332 -0.1475 0.0676\n", |
| 106 | + "-0.0153 -0.1462 0.0268 0.1392 0.0394 0.0574 0.1799 -0.0668\n", |
| 107 | + "-0.1948 0.1332 0.1352 0.0370 0.0774 -0.0084 -0.1396 -0.0138\n", |
78 | 108 | "[torch.FloatTensor of size 8x8]\n", |
79 | 109 | "\n" |
80 | 110 | ] |
|
92 | 122 | "cell_type": "markdown", |
93 | 123 | "metadata": {}, |
94 | 124 | "source": [ |
95 | | - "# Parametrized example" |
| 125 | + "## Parametrized example\n", |
| 126 | + "\n", |
| 127 | + "This implements a layer with learnable weights.\n", |
| 128 | + "\n", |
| 129 | + "It implements the Cross-correlation with a learnable kernel.\n", |
| 130 | + "\n", |
| 131 | + "In deep learning literature, it's confusingly referred to as Convolution.\n", |
| 132 | + "\n", |
| 133 | + "The backward computes the gradients wrt the input and gradients wrt the filter.\n", |
| 134 | + "\n", |
| 135 | + "**Implementation:**\n", |
| 136 | + "\n", |
| 137 | + "*Please Note that the implementation serves as an illustration, and we did not verify it's correctness*" |
96 | 138 | ] |
97 | 139 | }, |
98 | 140 | { |
99 | 141 | "cell_type": "code", |
100 | | - "execution_count": 95, |
| 142 | + "execution_count": 40, |
101 | 143 | "metadata": { |
102 | 144 | "collapsed": false |
103 | 145 | }, |
|
116 | 158 | " def backward(self, grad_output):\n", |
117 | 159 | " input, filter = self.saved_tensors\n", |
118 | 160 | " grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')\n", |
119 | | - " grad_filter = convolve2d(grad_output.numpy(), input.numpy(), mode='valid')\n", |
| 161 | + " grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')\n", |
120 | 162 | " return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)\n", |
121 | 163 | "\n", |
122 | 164 | "\n", |
|
131 | 173 | " return ScipyConv2dFunction()(input, self.filter)" |
132 | 174 | ] |
133 | 175 | }, |
| 176 | + { |
| 177 | + "cell_type": "markdown", |
| 178 | + "metadata": {}, |
| 179 | + "source": [ |
| 180 | + "**Example usage: **" |
| 181 | + ] |
| 182 | + }, |
134 | 183 | { |
135 | 184 | "cell_type": "code", |
136 | | - "execution_count": 96, |
| 185 | + "execution_count": 41, |
137 | 186 | "metadata": { |
138 | 187 | "collapsed": false |
139 | 188 | }, |
|
143 | 192 | "output_type": "stream", |
144 | 193 | "text": [ |
145 | 194 | "[Variable containing:\n", |
146 | | - "-1.5070 1.2195 0.3059\n", |
147 | | - "-0.9716 -1.6591 0.0582\n", |
148 | | - " 0.3959 1.4859 0.5762\n", |
| 195 | + "-1.0235 0.9875 0.2565\n", |
| 196 | + " 0.1980 -0.6102 0.1088\n", |
| 197 | + "-0.2887 0.4421 0.4697\n", |
149 | 198 | "[torch.FloatTensor of size 3x3]\n", |
150 | 199 | "]\n", |
151 | 200 | "Variable containing:\n", |
152 | | - " 0.8031 -2.6673 -3.7764 0.3957 -3.7494 -1.7617 -1.0052 -5.8402\n", |
153 | | - " 1.3038 6.2255 3.8769 2.4016 -1.7805 -3.1314 4.7049 11.2956\n", |
154 | | - " -3.4491 0.1618 -2.5647 2.3304 -0.2030 0.9072 -3.5095 -1.4599\n", |
155 | | - " 1.7574 0.6292 0.5140 -0.9045 -0.7373 -1.2061 -2.2977 3.6035\n", |
156 | | - " 0.4435 -1.0651 -0.5496 0.6387 1.7522 4.5231 -0.5720 -3.3034\n", |
157 | | - " -0.8580 -0.4809 2.4041 7.1462 -6.4747 -5.3665 2.0541 4.8248\n", |
158 | | - " -3.3959 0.2333 -0.2029 -2.6130 2.9378 2.5276 -0.8665 -2.6157\n", |
159 | | - " 4.6814 -5.2214 5.0351 0.9138 -5.0147 -3.1597 1.9054 -1.2458\n", |
| 201 | + " 0.7426 -0.4963 2.1839 -0.0167 -1.6349 -0.7259 -0.2989 0.0568\n", |
| 202 | + "-0.3100 2.2298 -2.2832 0.5753 4.0489 0.1377 0.1672 0.6429\n", |
| 203 | + "-1.8680 1.3115 1.8970 0.3323 -4.5448 -0.0464 -2.3960 1.5496\n", |
| 204 | + "-0.6578 0.6759 0.5512 -0.3498 2.6668 1.3984 1.9388 -1.6464\n", |
| 205 | + "-0.5867 0.5676 2.8697 -0.5566 -2.8876 1.2372 -1.1336 -0.0219\n", |
| 206 | + "-2.1587 1.1444 -0.5513 -0.5551 1.8229 0.6331 -0.0577 -1.4510\n", |
| 207 | + " 2.6664 1.4183 2.1640 0.4424 -0.3112 -2.0792 1.7458 -3.3291\n", |
| 208 | + "-0.4942 -2.1142 -0.2624 0.8993 1.4487 2.1706 -1.4943 0.8073\n", |
160 | 209 | "[torch.FloatTensor of size 8x8]\n", |
161 | 210 | "\n", |
162 | 211 | "\n", |
163 | | - " 0.1741 -1.9989 -0.2740 3.8120 0.3502 0.6712 3.0274 1.7058 0.4150 -0.3298\n", |
164 | | - "-1.8919 -2.6355 -3.2564 3.6947 2.5255 -6.7857 0.2239 -1.5672 -0.2663 -1.1211\n", |
165 | | - " 2.8815 2.5121 -4.7712 3.5822 -4.3752 0.7339 -0.7228 -1.7776 -2.0243 0.5019\n", |
166 | | - "-0.8926 0.1823 -4.3306 1.6298 1.4614 -1.5850 3.6988 3.1788 -1.2472 1.7891\n", |
167 | | - "-0.4497 2.5219 -0.0277 -2.5140 8.4283 -2.7177 -0.7160 2.5198 4.2670 -1.8847\n", |
168 | | - "-2.7016 -4.0250 2.7055 -0.6101 3.5926 0.5576 -1.8934 -3.3632 5.5995 -4.8563\n", |
169 | | - " 2.6918 -1.4062 1.1848 -1.7458 2.4408 0.9058 -3.6130 -3.0862 -0.1350 -1.6894\n", |
170 | | - "-0.2913 2.1607 4.0600 -1.4186 -4.5283 3.7960 -5.8559 -0.2632 -1.5944 1.9401\n", |
171 | | - " 0.4020 -2.5734 2.3380 -0.0078 -3.0894 3.5005 -1.3228 1.2757 0.7101 1.7986\n", |
172 | | - " 0.1187 -0.4283 -0.0142 -0.5494 -0.2744 0.8786 0.2644 0.7838 0.6230 0.4126\n", |
| 212 | + " 0.2528 0.6793 1.4519 0.8932 -1.6100 0.2802 0.7728 -1.7915 0.6271 -0.4103\n", |
| 213 | + " 1.1033 0.9326 -0.6076 0.0806 2.0530 -1.5469 -0.4001 2.3436 -1.4082 0.6746\n", |
| 214 | + "-2.2699 0.4997 -1.0990 -0.9396 -2.2007 -0.3414 -1.1383 1.5647 -0.8794 0.9267\n", |
| 215 | + "-0.0902 -2.0114 1.1145 -1.1107 0.4190 -0.7028 2.7191 -0.6072 1.3405 -0.2114\n", |
| 216 | + " 3.1340 -1.3749 0.5132 0.1247 1.3468 0.2727 -1.0975 0.5712 0.2452 -1.0394\n", |
| 217 | + "-1.7159 2.4817 -0.0412 -0.9571 0.8877 0.5806 0.1002 0.0128 -0.6611 -0.6181\n", |
| 218 | + "-1.6527 -2.9061 -3.1407 0.1848 -1.4983 0.1549 0.0607 -1.4082 0.7121 -0.5538\n", |
| 219 | + " 0.8319 2.1323 -0.5079 -1.8576 -0.9979 -1.6148 -1.2104 -0.2222 -0.6102 0.1271\n", |
| 220 | + "-0.0115 -0.5239 2.0231 1.3474 0.3604 1.7257 -0.3180 1.3881 0.0142 0.9140\n", |
| 221 | + "-0.0512 -0.3274 -0.1038 -0.1919 0.4578 1.0406 0.5750 1.0693 0.4735 0.4023\n", |
173 | 222 | "[torch.FloatTensor of size 10x10]\n", |
174 | 223 | "\n" |
175 | 224 | ] |
|
182 | 231 | "output = module(input)\n", |
183 | 232 | "print(output)\n", |
184 | 233 | "output.backward(torch.randn(8, 8))\n", |
185 | | - "print(input.grad)\n" |
| 234 | + "print(input.grad)" |
186 | 235 | ] |
187 | | - }, |
188 | | - { |
189 | | - "cell_type": "code", |
190 | | - "execution_count": null, |
191 | | - "metadata": { |
192 | | - "collapsed": true |
193 | | - }, |
194 | | - "outputs": [], |
195 | | - "source": [] |
196 | 236 | } |
197 | 237 | ], |
198 | 238 | "metadata": { |
199 | 239 | "kernelspec": { |
200 | | - "display_name": "Python 3", |
| 240 | + "display_name": "Python 2", |
201 | 241 | "language": "python", |
202 | | - "name": "python3" |
| 242 | + "name": "python2" |
203 | 243 | }, |
204 | 244 | "language_info": { |
205 | 245 | "codemirror_mode": { |
206 | 246 | "name": "ipython", |
207 | | - "version": 3 |
| 247 | + "version": 2 |
208 | 248 | }, |
209 | 249 | "file_extension": ".py", |
210 | 250 | "mimetype": "text/x-python", |
211 | 251 | "name": "python", |
212 | 252 | "nbconvert_exporter": "python", |
213 | | - "pygments_lexer": "ipython3", |
214 | | - "version": "3.5.2" |
| 253 | + "pygments_lexer": "ipython2", |
| 254 | + "version": "2.7.12" |
215 | 255 | } |
216 | 256 | }, |
217 | 257 | "nbformat": 4, |
|
0 commit comments