Skip to content

Commit f1da989

Browse files
James Reedfacebook-github-bot
James Reed
authored andcommitted
Generalize catArray for contiguous inputs and dim != 0 (pytorch#17032)
Summary: I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch. Quick benchmark script: ``` import torch, time tensors = [torch.rand(6, 2, 1024) for i in range(5)] NITER = 1000 s = time.time() for i in range(NITER): torch.cat(tensors, dim=1) print('time per iter ', (time.time() - s) / NITER) ``` Before: ``` time per iter 8.089399337768554e-05 ``` After: ``` time per iter 2.183413505554199e-05 ``` Pull Request resolved: pytorch#17032 Differential Revision: D14090038 Pulled By: jamesr66a fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
1 parent f3dd556 commit f1da989

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

aten/src/TH/generic/THTensor.cpp

+31-14
Original file line numberDiff line numberDiff line change
@@ -779,24 +779,41 @@ void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int
779779
}
780780
allContiguous = allContiguous && THTensor_(isContiguous)(result);
781781

782-
// First path is for contiguous inputs along dim 0
782+
// First path is for contiguous inputs
783783
// Second path for non-contiguous
784784
int64_t offset;
785-
if (dimension == 0 && allContiguous) {
785+
if (allContiguous) {
786+
int64_t outer = 1, inner = 1;
787+
788+
// Outer is the product of dimensions from the left up to (and not
789+
// including the concatenation dimension). This becomes the number of times
790+
// we have to replicate the memcpy call.
791+
for (int i = 0; i < dimension; ++i) {
792+
outer *= size[i];
793+
}
794+
795+
// The product of dimensions to the right of the concatenation dimension.
796+
// We go on to multiply this by the size of the concat dimension for
797+
// each input tensor.
798+
for (int i = dimension + 1; i < size.size(); ++i) {
799+
inner *= size[i];
800+
}
801+
786802
scalar_t* result_data = THStorage_(data)(THTensor_getStoragePtr(result)) + result->storage_offset();
787803
offset = 0;
788-
for (int j = 0; j < numInputs; j++) {
789-
if (!should_skip(inputs[j])) {
790-
THTensor* input0 = inputs[j];
791-
scalar_t* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset();
792-
int64_t input0_size = THTensor_(nElement)(input0);
793-
// C standard says you can't pass nullptrs to memcpy, even if the size is 0; ubsan checks this.
794-
if (input0_size != 0) {
795-
memcpy(result_data + offset, input0_data, input0_size*sizeof(scalar_t));
796-
}
797-
offset += input0_size;
798-
}
799-
}
804+
for (int o = 0; o < outer; ++o) {
805+
for (int j = 0; j < numInputs; ++j) {
806+
if (!should_skip(inputs[j])) {
807+
THTensor* input0 = inputs[j];
808+
scalar_t* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset();
809+
int local_inner = inner * input0->size(dimension);
810+
if (local_inner != 0) {
811+
memcpy(result_data + offset, input0_data + o*local_inner, local_inner*sizeof(scalar_t));
812+
} // input0_size != 0
813+
offset += local_inner;
814+
} // should_skip
815+
} // for j
816+
} // for i
800817
} else {
801818
offset = 0;
802819
for (int j = 0; j < numInputs; j++) {

0 commit comments

Comments
 (0)