forked from pytorch/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnestedtensor.py
More file actions
495 lines (414 loc) · 22.1 KB
/
nestedtensor.py
File metadata and controls
495 lines (414 loc) · 22.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
"""
NestedTensors
===============================================================
NestedTensors are similar to regular tensors, except for their shape:
* for a regular tensor, each dimension has a size
* for a nestedtensor, not all dimensions have regular sizes; some of them are jagged
Nestedtensors are a natural solution for representing sequential data within various domains:
* in NLP, sentences can have variable lengths, so a batch of sentences forms a nestedtensor
* in CV, images can have variable shapes, so a batch of images forms a nestedtensor
In this tutorial, we will demonstrate basic usage of nestedtensors and motivate their usefulness
for operating on sequential data of varying lengths with a real-world example.
NestedTensor are currently a prototype feature and are subject to change.
"""
import torch
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
######################################################################
# NestedTensor Initialization
# ----------------
#
######################################################################
# From the Python frontend, a nestedtensor can be created from a list of tensors.
# We denote nt[i] as the ith tensor component of a nestedtensor.
nt = torch.nested.nested_tensor([torch.arange(12).reshape(
2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)
print(f"{nt=}")
######################################################################
# By padding every underlying tensor to the same shape,
# a nestedtensor can be converted to a regular tensor.
padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0)
print(f"{padded_out_tensor=}")
######################################################################
# All tensors posses an attribute for determining if they are nested;
print(f"nt is nested: {nt.is_nested}")
print(f"padded_out_tensor is nested: {padded_out_tensor.is_nested}")
######################################################################
# It is common to construct nestedtensors from batches of irregularly shaped tensors.
# i.e. dimension 0 is assumed to be the batch dimension.
# Indexing dimension 0 gives back the first underlying tensor component.
print("First underlying tensor component:", nt[0], sep='\n')
print("last column of 2nd underlying tensor component:", nt[1, :, -1], sep='\n')
# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.
print(f"First underlying tensor component is nested: {nt[0].is_nested}")
######################################################################
# An important note is that slicing in dimension 0 has not been supported yet.
# Which means it not currently possible to construct a view that combines the underlying
# tensor components.
######################################################################
# Nested Tensor Operations
# ----------------
#
######################################################################
# As each operation must be explicitly implemented for nestedtensors,
# operation coverage for nestedtensors is currently narrower than that of regular tensors.
# For now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered.
# However, coverage is being expanded.
# If you need certain operations, please file an `issue <https://github.com/pytorch/pytorch>`__
# to help us prioritize coverage.
#
# **reshape**
#
# The reshape op is for changing the shape of a tensor.
# Its full semantics for regular tensors can be found
# `here <https://pytorch.org/docs/stable/generated/torch.reshape.html>`__.
# For regular tensors, when specifying the new shape,
# a single dimension may be -1, in which case it is inferred
# from the remaining dimensions and the number of elements.
#
# The semantics for nestedtensors are similar, except that -1 no longer infers.
# Instead, it inherits the old size (here 2 for ``nt[0]`` and 3 for ``nt[1]``).
# -1 is the only legal size to specify for a jagged dimension.
nt_reshaped = nt.reshape(2, -1, 2, 3)
print(f"{nt_reshaped=}")
######################################################################
# **transpose**
#
# The transpose op is for swapping two dimensions of a tensor.
# Its full semantics can be found
# `here <https://pytorch.org/docs/stable/generated/torch.transpose.html>`__.
# Note that for nestedtensors dimension 0 is special;
# it is assumed to be the batch dimension,
# so transposes involving nestedtensor dimension 0 are not supported.
nt_transposed = nt_reshaped.transpose(1, 2)
print(f"{nt_transposed=}")
######################################################################
# **others**
#
# Other operations have the same semantics as for regular tensors.
# Applying the operation on a nestedtensor is equivalent to
# applying the operation to the underlying tensor components,
# with the result being a nestedtensor as well.
nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)
nt3 = torch.matmul(nt_transposed, nt_mm)
print(f"Result of Matmul:\n {nt3}")
nt4 = F.dropout(nt3, 0.1)
print(f"Result of Dropout:\n {nt4}")
nt5 = F.softmax(nt4, -1)
print(f"Result of Softmax:\n {nt5}")
######################################################################
# Why Nested Tensor
# ----------------
#
######################################################################
# When data is sequential, it is often the case that each sample has a different length.
# For example, in a batch of sentences, each sentence has a different number of words.
# A common technique for handling varying sequences is to manually pad each data tensor
# to the same shape in order to form a batch.
# For example, we have 2 sentences with different lengths and a vocabulary
# In order to represent his as single tensor we pad with 0 to the max length in the batch.
sentences = [["goodbye", "padding"],
["embrace", "nested", "tensor"]]
vocabulary = {"goodbye": 1.0, "padding": 2.0,
"embrace": 3.0, "nested": 4.0, "tensor": 5.0}
padded_sentences = torch.tensor([[1.0, 2.0, 0.0],
[3.0, 4.0, 5.0]])
nested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),
torch.tensor([3.0, 4.0, 5.0])])
print(f"{padded_sentences=}")
print(f"{nested_sentences=}")
######################################################################
# This techinque of padding a batch of data to its max length is not optimal.
# The padded data is not needed for computation and wastes memory by allocating
# larger tensors than necessary.
# Further, not all operations have the same semnatics when applied to padded data.
# For matrix multiplications in order to ignore the padded entries, one needs to pad
# with 0 while for softmax one has to pad with -inf to ignore specific entries.
padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float("-inf")],
[3.0, 4.0, 5.0]])
print(F.softmax(padded_sentences_for_softmax, -1))
print(F.softmax(nested_sentences, -1))
######################################################################
# Let us take a look at a practical example: the multi-head attention component
# utilized in `Transformers <https://arxiv.org/pdf/1706.03762.pdf>`__.
# The nestedtensor version is straightforward.
import math
def mha_nested(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int,
W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor,
b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None,
dropout_p: float = 0.0) -> torch.Tensor:
"""Compute multi-head attention with nested tensors.
Args:
query (torch.Tensor): query of shape (N, L_t, E_q)
key (torch.Tensor): key of shape (N, L_s, E_k)
value (torch.Tensor): value of shape (N, L_s, E_v)
nheads (int): number of heads in multi-head attention
W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q)
W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k)
W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v)
W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total)
b_q (torch.Tensor, optional): Bias for query input projection of shape E_total. Default: None. Defaults to None.
b_k (torch.Tensor, optional): Bias for key input projection of shape E_total. Default: None. Defaults to None.
b_v (torch.Tensor, optional): Bias for value input projection of shape E_total. Default: None. Defaults to None.
b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Default: None. Defaults to None.
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
Where:
N is the batch size
L_t is the target sequence length (jagged)
L_s is the source sequence length (jagged)
E_q is the embedding size for query
E_k is the embedding size for key
E_v is the embedding size for value
E_total is the embedding size for all heads combined
E_out is the output embedding size
Returns:
torch.Tensor: Output of shape (N, L_t, E_out)
"""
N = query.size(0)
E_total = W_q.size(0)
assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
E_head = E_total // nheads
# apply input projection
# (N, L_t, E_q) -> (N, L_t, E_total)
query = F.linear(query, W_q, b_q)
# (N, L_s, E_k) -> (N, L_s, E_total)
key = F.linear(key, W_k, b_k)
# (N, L_s, E_v) -> (N, L_s, E_total)
value = F.linear(value, W_v, b_v)
# reshape query, key, value to separate by head
# (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
query = query.reshape(N, -1, nheads, E_head).transpose(1, 2)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
key = key.reshape(N, -1, nheads, E_head).transpose(1, 2)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
value = value.reshape(N, -1, nheads, E_head).transpose(1, 2)
# query matmul key^T
# (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s)
keyT = key.transpose(-1, -2)
attn_weights = torch.matmul(query, keyT)
# scale down
attn_weights = attn_weights * (1.0 / math.sqrt(E_head))
# softmax
attn_weights = F.softmax(attn_weights, dim=-1)
# dropout
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p)
# attention_weights matmul value
# (N, nheads, L_t, L_s) x (N, nheads, L_s, E_head) -> (N, nheads, L_t, E_head)
attn_output = torch.matmul(attn_weights, value)
# merge heads
# (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
attn_output = attn_output.transpose(1, 2).reshape(N, -1, E_total)
# apply output projection
# (N, L_t, E_total) -> (N, L_t, E_out)
attn_output = F.linear(attn_output, W_out, b_out)
return attn_output
######################################################################
# The 0-padded tensor version additionally requires masks
# for more complicated treatments at padded entries.
def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int,
attn_mask_q: torch.Tensor, attn_mask_kv: torch.Tensor,
W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor,
b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None,
dropout_p: float = 0.0) -> torch.Tensor:
"""Compute multi-head attention for padded out dense tensors.
Args:
query (torch.Tensor): query of shape (N, L_t, E_q)
key (torch.Tensor): key of shape (N, L_s, E_k)
value (torch.Tensor): value of shape (N, L_s, E_v)
nheads (int): number of heads in multi-head attention
attn_mask_q (torch.Tensor): boolean mask indicating locations that should not take part in attention for query, shape (N, L_t)
attn_mask_kv (torch.Tensor): boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s)
W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q)
W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k)
W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v)
W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total)
b_q (torch.Tensor, optional): Bias for query input projection of shape E_total.. Defaults to None.
b_k (torch.Tensor, optional): Bias for key input projection of shape E_total.. Defaults to None.
b_v (torch.Tensor, optional): Bias for value input projection of shape E_total.. Defaults to None.
b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Defaults to None.
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
Where:
N is the batch size
L_t is the target sequence length (padded)
L_s is the source sequence length (padded)
E_q is the embedding size for query
E_k is the embedding size for key
E_v is the embedding size for value
E_total is the embedding size for all heads combined
E_out is the output embedding size
Returns:
torch.Tensor: Output of shape (N, L_t, E_out)
"""
N = query.size(0)
L_t = query.size(1)
L_s = key.size(1)
E_total = W_q.size(0)
assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
assert L_t == L_s, "This implementation assumes equal query and key sequence lengths"
E_head = E_total // nheads
# apply input projection
# (N, L_t, E_q) -> (N, L_t, E_total)
query = F.linear(query, W_q, b_q)
# (N, L_s, E_k) -> (N, L_s, E_total)
key = F.linear(key, W_k, b_k)
# (N, L_s, E_v) -> (N, L_s, E_total)
value = F.linear(value, W_v, b_v)
# reshape query, key, value to separate by head
# (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head)
query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head)
key = key.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)
# (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head)
value = value.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)
# query bmm key^T
# (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s)
keyT = key.transpose(-1, -2)
attn_weights = torch.bmm(query, keyT)
# scale down
attn_weights = attn_weights * (1.0 / math.sqrt(E_head))
# Have to manipulate masks in order to apply them to the attention weights
key_padding_mask = attn_mask_q.view(N, 1, 1, L_t).expand(-1, nheads, -1, -1).reshape(N*nheads, 1, L_t).to(device=device)
attn_mask = torch.zeros(key_padding_mask.shape, device=device, dtype=torch.float32)
attn_mask = attn_mask.masked_fill_(key_padding_mask, float("-inf"))
# Zero out the attention weights where the mask is True by adding -inf prior to softmax
attn_weights.add_(attn_mask)
# softmax
attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0)
# dropout
if dropout_p > 0.0:
attn_weights = F.dropout(attn_weights, p=dropout_p)
# attention_weights bmm value
# (N * nheads, L_t, L_s) x (N * nheads, L_s, E_head) -> (N * nheads, L_t, E_head)
attn_output = attn_weights.bmm(value)
# merge heads
# (N * nheads, L_t, E_head) -> (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
attn_output = attn_output.reshape(N, nheads, -1, E_head).transpose(1, 2).reshape(N, -1, E_total)
# apply output projection
# (N, L_t, E_total) -> (N, L_t, E_out)
attn_output = F.linear(attn_output, W_out, b_out)
# padding-specific step: remove output projection bias from padded entries
attn_output[attn_mask_q, :] = 0.0
return attn_output
######################################################################
# set hyperparameters following `the Transformer paper <https://arxiv.org/pdf/1706.03762.pdf>`__
N = 512
E_q, E_k, E_v, E_total, E_out = 512, 512, 512, 512, 512
nheads = 8
######################################################################
# except for dropout probability: set to 0 for correctness check
dropout_p = 0.0
######################################################################
# Let us generate some realistic fake data from Zipf's law.
import numpy as np
def zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:
# generate fake corpus by unigram Zipf distribution
# from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
sentence_lengths = np.empty(batch_size, dtype=int)
for ibatch in range(batch_size):
sentence_lengths[ibatch] = 1
word = np.random.zipf(alpha)
while word != 3 and word != 386 and word != 858:
sentence_lengths[ibatch] += 1
word = np.random.zipf(alpha)
return sentence_lengths
alpha = 1.2
sentence_lengths = zipf_sentence_lengths(alpha, N)
L_t = np.max(sentence_lengths)
L_s = L_t
######################################################################
# create inputs
# create parameters
W_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device)
W_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device)
W_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device)
W_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device)
# create nested input
queries = []
keys = []
values = []
for i in range(N):
l = sentence_lengths[i]
s = l
queries.append(torch.randn((l, E_q), device=device))
keys .append(torch.randn((s, E_k), device=device))
values .append(torch.randn((s, E_v), device=device))
query = torch.nested.nested_tensor(queries)
key = torch.nested.nested_tensor(keys)
value = torch.nested.nested_tensor(values)
# pad input
padded_query = torch.nested.to_padded_tensor(query, 0.0, (N, L_t, E_q))
padded_key = torch.nested.to_padded_tensor(key, 0.0, (N, L_s, E_k))
padded_value = torch.nested.to_padded_tensor(value, 0.0, (N, L_s, E_v))
# create attention masks
attn_mask_q = torch.zeros((N, L_t), dtype=torch.bool)
attn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool)
# We need to mask out the padding entries in the attention weights.
for i, entry_length in enumerate(sentence_lengths):
attn_mask_q[i, entry_length:] = True
attn_mask_kv[i, entry_length:] = True
######################################################################
# check correctness and performance
import timeit
t0 = timeit.default_timer()
out_nested = mha_nested(
query, key, value, nheads,
W_q, W_k, W_v, W_out,
b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
dropout_p=dropout_p)
t1 = timeit.default_timer()
out_padded = mha_padded(
padded_query, padded_key, padded_value, nheads,
attn_mask_q, attn_mask_kv,
W_q, W_k, W_v, W_out,
b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
dropout_p=dropout_p)
t2 = timeit.default_timer()
print("nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_nested, 0.0, (N, L_t, E_out)) - out_padded).abs().max().item())
print("nestedtensor multi-head attention takes", t1 - t0, "seconds")
print("padded tensor multi-head attention takes", t2 - t1, "seconds")
######################################################################
# Although the nestedtensor version avoids wasted computation on padding, it is not faster
# then the equivalent padded tensor version. This is because the nestedtensor version
# has implemented a few of the kernels, like softmax, in a non optimal way.
#
# There are plans to implement performance critical operations using the new Pytorch 2.0 stack
# For now, some performant kernels are provided for specific use cases, e.g.
# self-attention evaluation by multi-head attention formula.
# embeddings are assumed to be the same
E = E_total
mha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device)
mha_lib.eval()
######################################################################
# extract parameters for correctness check
mha_lib.in_proj_weight.requires_grad_(False)
mha_lib.in_proj_bias.requires_grad_(False)
mha_lib.out_proj.weight.requires_grad_(False)
mha_lib.out_proj.bias.requires_grad_(False)
W_q, b_q = mha_lib.in_proj_weight[: E, :], mha_lib.in_proj_bias[: E]
W_k, b_k = mha_lib.in_proj_weight[E : 2 * E, :], mha_lib.in_proj_bias[E : 2 * E]
W_v, b_v = mha_lib.in_proj_weight[2 * E :, :], mha_lib.in_proj_bias[2 * E :]
W_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias
######################################################################
# If we set need_weights to False this will enable the fast path in the library.
# Under the hood this will call _scaled_dot_product_attention. If your tensors
# are on CUDA, than a fused, efficient attention kernel will be used. For
# more detailed performance characteristics look at the benchmark in
# pytorch/benchmarks/transformer/sdp.py
with torch.inference_mode():
t0 = timeit.default_timer()
out_lib, out_lib_weights = mha_lib(query, query, query, need_weights=False)
t1 = timeit.default_timer()
padded_out = mha_padded(
padded_query, padded_query, padded_query, nheads,
attn_mask_q, attn_mask_q,
W_q, W_k, W_v, W_out,
b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,
dropout_p=dropout_p)
t2 = timeit.default_timer()
nested_time = t1 - t0
padded_time = t2 - t1
print("Nested and padded calculations differ by", (torch.nested.to_padded_tensor(out_lib, 0.0) - padded_out).abs().max().item())
print("Nested library multi-head attention takes", nested_time, "seconds")
print("Padded tensor multi-head attention takes", padded_time, "seconds")
print(f"Nested Speedup: {padded_time / nested_time:.3f}")