You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Generalize catArray for contiguous inputs and dim != 0 (pytorch#17032)
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.
Quick benchmark script:
```
import torch, time
tensors = [torch.rand(6, 2, 1024) for i in range(5)]
NITER = 1000
s = time.time()
for i in range(NITER):
torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```
Before:
```
time per iter 8.089399337768554e-05
```
After:
```
time per iter 2.183413505554199e-05
```
Pull Request resolved: pytorch#17032
Differential Revision: D14090038
Pulled By: jamesr66a
fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
0 commit comments