Skip to content

Commit 7c1d5db

Browse files
cyyeverpytorchmergebot
authored andcommitted
[2/N] Enable UBSAN tests (pytorch#141740)
Apply c10::load in more places. The function was introduced to cast a byte to valid boolean values, thus fixing the UBSAN errors. Pull Request resolved: pytorch#141740 Approved by: https://github.com/ezyang
1 parent 28efc17 commit 7c1d5db

12 files changed

+45
-52
lines changed

aten/src/ATen/native/TensorAdvancedIndexing.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1047,7 +1047,7 @@ TORCH_IMPL_FUNC(index_add_cpu_out)
10471047
auto self_i = index_data[i];
10481048
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
10491049
scalar_t *self_ip = result_ptr + self_i * result_stride;
1050-
*self_ip += *(source_ptr + i * source_stride) * alpha_value;
1050+
*self_ip += c10::load(source_ptr + i * source_stride) * alpha_value;
10511051
}
10521052
});
10531053
});

aten/src/ATen/native/TriangularOps.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void apply_triu_tril_single(
6161
}
6262
if (!inplace) { // copy the rest of the self if not inplace
6363
for (int64_t j = std::max(zero, i + k); j < m; j++) {
64-
result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
64+
result[i * res_row_stride + j * res_col_stride] = c10::load(&self[i * self_row_stride + j * self_col_stride]);
6565
}
6666
}
6767
}
@@ -74,7 +74,7 @@ void apply_triu_tril_single(
7474
}
7575
if (!inplace) { // copy the rest of the self if not inplace
7676
for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
77-
result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
77+
result[i * res_row_stride + j * res_col_stride] = c10::load(&self[i * self_row_stride + j * self_col_stride]);
7878
}
7979
}
8080
}

aten/src/ATen/native/cpu/ChannelShuffleKernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void cpu_channel_shuffle(
4848
data_vec.store(output_ptr + d);
4949
}
5050
for (; d < image_size; d++) {
51-
output_ptr[d] = input_ptr[d];
51+
output_ptr[d] = c10::load(&(input_ptr[d]));
5252
}
5353

5454
// move on to next output index

aten/src/ATen/native/cpu/IndexKernel.cpp

+19-18
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef
2525
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
2626
iter.dtype(), "index_cpu", [&] {
2727
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
28-
*(scalar_t*)dst = *(scalar_t*)(src + offset);
28+
*(scalar_t*)dst = c10::load((scalar_t*)(src + offset));
2929
});
3030
});
3131
}
@@ -128,14 +128,14 @@ void put_kernel(
128128
// Unlike the non-accumulate case, this needs to be thread-safe.
129129
cpu_take_put_kernel<scalar_t>(iter, self, true,
130130
[](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
131-
indexed[idx] += iterated;
131+
indexed[idx] += c10::load(&iterated);
132132
},
133133
/*serial_execution=*/true);
134134
}
135135
} else {
136136
cpu_take_put_kernel<scalar_t>(iter, self, true,
137137
[](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
138-
indexed[idx] = iterated;
138+
indexed[idx] = c10::load(&iterated);
139139
});
140140
}
141141
});
@@ -148,7 +148,7 @@ void take_kernel(
148148
iter.dtype(), "take_cpu", [&] {
149149
cpu_take_put_kernel<scalar_t>(iter, input, false,
150150
[](scalar_t& iterated, const scalar_t* indexed, const int64_t idx) {
151-
iterated = indexed[idx];
151+
iterated = c10::load(&(indexed[idx]));
152152
});
153153
});
154154
}
@@ -174,12 +174,12 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
174174
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
175175
// this needs to be thread-safe.
176176
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
177-
*(scalar_t*)(dst + offset) += *(scalar_t*)src;
177+
*(scalar_t*)(dst + offset) += c10::load(reinterpret_cast<scalar_t*>(src));
178178
}, /*serial_execution=*/true);
179179
}
180180
} else {
181181
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
182-
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
182+
*(scalar_t*)(dst + offset) = c10::load(reinterpret_cast<scalar_t*>(src));
183183
}, /*serial_execution=*/is_deterministic);
184184
}
185185
}),
@@ -270,7 +270,7 @@ void index_copy_kernel(
270270
"index_copy_(): index ", idx, " is out of bounds for dimension ",
271271
dim, " with size ", self_dim_size);
272272

273-
self_data[idx * self_dim_stride] = *source_data;
273+
self_data[idx * self_dim_stride] = c10::load(source_data);
274274

