Skip to content

Commit 531b045

Browse files
huiguoofacebook-github-bot
authored andcommitted
[tensorexpr] Fix the buf size of discontiguous tensors (pytorch#69657)
Summary: Pull Request resolved: pytorch#69657 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D32974473 Pulled By: huiguoo fbshipit-source-id: 52dcd13d0ad7f7e4f1beb69dcaabc8ceb386ffca
1 parent aab67c6 commit 531b045

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

torch/csrc/jit/tensorexpr/kernel.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,9 +937,19 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
937937
bufferArgs_.emplace_back(inBuffer);
938938
break;
939939
}
940+
ExprHandle flat_size = 1;
941+
for (size_t i = 0; i < *tt->sizes().size(); i++) {
942+
auto size = *tt->sizes()[i];
943+
if (size == 0) {
944+
flat_size = 0;
945+
break;
946+
}
947+
flat_size = flat_size + (size - 1) * *tt->strides()[i];
948+
}
949+
flat_size = IRSimplifier::simplify(flat_size);
940950
BufHandle inBuffer(
941951
"t" + input_name_map_[input],
942-
{0},
952+
{flat_size},
943953
ToDtype(static_cast<ScalarType>(*tt->scalarType())));
944954
std::vector<DimArg> inputTensorDims;
945955
for (size_t i = 0; i < *tt->sizes().size(); i++) {

0 commit comments

Comments
 (0)