Commit f1da989
Generalize catArray for contiguous inputs and dim != 0 (pytorch#17032)
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.
Quick benchmark script:
```
import torch, time
tensors = [torch.rand(6, 2, 1024) for i in range(5)]
NITER = 1000
s = time.time()
for i in range(NITER):
torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```
Before:
```
time per iter 8.089399337768554e-05
```
After:
```
time per iter 2.183413505554199e-05
```
Pull Request resolved: pytorch#17032
Differential Revision: D14090038
Pulled By: jamesr66a
fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de1 parent f3dd556 commit f1da989
1 file changed
+31
-14
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
779 | 779 | | |
780 | 780 | | |
781 | 781 | | |
782 | | - | |
| 782 | + | |
783 | 783 | | |
784 | 784 | | |
785 | | - | |
| 785 | + | |
| 786 | + | |
| 787 | + | |
| 788 | + | |
| 789 | + | |
| 790 | + | |
| 791 | + | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
786 | 802 | | |
787 | 803 | | |
788 | | - | |
789 | | - | |
790 | | - | |
791 | | - | |
792 | | - | |
793 | | - | |
794 | | - | |
795 | | - | |
796 | | - | |
797 | | - | |
798 | | - | |
799 | | - | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
| 807 | + | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
| 813 | + | |
| 814 | + | |
| 815 | + | |
| 816 | + | |
800 | 817 | | |
801 | 818 | | |
802 | 819 | | |
| |||
0 commit comments