275275
self_data_bytes += strides[0];
276276
index_data_bytes += strides[1];
@@ -289,7 +289,7 @@ void index_copy_kernel(
289289
auto* self_data = reinterpret_cast<scalar_t*>(self_data_bytes);
290290
auto* source_data = reinterpret_cast<scalar_t*>(source_data_bytes);
291291

292-
self_data[idx * self_dim_stride] = *source_data;
292+
self_data[idx * self_dim_stride] = c10::load(source_data);
293293

294294
self_data_bytes += strides[0];
295295
source_data_bytes += strides[2];
@@ -320,7 +320,7 @@ void cpu_masked_fill_kernel(TensorIterator& iter, scalar_t value) {
320320
char* dst = data[0];
321321
char* mask = data[1];
322322
for (const auto i : c10::irange(n)) {
323-
bool mask_value = *reinterpret_cast<bool*>(mask + strides[1] * i);
323+
bool mask_value = c10::load(reinterpret_cast<bool*>(mask + strides[1] * i));
324324

325325
if (mask_value) {
326326
*(scalar_t*)(dst + strides[0] * i) = value;
@@ -353,10 +353,11 @@ void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
353353
char* mask = data[1];
354354
const int64_t mask_stride = strides[1];
355355
for (const auto i : c10::irange(n)) {
356-
auto mask_value = *reinterpret_cast<bool*>(mask + mask_stride * i);
356+
auto mask_value = c10::load(reinterpret_cast<bool*>(mask + mask_stride * i));
357+
357358
if (mask_value) {
358359
TORCH_CHECK(source_cntr < numel, "Number of elements of source < number of ones in mask");
359-
*(scalar_t*)(dst + dst_stride * i) = *(source_ptr);
360+
*(scalar_t*)(dst + dst_stride * i) = c10::load(source_ptr);
360361
source_ptr++;
361362
source_cntr++;
362363
}
@@ -387,7 +388,7 @@ void cpu_masked_select_serial_kernel(TensorIterator& iter, const func_t& f) {
387388
char* src = data[1];
388389
char* mask = data[2];
389390
for (const auto i : c10::irange(n)) {
390-
mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
391+
mask_t mask_value = c10::load((mask_t*)(mask + strides[2] * i));
391392
if constexpr (!std::is_same_v<mask_t, bool>) {
392393
TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
393394
}
@@ -406,11 +407,11 @@ void masked_select_serial_kernel(TensorIterator& iter, int64_t result_stride) {
406407
auto mask_dtype = iter.input_dtype(1);
407408
if (mask_dtype == ScalarType::Bool) {
408409
cpu_masked_select_serial_kernel<scalar_t, bool>(iter, [result_stride](char* dst, char* src, int64_t offset) {
409-
*(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
410+
*(scalar_t*)(dst + offset*result_stride) = c10::load((scalar_t*)src);
410411
});
411412
} else {
412413
cpu_masked_select_serial_kernel<scalar_t, unsigned char>(iter, [result_stride](char* dst, char* src, int64_t offset) {
413-
*(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
414+
*(scalar_t*)(dst + offset*result_stride) = c10::load((scalar_t*)src);
414415
});
415416
}
416417
}),
@@ -430,7 +431,7 @@ void cpu_masked_select_kernel(TensorIterator& iter, const func_t& f) {
430431
char* mask = data[2];
431432
char* mask_prefix_sum = data[3];
432433
for (const auto i : c10::irange(n)) {
433-
mask_t mask_value = *(mask_t*)(mask + strides[2] * i);
434+
mask_t mask_value = c10::load((mask_t*)(mask + strides[2] * i));
434435
if constexpr (!std::is_same_v<mask_t, bool>) {
435436
TORCH_CHECK(mask_value == 0 || mask_value == 1, "Mask tensor can take 0 and 1 values only");
436437
}
@@ -449,7 +450,7 @@ void masked_select_kernel(TensorIterator& iter, int64_t result_stride) {
449450
auto mask_dtype = iter.input_dtype(1);
450451
if (mask_dtype == ScalarType::Bool) {
451452
cpu_masked_select_kernel<scalar_t, bool>(iter, [result_stride](char* dst, char* src, int64_t offset) {
452-
*(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src;
453+
*(scalar_t*)(dst + offset*result_stride) = c10::load((scalar_t*)src);
453454
});
454455
} else {
455456
cpu_masked_select_kernel<scalar_t, unsigned char>(iter, [result_stride](char* dst, char* src, int64_t offset) {
@@ -501,7 +502,7 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
501502
offset = (offset >= n) ? n : offset;
502503
for (; i < offset; i++) {
503504
scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
504-
*out_ptr = *(scalar_t *)(data[1] + i * stride);
505+
*out_ptr = c10::load((scalar_t *)(data[1] + i * stride));
505506
}
506507
// Empirically found that it is faster to process 3 data items together vs 2 or 4
507508
for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) {
@@ -519,7 +520,7 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
519520
if (i < n) {
520521
for (; i < n; i++) {
521522
scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
522-
*out_ptr = *(scalar_t *)(data[1] + i * stride);
523+
*out_ptr = c10::load((scalar_t *)(data[1] + i * stride));
523524
}
524525
}
525526

aten/src/ATen/native/cpu/Loops.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, ve
208208
data[arg] = data_[arg];
209209
}
210210

211-
Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
211+
Vec opt_scalar = Vec(S > 0 ? c10::load((scalar_t*)data[S]) : scalar_t(0));
212212
int64_t i = 0;
213213
for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
214214
auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);

aten/src/ATen/native/cpu/PixelShuffleKernel.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void cpu_pixel_shuffle(
4545
for (const auto i : c10::irange(begin, end)) {
4646
int64_t input_offset = n * stride_n + c * stride_c + s1 * stride_s1 +
4747
s2 * stride_s2 + h * stride_h + w;
48-
output_data[i] = input_data[input_offset];
48+
output_data[i] = c10::load(&input_data[input_offset]);
4949

5050
data_index_step(n, nbatch, c, sub_channels, h, height, s1, S, w, width, s2, S);
5151
}
@@ -144,7 +144,7 @@ void cpu_pixel_unshuffle(
144144
for (const auto i : c10::irange(begin, end)) {
145145
int64_t input_offset = n * stride_n + c * stride_c + h * stride_h +
146146
s1 * stride_s1 + w * stride_w + s2 * stride_s2;
147-
output_data[i] = input_data[input_offset];
147+
output_data[i] = c10::load(&input_data[input_offset]);
148148

149149
data_index_step(n, nbatch, c, sub_channels, s1, S, s2, S, h, height, w, width);
150150
}
@@ -186,7 +186,7 @@ void cpu_pixel_unshuffle_channels_last(
186186
for (const auto i : c10::irange(begin, end)) {
187187
int64_t input_offset = n * stride_n + h * stride_h + s1 * stride_s1 +
188188
w * stride_w + s2 * stride_s2 + c * stride_c;
189-
output_data[i] = input_data[input_offset];
189+
output_data[i] = c10::load(&input_data[input_offset]);
190190

191191
data_index_step(n, nbatch, h, height, w, width, c, sub_channels, s1, S, s2, S);
192192
}

aten/src/ATen/native/cpu/ScatterGatherKernel.cpp

+11-7
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ class ReduceMultiply {
3434
template <typename scalar_t>
3535
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
3636
using opmath_t = at::opmath_type<scalar_t>;
37-
*self_data *= opmath_t(*src_data);
37+
*self_data *= opmath_t(c10::load(src_data));
3838
}
3939

4040
constexpr void operator() (bool * self_data, bool * src_data) const {
41-
*self_data = *self_data && *src_data;
41+
*self_data = c10::load(self_data) && c10::load(src_data);
4242
}
4343
};
4444
static ReduceMultiply reduce_multiply;
@@ -48,7 +48,7 @@ class ReduceAdd {
4848
template <typename scalar_t>
4949
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
5050
using opmath_t = at::opmath_type<scalar_t>;
51-
*self_data += opmath_t(*src_data);
51+
*self_data += opmath_t(c10::load(src_data));
5252
}
5353
};
5454
static ReduceAdd reduce_add;
@@ -58,7 +58,7 @@ class ReduceMean {
5858
template <typename scalar_t>
5959
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
6060
using opmath_t = at::opmath_type<scalar_t>;
61-
*self_data += opmath_t(*src_data);
61+
*self_data += opmath_t(c10::load(src_data));
6262
}
6363
};
6464
static ReduceMean reduce_mean;
@@ -68,7 +68,9 @@ class ReduceMaximum {
6868
template <typename scalar_t>
6969
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
7070
using opmath_t = at::opmath_type<scalar_t>;
71-
*self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::max(*self_data, opmath_t(*src_data));
71+
auto self_value = c10::load(self_data);
72+
auto src_value = c10::load(src_data);
73+
*self_data = at::_isnan<scalar_t>(src_value) ? opmath_t(src_value) : std::max(self_value, opmath_t(src_value));
7274
}
7375
};
7476
static ReduceMaximum reduce_maximum;
@@ -78,7 +80,9 @@ class ReduceMinimum {
7880
template <typename scalar_t>
7981
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
8082
using opmath_t = at::opmath_type<scalar_t>;
81-
*self_data = at::_isnan<scalar_t>(*src_data) ? opmath_t(*src_data) : std::min(*self_data, opmath_t(*src_data));
83+
auto self_value = c10::load(self_data);
84+
auto src_value = c10::load(src_data);
85+
*self_data = at::_isnan<scalar_t>(src_value) ? opmath_t(src_value) : std::min(self_value, opmath_t(src_value));
8286
}
8387
};
8488
static ReduceMinimum reduce_minimum;
@@ -88,7 +92,7 @@ class TensorAssign {
8892
template <typename scalar_t>
8993
constexpr void operator() (at::opmath_type<scalar_t> * self_data, scalar_t * src_data) const {
9094
using opmath_t = at::opmath_type<scalar_t>;
91-
*self_data = opmath_t(*src_data);
95+
*self_data = opmath_t(c10::load(src_data));
9296
}
9397
};
9498
static TensorAssign tensor_assign;

aten/src/ATen/native/cpu/TensorCompareKernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ static void min_kernel_impl(
115115
scalar_t min_number = c10::load(self_data);
116116
int64_t index = 0;
117117
for (const auto i : c10::irange(self_dim_size)) {
118-
scalar_t value = self_data[i * self_dim_stride];
118+
scalar_t value = c10::load(&self_data[i * self_dim_stride]);
119119
if (!(zabs_(value) >= zabs_(min_number))) {
120120
min_number = value;
121121
index = i;

aten/src/ATen/native/cpu/utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ template <typename T>
148148
inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
149149
for (int64_t j = 0; j < N; j++) {
150150
for (int64_t i = 0; i < M; i++) {
151-
dst[j * ld_dst + i] = src[i * ld_src + j];
151+
dst[j * ld_dst + i] = c10::load(&(src[i * ld_src + j]));
152152
}
153153
}
154154
}

aten/src/ATen/native/im2col.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ static void im2col(
7171
int64_t w_im = w_col * stride_w - pad_w + w_offset * dilation_w;
7272
data_col[(c_col * height_col + h_col) * width_col + w_col] =
7373
(h_im >= 0 && w_im >= 0 && h_im < height && w_im < width)
74-
? data_im[(c_im * height + h_im) * width + w_im]
74+
? c10::load(&(data_im[(c_im * height + h_im) * width + w_im]))
7575
: static_cast<T>(0);
7676
}
7777
}

c10/util/Load.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ struct LoadImpl<bool> {
2626
} // namespace detail
2727

2828
template <typename T>
29-
C10_HOST_DEVICE T load(const void* src) {
29+
C10_HOST_DEVICE constexpr T load(const void* src) {
3030
return c10::detail::LoadImpl<T>::apply(src);
3131
}
3232

3333
template <typename scalar_t>
34-
C10_HOST_DEVICE scalar_t load(const scalar_t* src) {
34+
C10_HOST_DEVICE constexpr scalar_t load(const scalar_t* src) {
3535
return c10::detail::LoadImpl<scalar_t>::apply(src);
3636
}
3737

test/test_ops.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
TEST_WITH_ROCM,
7272
TEST_WITH_TORCHDYNAMO,
7373
TEST_WITH_TORCHINDUCTOR,
74-
TEST_WITH_UBSAN,
7574
TestCase,
7675
unMarkDynamoStrictTest,
7776
)
@@ -653,16 +652,6 @@ def test_python_ref_executor(self, device, dtype, op, executor):
653652
and dtype == torch.float16
654653
):
655654
self.skipTest("Skipped on ROCm")
656-
# skip zero-dim tensors for some composites of reduction operations and view
657-
skip_zero_dim_ops = [
658-
"_refs.logsumexp",
659-
"_refs.log_softmax",
660-
"_refs.native_group_norm",
661-
"_refs.softmax",
662-
"_refs.sum_to_size",
663-
"ops.nvprims.view",
664-
]
665-
666655
from copy import copy
667656

668657
from torch._prims.executor import make_traced
@@ -1050,7 +1039,7 @@ def _case_zero_transform(t):
10501039
try:
10511040
info = torch.iinfo(t.dtype)
10521041
return torch.full_like(t, info.max)
1053-
except TypeError as te:
1042+
except TypeError:
10541043
# for non-integer types fills with NaN
10551044
return torch.full_like(t, float("nan"))
10561045

@@ -1445,7 +1434,6 @@ def test_complex_half_reference_testing(self, device, dtype, op):
14451434
self.assertEqual(actual, expected, exact_dtype=False)
14461435

14471436
@ops(op_db, allowed_dtypes=(torch.bool,))
1448-
@unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior")
14491437
def test_non_standard_bool_values(self, device, dtype, op):
14501438
# Test boolean values other than 0x00 and 0x01 (gh-54789)
14511439
def convert_boolean_tensors(x):
@@ -2754,7 +2742,7 @@ def map_to_fake(e):
27542742

27552743
try:
27562744
op(input, *args, **kwargs)
2757-
except Exception as e:
2745+
except Exception:
27582746
continue
27592747

27602748
with TestPointwiseMode():

0 commit comments

Comments
 (0)