@@ -166,6 +166,8 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
166
166
explicit TensorImpl (at::Storage storage) : storage_(std::move(storage)), storage_offset_(0 ) {
167
167
AT_ASSERT (storage_);
168
168
data_type_ = storage_.dtype ();
169
+ strides_.reset (new int64_t [1 ]);
170
+ strides_[0 ] = 1 ;
169
171
}
170
172
171
173
TensorImpl (const TensorImpl&) = default ;
@@ -317,15 +319,17 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
317
319
// assumes that old values are preserved
318
320
auto old_dim = sizes_.size ();
319
321
sizes_.resize (ndim);
320
- auto new_strides = c10::guts::make_unique<int64_t []>(ndim);
321
- for (size_t i = 0 ; i < std::min (old_dim, static_cast <size_t >(ndim)); i++) {
322
- new_strides[i] = strides_[i];
323
- }
324
- for (size_t i = old_dim; i < static_cast <size_t >(ndim); i++) {
325
- // If ndim < old_dim, this loop never executes
326
- new_strides[i] = 0 ;
322
+ if (old_dim != sizes_.size ()) {
323
+ auto new_strides = c10::guts::make_unique<int64_t []>(ndim);
324
+ for (size_t i = 0 ; i < std::min (old_dim, static_cast <size_t >(ndim)); i++) {
325
+ new_strides[i] = strides_[i];
326
+ }
327
+ for (size_t i = old_dim; i < static_cast <size_t >(ndim); i++) {
328
+ // If ndim < old_dim, this loop never executes
329
+ new_strides[i] = 0 ;
330
+ }
331
+ strides_ = std::move (new_strides);
327
332
}
328
- strides_ = std::move (new_strides);
329
333
refresh_numel ();
330
334
refresh_contiguous ();
331
335
}
@@ -337,8 +341,6 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
337
341
}
338
342
339
343
virtual void set_stride (int64_t dim, int64_t new_stride) {
340
- AT_ASSERTM (strides_, " Caffe2 tensors don't have meaningful strides and "
341
- " cannot be used in PyTorch" );
342
344
strides_[dim] = new_stride;
343
345
refresh_numel ();
344
346
refresh_contiguous ();
@@ -630,8 +632,9 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
630
632
" The old caffe2 mixes Reshape and Resize but this behavior has "
631
633
" been changed. If you find this error, most likely you will need "
632
634
" to change corresponding code from Reshape to Resize." );
635
+ auto old_dim = sizes_.size ();
633
636
sizes_ = dims;
634
- update_to_contiguous_strides ();
637
+ update_to_contiguous_strides (old_dim );
635
638
}
636
639
637
640
/* *
@@ -801,13 +804,14 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
801
804
typename = typename std::enable_if<std::is_integral<T>::value>::type>
802
805
bool SetDimsTemplate (at::ArrayRef<T> src) {
803
806
auto old_numel = numel_;
807
+ auto old_dim = sizes_.size ();
804
808
sizes_.resize (src.size ());
805
809
int64_t new_numel = 1 ;
806
810
for (size_t i = 0 ; i < src.size (); ++i) {
807
811
new_numel *= src[i];
808
812
sizes_[i] = src[i];
809
813
}
810
- update_to_contiguous_strides ();
814
+ update_to_contiguous_strides (old_dim );
811
815
numel_ = new_numel;
812
816
return numel_ != old_numel;
813
817
}
@@ -844,8 +848,17 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
844
848
return SetDims (at::IntList{d0, d1, d2, d3});
845
849
}
846
850
847
- inline void update_to_contiguous_strides () {
848
- strides_.reset ();
851
+ inline void update_to_contiguous_strides (size_t old_dim) {
852
+ if (old_dim != sizes_.size ()) {
853
+ strides_ = c10::guts::make_unique<int64_t []>(sizes_.size ());
854
+ }
855
+ if (dim () > 0 ) {
856
+ int last_idx = dim () - 1 ;
857
+ strides_[last_idx] = 1 ;
858
+ for (auto i = last_idx - 1 ; i >= 0 ; --i) {
859
+ strides_[i] = strides_[i + 1 ] * std::max<int64_t >(sizes_[i + 1 ], 1 );
860
+ }
861
+ }
849
862
is_contiguous_ = true ;
850
863
}
851
864
0 commit comments