Skip to content

Commit 2e983fc

Browse files
ezyangpytorchmergebot
authored andcommitted
Support unsigned int for randint, item, equality, fill, iinfo, tensor (pytorch#116805)
These are some basic utilities that are often used for testing. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#116805 Approved by: https://github.com/albanD
1 parent 4a10e9e commit 2e983fc

22 files changed

+137
-65
lines changed

aten/src/ATen/Dispatch_v2.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,18 @@
8585
// to be interpreted as being multiple arguments
8686
#define AT_WRAP(...) __VA_ARGS__
8787

88-
#define AT_FLOAT8_TYPES \
89-
kFloat8_e5m2, kFloat8_e5m2fnuz, kFloat8_e4m3fn, kFloat8_e4m3fnuz
90-
91-
#define AT_INTEGRAL_TYPES kByte, kChar, kInt, kLong, kShort
92-
#define AT_FLOATING_TYPES kDouble, kFloat
93-
#define AT_BAREBONES_UNSIGNED_TYPES kUInt16, kUInt32, kUInt64
94-
#define AT_COMPLEX_TYPES kComplexDouble, kComplexFloat
95-
#define AT_QINT_TYPES kQInt8, kQUInt8, kQInt32
88+
#define AT_FLOAT8_TYPES \
89+
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
90+
c10::kFloat8_e4m3fnuz
91+
92+
#define AT_INTEGRAL_TYPES \
93+
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
94+
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
95+
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
96+
#define AT_INTEGRAL_TYPES_V2 \
97+
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
98+
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
99+
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
96100
// NB: not *actually* all types
97101
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
98102
#define AT_ALL_TYPES_AND_COMPLEX \

aten/src/ATen/native/DistributionTemplates.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <ATen/core/Tensor.h>
44
#include <ATen/Dispatch.h>
5+
#include <ATen/Dispatch_v2.h>
56
#include <ATen/Generator.h>
67
#include <ATen/ExpandUtils.h>
78
#include <ATen/Tensor.h>
@@ -110,13 +111,21 @@ static void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMet
110111
WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
111112
WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
112113
});
114+
} else if (scalar_type == kUInt64) {
115+
// When you do a comparison between int64_t and uint64_t, the usual
116+
// arithmetic conversions say that the int64_t value is promoted to
117+
// unsigned. But this conversion wraps around: if I had -1 as my int64_t,
118+
// then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
119+
// the right thing to do.
120+
CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
121+
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
113122
} else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
114-
AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, scalar_type, "check_random_integral_bounds", [&]() {
123+
AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
115124
const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
116125
const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
117126
CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
118127
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
119-
});
128+
}), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
120129
} else {
121130
TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
122131
}
@@ -152,13 +161,13 @@ at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, c10::optional<in
152161
TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
153162
});
154163
} else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
155-
AT_DISPATCH_INTEGRAL_TYPES_AND(at::ScalarType::Bool, self.scalar_type(), "random_from_to_range_calc", [&] {
164+
AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
156165
if constexpr (std::is_same_v<scalar_t, bool>) {
157166
to_inc = static_cast<int64_t>(true);
158167
} else {
159168
to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
160169
}
161-
});
170+
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
162171
} else {
163172
TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
164173
}

aten/src/ATen/native/ReduceOpsUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ static inline void check_scalar_type_device_layout_equal(const Tensor& out, cons
104104

105105
static inline Tensor integer_upcast(const Tensor& self, c10::optional<ScalarType> dtype) {
106106
ScalarType scalarType = self.scalar_type();
107+
TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
107108
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
108109
return self.toType(upcast_scalarType);
109110
}

aten/src/ATen/native/Scalar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Scalar item(const Tensor& self) {
2727
}
2828
}
2929

30-
#define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf, kBool, kBFloat16
30+
#define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
3131
#if !defined(C10_MOBILE)
3232
#define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES), AT_EXPAND(AT_FLOAT8_TYPES)
3333
#else

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cmath>
55

66
#include <ATen/Dispatch.h>
7+
#include <ATen/Dispatch_v2.h>
78
#include <ATen/OpMathType.h>
89
#include <ATen/Parallel.h>
910
#include <ATen/cpu/vec/functional.h>
@@ -81,31 +82,32 @@ void atan2_kernel(TensorIteratorBase& iter) {
8182

8283
#if !defined(C10_MOBILE)
8384
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
84-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
85+
AT_DISPATCH_V2( \
86+
TYPE, \
87+
NAME, \
88+
AT_WRAP(__VA_ARGS__), \
8589
kComplexHalf, \
8690
kHalf, \
8791
kBool, \
8892
kBFloat16, \
8993
kFloat8_e5m2, \
9094
kFloat8_e5m2fnuz, \
9195
kFloat8_e4m3fn, \
92-
kFloat8_e4m3fnuz, \
93-
TYPE, \
94-
NAME, \
95-
__VA_ARGS__)
96+
kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
9697
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
97-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
98+
AT_DISPATCH_V2( \
99+
TYPE, \
100+
NAME, \
101+
AT_WRAP(__VA_ARGS__), \
98102
kComplexHalf, \
99103
kHalf, \
100104
kBFloat16, \
101105
kFloat8_e5m2, \
102106
kFloat8_e4m3fn, \
103-
TYPE, \
104-
NAME, \
105-
__VA_ARGS__)
107+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
106108
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
107-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
108-
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, TYPE, NAME, __VA_ARGS__)
109+
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
110+
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
109111
#else
110112
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
111113
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

