@@ -33,27 +33,9 @@ inline at::Tensor scalar_to_tensor(
33
33
const Device device = at::kCPU ) {
34
34
// This is the fast track we have for CPU scalar tensors.
35
35
if (device == at::kCPU ) {
36
- if (s.isFloatingPoint ()) {
37
- return at::detail::scalar_tensor_static (s, at::kDouble , at::kCPU );
38
- } else if (s.isComplex ()) {
39
- return at::detail::scalar_tensor_static (s, at::kComplexDouble , at::kCPU );
40
- } else if (s.isBoolean ()) {
41
- return at::detail::scalar_tensor_static (s, at::kBool , at::kCPU );
42
- } else {
43
- AT_ASSERT (s.isIntegral (false ));
44
- return at::detail::scalar_tensor_static (s, at::kLong , at::kCPU );
45
- }
46
- }
47
- if (s.isFloatingPoint ()) {
48
- return at::scalar_tensor (s, at::device (device).dtype (at::kDouble ));
49
- } else if (s.isBoolean ()) {
50
- return at::scalar_tensor (s, at::device (device).dtype (at::kBool ));
51
- } else if (s.isComplex ()) {
52
- return at::scalar_tensor (s, at::device (device).dtype (at::kComplexDouble ));
53
- } else {
54
- AT_ASSERT (s.isIntegral (false ));
55
- return at::scalar_tensor (s, at::device (device).dtype (at::kLong ));
36
+ return at::detail::scalar_tensor_static (s, s.type (), at::kCPU );
56
37
}
38
+ return at::scalar_tensor (s, at::device (device).dtype (s.type ()));
57
39
}
58
40
59
41
} // namespace c10
0 commit comments