Skip to content

Commit e0f21a4

Browse files
Roy Lifacebook-github-bot
Roy Li
authored andcommitted
restore caffe2 strides (pytorch#12883)
Summary: Pull Request resolved: pytorch#12883 Attempting to do this again. last try broke oss ci: D10421896 Reallocation of strides_ if there's no change in dim seems to cause the error that broke internal flow last time. This fixes that. Found a potential race condition in caffe2 counter ops that might be the cause, we will investigate that. Reviewed By: ezyang Differential Revision: D10469960 fbshipit-source-id: 478186ff0d2f3dba1fbff6231db715322418d79c
1 parent 88f70fc commit e0f21a4

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

aten/src/ATen/core/TensorImpl.cpp

-10
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,13 @@ IntList TensorImpl::sizes() const {
4646
}
4747

4848
IntList TensorImpl::strides() const {
49-
AT_ASSERTM(strides_,
50-
"Caffe2 tensors don't (yet) have meaningful strides and cannot "
51-
"be used in PyTorch.");
5249
return IntList{strides_.get(), sizes_.size()};
5350
}
5451

5552
bool TensorImpl::compute_contiguous() const {
5653
bool is_contiguous = true;
5754
if (is_empty())
5855
return is_contiguous;
59-
if (!strides_) {
60-
// Special case for Caffe2 tensors which don't have strides set.
61-
return true;
62-
}
6356
int64_t z = 1;
6457
for (int64_t d = dim() - 1; d >= 0; d--) {
6558
if (size(d) != 1) {
@@ -90,9 +83,6 @@ int64_t TensorImpl::size(int64_t d) const {
9083
}
9184

9285
int64_t TensorImpl::stride(int64_t d) const {
93-
AT_ASSERTM(strides_,
94-
"Caffe2 tensors don't (yet) have meaningful strides and cannot "
95-
"be used in PyTorch.");
9686
d = at::maybe_wrap_dim(d, dim(), false);
9787
return strides_[d];
9888
}

aten/src/ATen/core/TensorImpl.h

+27-14
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
166166
explicit TensorImpl(at::Storage storage) : storage_(std::move(storage)), storage_offset_(0) {
167167
AT_ASSERT(storage_);
168168
data_type_ = storage_.dtype();
169+
strides_.reset(new int64_t[1]);
170+
strides_[0] = 1;
169171
}
170172

171173
TensorImpl(const TensorImpl&) = default;
@@ -317,15 +319,17 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
317319
// assumes that old values are preserved
318320
auto old_dim = sizes_.size();
319321
sizes_.resize(ndim);
320-
auto new_strides = c10::guts::make_unique<int64_t[]>(ndim);
321-
for (size_t i = 0; i < std::min(old_dim, static_cast<size_t>(ndim)); i++) {
322-
new_strides[i] = strides_[i];
323-
}
324-
for (size_t i = old_dim; i < static_cast<size_t>(ndim); i++) {
325-
// If ndim < old_dim, this loop never executes
326-
new_strides[i] = 0;
322+
if (old_dim != sizes_.size()) {
323+
auto new_strides = c10::guts::make_unique<int64_t[]>(ndim);
324+
for (size_t i = 0; i < std::min(old_dim, static_cast<size_t>(ndim)); i++) {
325+
new_strides[i] = strides_[i];
326+
}
327+
for (size_t i = old_dim; i < static_cast<size_t>(ndim); i++) {
328+
// If ndim < old_dim, this loop never executes
329+
new_strides[i] = 0;
330+
}
331+
strides_ = std::move(new_strides);
327332
}
328-
strides_ = std::move(new_strides);
329333
refresh_numel();
330334
refresh_contiguous();
331335
}
@@ -337,8 +341,6 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
337341
}
338342

339343
virtual void set_stride(int64_t dim, int64_t new_stride) {
340-
AT_ASSERTM(strides_, "Caffe2 tensors don't have meaningful strides and "
341-
"cannot be used in PyTorch");
342344
strides_[dim] = new_stride;
343345
refresh_numel();
344346
refresh_contiguous();
@@ -630,8 +632,9 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
630632
" The old caffe2 mixes Reshape and Resize but this behavior has "
631633
"been changed. If you find this error, most likely you will need "
632634
"to change corresponding code from Reshape to Resize.");
635+
auto old_dim = sizes_.size();
633636
sizes_ = dims;
634-
update_to_contiguous_strides();
637+
update_to_contiguous_strides(old_dim);
635638
}
636639

637640
/**
@@ -801,13 +804,14 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
801804
typename = typename std::enable_if<std::is_integral<T>::value>::type>
802805
bool SetDimsTemplate(at::ArrayRef<T> src) {
803806
auto old_numel = numel_;
807+
auto old_dim = sizes_.size();
804808
sizes_.resize(src.size());
805809
int64_t new_numel = 1;
806810
for (size_t i = 0; i < src.size(); ++i) {
807811
new_numel *= src[i];
808812
sizes_[i] = src[i];
809813
}
810-
update_to_contiguous_strides();
814+
update_to_contiguous_strides(old_dim);
811815
numel_ = new_numel;
812816
return numel_ != old_numel;
813817
}
@@ -844,8 +848,17 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
844848
return SetDims(at::IntList{d0, d1, d2, d3});
845849
}
846850

847-
inline void update_to_contiguous_strides() {
848-
strides_.reset();
851+
inline void update_to_contiguous_strides(size_t old_dim) {
852+
if (old_dim != sizes_.size()) {
853+
strides_ = c10::guts::make_unique<int64_t[]>(sizes_.size());
854+
}
855+
if (dim() > 0) {
856+
int last_idx = dim() - 1;
857+
strides_[last_idx] = 1;
858+
for (auto i = last_idx - 1; i >= 0; --i) {
859+
strides_[i] = strides_[i + 1] * std::max<int64_t>(sizes_[i + 1], 1);
860+
}
861+
}
849862
is_contiguous_ = true;
850863
}
851864

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include "gtest/gtest.h"
2+
#include "caffe2/core/tensor.h"
3+
4+
TEST(TensorImplTest, Caffe2Constructor) {
5+
caffe2::Tensor tensor(caffe2::CPU);
6+
ASSERT_EQ(tensor.strides()[0], 1);
7+
}

0 commit comments

Comments
 (0)