aten/src/ATen/native/cpu/DistributionTemplates.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <ATen/CPUApplyUtils.h>
44
#include <ATen/Dispatch.h>
5+
#include <ATen/Dispatch_v2.h>
56
#include <ATen/ExpandBase.h>
67
#include <ATen/core/DistributionsHelper.h>
78
#include <ATen/native/TensorIterator.h>
@@ -25,13 +26,13 @@ namespace {
2526

2627
template<typename RNG>
2728
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
28-
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cpu", [&] {
29+
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cpu", AT_WRAP([&] {
2930
std::lock_guard<std::mutex> lock(generator->mutex_);
3031
cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
3132
uniform_int_from_to_distribution<scalar_t> random(range, base);
3233
return random(generator);
3334
});
34-
});
35+
}), kBool, kHalf, kBFloat16, AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
3536
}
3637

3738
// This is the special kernel to handle single specific case:

aten/src/ATen/native/cpu/FillKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
5252
[=]() -> scalar_t { return value; },
5353
[=]() { return Vectorized<scalar_t>(value); });
5454
}),
55-
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kBool
55+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kBool, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
5656
);
5757
}
5858
}

aten/src/ATen/native/cuda/CUDAScalar.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/core/Tensor.h>
3-
#include <ATen/Dispatch.h>
3+
#include <ATen/Dispatch_v2.h>
44

55
#ifndef AT_PER_OPERATOR_HEADERS
66
#include <ATen/NativeFunctions.h>
@@ -14,13 +14,13 @@ namespace at::native {
1414

1515
Scalar _local_scalar_dense_cuda(const Tensor& self) {
1616
Scalar r;
17-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
18-
kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
17+
AT_DISPATCH_V2(
18+
self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] {
1919
scalar_t value;
2020
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
2121
at::cuda::memcpy_and_sync(&value, self.const_data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
2222
r = Scalar(value);
23-
});
23+
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
2424
return r;
2525
}
2626

aten/src/ATen/native/cuda/CompareEQKernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#define TORCH_ASSERT_NO_OPERATORS
22
#include <ATen/Dispatch.h>
3+
#include <ATen/Dispatch_v2.h>
34
#include <ATen/native/BinaryOps.h>
45
#include <ATen/native/DispatchStub.h>
56
#include <ATen/native/TensorIterator.h>
@@ -29,11 +30,10 @@ struct CompareEqFunctor{
2930
}
3031

3132
C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) {
32-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2,
33-
iter.common_dtype(), "compare_eq_ne_cuda", [&]() {
33+
AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() {
3434
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
3535
iter, CompareEqFunctor<scalar_t>(op));
36-
});
36+
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
3737
}
3838

3939
void eq_kernel_cuda(TensorIteratorBase& iter) {

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/core/Tensor.h>
33
#include <ATen/Context.h>
44
#include <ATen/Dispatch.h>
5+
#include <ATen/Dispatch_v2.h>
56
#include <ATen/cuda/CachingHostAllocator.h>
67
#include <ATen/cuda/CUDAContext.h>
78
#include <ATen/cuda/CUDAEvent.h>
@@ -98,6 +99,8 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
9899
}
99100
}
100101

102+
// TODO: We probably can use the opaque type trick to avoid creating duplicate
103+
// kernels for equivalent bit lengths
101104
void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
102105
ScalarType dtype = iter.dtype(0);
103106
if (isQIntType(dtype)) {
@@ -115,10 +118,10 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
115118
});
116119
#endif /* !defined(USE_ROCM) */
117120
} else {
118-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
119-
kHalf, kBool, kBFloat16, kComplexHalf,dtype, "copy_", [&] {
121+
AT_DISPATCH_V2(
122+
dtype, "copy_", AT_WRAP([&] {
120123
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
121-
});
124+
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kHalf, kBool, kBFloat16, kComplexHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
122125
}
123126
}
124127

aten/src/ATen/native/cuda/DistributionTemplates.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <ATen/AccumulateType.h>
44
#include <ATen/Dispatch.h>
5+
#include <ATen/Dispatch_v2.h>
56
#include <ATen/ExpandBase.h>
67
#include <ATen/OpMathType.h>
78
#include <ATen/native/TensorIterator.h>
@@ -285,7 +286,7 @@ namespace cuda {
285286

286287
template<typename RNG>
287288
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
288-
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cuda", [&] {
289+
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
289290
if ((
290291
std::is_same<scalar_t, int64_t>::value ||
291292
std::is_same<scalar_t, double>::value ||
@@ -317,7 +318,7 @@ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t bas
317318
},
318319
random_func);
319320
}
320-
});
321+
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
321322
}
322323

