Generalize catArray for contiguous inputs and dim != 0#17032
Conversation
|
Do our tests explicitly exercise this branch? If not, please add. |
|
I believe this test exercises the branch: https://github.com/pytorch/pytorch/blob/master/test/test_torch.py#L4203 EDIT: looks like the inputs may be noncontiguous, let me dig deeper EDIT2: I put a print in the branch and ran that test and it printed out, so looks like it's tested. |
|
Looks like it. Maybe double check and make sure the edge case you mentioned is covered as well. In general we prefer to avoid modifying TH and instead porting over the function to aten. Maybe I could ask you to spend a bit of time on seeing how feasible that is? |
|
I have a slight preference not to port this at the same time as we make changes. It's harder to review and less obvious where the problem is if there's a bug report. |
| int64_t outer = 1, inner = 1; | ||
|
|
||
| // Outer is the product of dimensions from the left up to (and not | ||
| // including the concatenation dimension). This becomes the number of times |
There was a problem hiding this comment.
nit: you want the ')' after including.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@jamesr66a is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.
Quick benchmark script:
```
import torch, time
tensors = [torch.rand(6, 2, 1024) for i in range(5)]
NITER = 1000
s = time.time()
for i in range(NITER):
torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```
Before:
```
time per iter 8.089399337768554e-05
```
After:
```
time per iter 2.183413505554199e-05
```
Pull Request resolved: pytorch/pytorch#17032
Differential Revision: D14090038
Pulled By: jamesr66a
fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.
Quick benchmark script:
```
import torch, time
tensors = [torch.rand(6, 2, 1024) for i in range(5)]
NITER = 1000
s = time.time()
for i in range(NITER):
torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```
Before:
```
time per iter 8.089399337768554e-05
```
After:
```
time per iter 2.183413505554199e-05
```
Pull Request resolved: pytorch#17032
Differential Revision: D14090038
Pulled By: jamesr66a
fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.
Quick benchmark script:
```
import torch, time
tensors = [torch.rand(6, 2, 1024) for i in range(5)]
NITER = 1000
s = time.time()
for i in range(NITER):
torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```
Before:
```
time per iter 8.089399337768554e-05
```
After:
```
time per iter 2.183413505554199e-05
```
Pull Request resolved: pytorch#17032
Differential Revision: D14090038
Pulled By: jamesr66a
fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
I noticed that we were sinking a lot of time into
catoperations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fastmemcpybranch. This PR generalizes that branch.Quick benchmark script:
Before:
After: