Skip to content

Commit 43a23a7

Browse files
ezyangpytorchmergebot
authored andcommitted
Support uint{16,32,64} copy (pytorch#116806)
Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#116806 Approved by: https://github.com/albanD ghstack dependencies: pytorch#116805
1 parent 2e983fc commit 43a23a7

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

aten/src/ATen/native/Copy.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <ATen/core/Tensor.h>
55
#include <ATen/Dispatch.h>
6+
#include <ATen/Dispatch_v2.h>
67
#include <ATen/ExpandUtils.h>
78
#include <ATen/FunctionalTensorWrapper.h>
89
#include <ATen/TensorIterator.h>
@@ -52,10 +53,9 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
5253

5354
#if !defined(C10_MOBILE)
5455
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
55-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
56-
kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \
57-
kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, TYPE, NAME, \
58-
__VA_ARGS__)
56+
AT_DISPATCH_V2( \
57+
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \
58+
kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
5959
#else
6060
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
6161
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

aten/src/ATen/native/cpu/CopyKernel.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#define TORCH_ASSERT_NO_OPERATORS
22
#include <ATen/Dispatch.h>
3+
#include <ATen/Dispatch_v2.h>
34
#include <ATen/native/Copy.h>
45
#include <ATen/native/UnaryOps.h>
56
#include <ATen/native/TensorIterator.h>
@@ -201,15 +202,14 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
201202

202203
#if !defined(C10_MOBILE)
203204
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
204-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
205-
ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, \
206-
ScalarType::BFloat16, ScalarType::Float8_e5m2, ScalarType::Float8_e4m3fn, \
207-
ScalarType::Float8_e5m2fnuz, ScalarType::Float8_e4m3fnuz, \
208-
TYPE, NAME, __VA_ARGS__)
205+
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
206+
kComplexHalf, kHalf, kBool, \
207+
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
208+
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
209209
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
210-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
210+
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
211211
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
212-
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, TYPE, NAME, __VA_ARGS__)
212+
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
213213
#else
214214
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
215215
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

test/test_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def test_tensor_storage_type(self, device, dtype):
238238
self.assertEqual(a.storage_type(), expected_storage_type)
239239

240240
@onlyNativeDeviceTypes
241-
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
241+
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64))
242242
def test_tensor_from_storage(self, device, dtype):
243243
a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
244244
a_s = a.storage()

0 commit comments

Comments
 (0)