323324
// This is the special kernel to handle single specific case:

aten/src/ATen/native/cuda/FillKernel.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#define TORCH_ASSERT_NO_OPERATORS
22
#include <ATen/Dispatch.h>
3+
#include <ATen/Dispatch_v2.h>
34
#include <ATen/native/cuda/Loops.cuh>
45
#include <ATen/native/DispatchStub.h>
56
#include <ATen/native/TensorIterator.h>
@@ -19,9 +20,9 @@ struct FillFunctor {
1920
};
2021

2122
void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) {
22-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, iter.dtype(), "fill_cuda", [&]() {
23+
AT_DISPATCH_V2(iter.dtype(), "fill_cuda", AT_WRAP([&]() {
2324
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>()));
24-
});
25+
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
2526
}
2627

2728
REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda);

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/core/Tensor.h>
88
#include <ATen/ceil_div.h>
99
#include <ATen/Dispatch.h>
10+
#include <ATen/Dispatch_v2.h>
1011
#include <ATen/ExpandUtils.h>
1112
#include <ATen/MemoryOverlap.h>
1213
#include <ATen/TensorOperators.h>
@@ -1481,14 +1482,16 @@ Tensor& index_select_out_cuda(
14811482
index_select_out_cuda_impl<scalar_t>(out, self, dim, index);
14821483
});
14831484
} else {
1484-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
1485-
at::ScalarType::ComplexHalf,
1486-
at::ScalarType::Half,
1487-
at::ScalarType::Bool,
1488-
at::ScalarType::BFloat16,
1485+
AT_DISPATCH_V2(
14891486
out.scalar_type(),
14901487
"index_select_cuda",
1491-
[&] { index_select_out_cuda_impl<scalar_t>(out, self, dim, index); });
1488+
AT_WRAP([&] { index_select_out_cuda_impl<scalar_t>(out, self, dim, index); }),
1489+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
1490+
kComplexHalf,
1491+
kHalf,
1492+
kBool,
1493+
kBFloat16
1494+
);
14921495
}
14931496

14941497
return out;

aten/src/ATen/native/cuda/Shape.cu

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <ATen/native/TypeProperties.h>
1010
#include <ATen/native/TensorShape.h>
1111
#include <ATen/Dispatch.h>
12+
#include <ATen/Dispatch_v2.h>
1213
#include <c10/core/MemoryFormat.h>
1314
#include <c10/util/Optional.h>
1415

@@ -431,12 +432,10 @@ TORCH_IMPL_FUNC(cat_out_cuda)
431432
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
432433
});
433434
} else {
434-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
435-
kComplexHalf, kHalf, kBool, kBFloat16,
436-
result.scalar_type(), "cat_cuda", [&]() {
435+
AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
437436
using dtype = OpaqueType<sizeof(scalar_t)>;
438437
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
439-
});
438+
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
440439
}
441440
} else if (materialized.size() > 1 &&
442441
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
@@ -451,12 +450,10 @@ TORCH_IMPL_FUNC(cat_out_cuda)
451450
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
452451
});
453452
} else {
454-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
455-
kComplexHalf, kHalf, kBool, kBFloat16,
456-
result.scalar_type(), "cat_cuda", [&]() {
453+
AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() {
457454
using dtype = OpaqueType<sizeof(scalar_t)>;
458455
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
459-
});
456+
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
460457
}
461458
} else {
462459
int64_t offset = 0;

c10/core/DynamicCast.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ C10_HOST_DEVICE inline dest_t fetch_and_cast(
7171
const void* ptr) {
7272
switch (src_type) {
7373
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE)
74+
FETCH_AND_CAST_CASE(uint16_t, UInt16)
75+
FETCH_AND_CAST_CASE(uint32_t, UInt32)
76+
FETCH_AND_CAST_CASE(uint64_t, UInt64)
7477
default:
7578
ERROR_UNSUPPORTED_CAST
7679
}
@@ -90,6 +93,9 @@ C10_HOST_DEVICE inline void cast_and_store(
9093
src_t value) {
9194
switch (dest_type) {
9295
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE)
96+
CAST_AND_STORE_CASE(uint16_t, UInt16)
97+
CAST_AND_STORE_CASE(uint32_t, UInt32)
98+
CAST_AND_STORE_CASE(uint64_t, UInt64)
9399
default:;
94100
}
95101
ERROR_UNSUPPORTED_CAST

0 commit comments

Comments
 (0)