File tree Expand file tree Collapse file tree 1 file changed +11
-1
lines changed
torch/csrc/jit/tensorexpr Expand file tree Collapse file tree 1 file changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -937,9 +937,19 @@ Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
937
937
bufferArgs_.emplace_back (inBuffer);
938
938
break ;
939
939
}
940
+ ExprHandle flat_size = 1 ;
941
+ for (size_t i = 0 ; i < *tt->sizes ().size (); i++) {
942
+ auto size = *tt->sizes ()[i];
943
+ if (size == 0 ) {
944
+ flat_size = 0 ;
945
+ break ;
946
+ }
947
+ flat_size = flat_size + (size - 1 ) * *tt->strides ()[i];
948
+ }
949
+ flat_size = IRSimplifier::simplify (flat_size);
940
950
BufHandle inBuffer (
941
951
" t" + input_name_map_[input],
942
- {0 },
952
+ {flat_size },
943
953
ToDtype (static_cast <ScalarType>(*tt->scalarType ())));
944
954
std::vector<DimArg> inputTensorDims;
945
955
for (size_t i = 0 ; i < *tt->sizes ().size (); i++) {
You can’t perform that action at this time.
0 commit comments