Skip to content

Commit 1d57a2d

Browse files
authored
[ATen][Scalars] Remove Scalar from return types of functions. (pytorch#3557)
* Add direct C-type scalar conversions from Tensor, e.g. toCFloat() as an alias for Scalar(x).toFloat() * Provide tensor overloads for fill_, masked_fill_, index_fill_. * Everythign up to scalar overload. * Fix pytorch build for aten scalar return type changes. * Use valid expression instead of dangling else. * Simplify code generation. * Fix test_jit (why didn't this compile locally?)
1 parent 22d1e37 commit 1d57a2d

File tree

12 files changed

+145
-66
lines changed

12 files changed

+145
-66
lines changed

aten/contrib/meter/APMeter.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void APMeter::value(Tensor& val) {
8282
if(targetbuffer_d[n] != 0.)
8383
val_d[k] += precision_d[n];
8484
}
85-
auto norm = sum(targetbuffer).toDouble();
85+
auto norm = sum(targetbuffer).toCDouble();
8686
if(norm > 0)
8787
val_d[k] /= norm;
8888
}

aten/src/ATen/Declarations.cwrap

+35-15
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,16 @@
126126
]]
127127
[[
128128
name: fill_
129-
cname: fill
130129
return: self
131-
arguments:
132-
- THTensor* self
133-
- real value
130+
cname: fill
131+
options:
132+
- arguments:
133+
- THTensor* self
134+
- real value
135+
- zero_dim_tensor_only: True
136+
arguments:
137+
- THTensor* self
138+
- THTensor* value
134139
]]
135140
[[
136141
name: isContiguous
@@ -156,11 +161,18 @@
156161
cname: maskedFill
157162
python_name: masked_fill_
158163
return: self
159-
arguments:
160-
- arg: THTensor* self
161-
broadcast: mask inplace fallback types:Byte
162-
- THBoolTensor* mask
163-
- real value
164+
options:
165+
- arguments:
166+
- arg: THTensor* self
167+
broadcast: mask inplace fallback types:Byte
168+
- THBoolTensor* mask
169+
- real value
170+
- zero_dim_tensor_only: True
171+
arguments:
172+
- arg: THTensor* self
173+
broadcast: mask inplace fallback types:Byte
174+
- THBoolTensor* mask
175+
- THTensor* value
164176
]]
165177
[[
166178
name: maskedCopy_
@@ -364,12 +376,20 @@
364376
python_name: index_fill_
365377
cname: indexFill
366378
return: argument 0
367-
arguments:
368-
- THTensor* self
369-
- arg: long dim
370-
wrap_dim: self
371-
- THIndexTensor* index
372-
- real value
379+
options:
380+
- arguments:
381+
- THTensor* self
382+
- arg: long dim
383+
wrap_dim: self
384+
- THIndexTensor* index
385+
- real value
386+
- zero_dim_tensor_only: True
387+
arguments:
388+
- THTensor* self
389+
- arg: long dim
390+
wrap_dim: self
391+
- THIndexTensor* index
392+
- THTensor* value
373393
]]
374394
[[
375395
name: narrow

aten/src/ATen/function_wrapper.py

+32-7
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,15 @@
7474
# implementation of ${api_name} if we have overloaded a function with
7575
# the same name (but different signature) already
7676
ZERO_DIM_CHECK = CodeTemplate("""\
77-
if(${check_name}.dim() == 0) {
77+
if (${check_name}.dim() == 0) {
7878
return static_cast<const Type*>(this)->${method_prefix}${api_name}(${zero_dim_actuals});
7979
}""")
8080

81+
ZERO_DIM_ONLY = CodeTemplate("""\
82+
runtime_error("${api_name} only supports a 0-dimensional ${check_name} tensor, but got tensor "
83+
"with %" PRId64 " dimension(s)", ${check_name}.dim());
84+
""")
85+
8186
SPARSE_CHECK = CodeTemplate("""\
8287
if(${check_name}.type().isSparse()) {
8388
return static_cast<const Type*>(this)->${method_prefix}${api_name}(${sparse_actuals});
@@ -136,8 +141,8 @@ def __init__(self, reason):
136141
'THIndexTensor*': 'Tensor',
137142
'THBoolTensor*': 'Tensor',
138143
'THIntegerTensor*': 'Tensor',
139-
'real': 'Scalar',
140-
'accreal': 'Scalar',
144+
'real': 'Tensor',
145+
'accreal': 'Tensor',
141146
'long': 'int64_t',
142147
}
143148

@@ -710,14 +715,24 @@ def is_actual_return_long(ret):
710715
return backend_type_env['AccScalarName'] == 'Long'
711716
return False
712717

718+
def get_zero_dim_dispatch_when_scalar(option):
719+
return option.get('zero_dim_dispatch_when_scalar', False)
720+
713721
def handle_zero_dim(env, option):
714-
if 'zero_dim_dispatch_when_scalar' not in option:
722+
zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option)
723+
if not zero_dim_dispatch:
715724
return []
716-
check_name = option['zero_dim_dispatch_when_scalar']
717725
zero_dim_actuals = [arg['name']
718-
if arg['name'] != check_name else "Scalar({})".format(arg['name'])
726+
if arg['name'] != zero_dim_dispatch else "Scalar({})".format(arg['name'])
719727
for arg in option['formals_list']]
720-
return [ZERO_DIM_CHECK.substitute(env, check_name=check_name, zero_dim_actuals=zero_dim_actuals)]
728+
return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
729+
730+
def handle_only_zero_dim(env, option):
731+
if option.get('zero_dim_tensor_only', False):
732+
check_name = get_zero_dim_dispatch_when_scalar(option)
733+
return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)]
734+
else:
735+
return None
721736

722737
def handle_sparse(env, option):
723738
if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']:
@@ -781,6 +796,12 @@ def emit_body(env, option):
781796
body = []
782797
body += handle_sparse(env, option)
783798
body += handle_zero_dim(env, option)
799+
only_zero_dim_check = handle_only_zero_dim(env, option)
800+
if only_zero_dim_check is not None:
801+
# code below only_zero_dim_check is unreachable so we do not need to generate the rest.
802+
body += only_zero_dim_check
803+
return body
804+
784805
body += handle_buffers(env, option)
785806
# arguments are potentially duplicated because of one argument
786807
# referencing another
@@ -933,6 +954,10 @@ def emit_body(env, option):
933954
return_tensor = "return Tensor((new ${Tensor}(context,${arg_name}))${maybe_scalar},false);"
934955
body.append(CodeTemplate(return_tensor).substitute(
935956
env, arg_name=call, maybe_scalar=maybe_scalar))
957+
# return the same underlying Tensor type for both real and accreal; this ensures
958+
# e.g. x.sum(0) and x.sum() return the same type.
959+
elif ret['type'] == 'accreal' or ret['type'] == 'real':
960+
body.append('return scalarTensor({});'.format(call))
936961
else:
937962
# we using int64_t for long in the API, so correct it here...
938963
if is_actual_return_long(ret):

aten/src/ATen/templates/Tensor.h

+5
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ struct Tensor : public detail::TensorBase {
108108
AT_FORALL_SCALAR_TYPES(TO_TYPE_DATA)
109109
#undef TO_TYPE_DATA
110110

111+
#define TO_C_TYPE(T,name,_) \
112+
T toC##name () const;
113+
AT_FORALL_SCALAR_TYPES(TO_C_TYPE)
114+
#undef TO_C_TYPE
115+
111116
template<typename T, size_t N>
112117
TensorAccessor<T,N> accessor() {
113118
static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data<T>()");

aten/src/ATen/templates/TensorMethods.h

+6
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,10 @@ inline T* Tensor::to##name##Data() const { return data<T>(); }
4646
AT_FORALL_SCALAR_TYPES(DEFINE_CAST)
4747
#undef DEFINE_CAST
4848

49+
#define DEFINE_TO_C_TYPE(T,name,_) \
50+
inline T Tensor::toC##name () const { return Scalar(*this).to##name (); }
51+
52+
AT_FORALL_SCALAR_TYPES(DEFINE_TO_C_TYPE)
53+
#undef DEFINE_TO_C_TYPE
54+
4955
} //namespace at

aten/src/ATen/test/basic.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ static void test(Type & type) {
3333
std::cout << "ones and dot:" << std::endl;
3434
Tensor b = type.ones({3, 4});
3535
std::cout << b << std::endl;
36-
ASSERT(24 == (b+b).sum().toDouble());
36+
ASSERT(24 == (b+b).sum().toCDouble());
3737
std::cout << b.numel() << std::endl;
3838
ASSERT(12 == b.numel());
3939
std::cout << b.dot(b) << std::endl;
40-
ASSERT(b.dot(b).toDouble() == 12);
40+
ASSERT(b.dot(b).toCDouble() == 12);
4141
}
4242

4343
{
@@ -97,8 +97,8 @@ static void test(Type & type) {
9797
}
9898
auto end = std::chrono::high_resolution_clock::now();
9999
std::cout << std::dec << " " << std::chrono::duration_cast<std::chrono::milliseconds>(end-begin).count() << " ms" << std::endl;
100-
ASSERT(norm(100000*d).toDouble() == norm(r).toDouble());
101-
std::cout << " norm: " << norm(r).toDouble() << std::endl;
100+
ASSERT(norm(100000*d).toCDouble() == norm(r).toCDouble());
101+
std::cout << " norm: " << norm(r).toCDouble() << std::endl;
102102
}
103103

104104
{
@@ -111,8 +111,8 @@ static void test(Type & type) {
111111
}
112112
auto end = std::chrono::high_resolution_clock::now();
113113
std::cout << std::dec << " " << std::chrono::duration_cast<std::chrono::milliseconds>(end-begin).count() << " ms" << std::endl;
114-
ASSERT(norm(100000*d).toDouble() == norm(r).toDouble());
115-
std::cout << " norm: " << norm(r).toDouble() << std::endl;
114+
ASSERT(norm(100000*d).toCDouble() == norm(r).toCDouble());
115+
std::cout << " norm: " << norm(r).toCDouble() << std::endl;
116116
}
117117

118118
{
@@ -247,7 +247,7 @@ static void test(Type & type) {
247247
std::cout << c << std::endl;
248248

249249
Tensor e = CPU(kFloat).rand({});
250-
ASSERT(*e.data<float>()== e.sum().toFloat());
250+
ASSERT(*e.data<float>()== e.sum().toCFloat());
251251
}
252252
{
253253
Tensor b = CPU(kFloat).ones({3,7})*.0000001f;

aten/src/ATen/test/scalar_tensor_test.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,17 @@ int main() {
172172
ASSERT(false);
173173
} catch (std::runtime_error &e) {}
174174
}
175+
176+
// fill_
177+
if (t.dim() > 0 && t.numel() != 0) {
178+
try {
179+
// can only fill_ 0-dim tensors
180+
t.fill_(t.sum(0));
181+
assert(t.dim() == 1);
182+
} catch (std::runtime_error &e) {
183+
assert(t.dim() != 1);
184+
}
185+
}
175186
}
176187

177188
for (auto lhs_it = sizes.begin(); lhs_it != sizes.end(); ++lhs_it) {

aten/src/ATen/test/scalar_test.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ int main() {
105105
auto t = CPU(Float).ones({4,4});
106106

107107
auto wha2 = CPU(Float).zeros({4,4}).add(t).sum();
108-
cout << wha2.toDouble() << " <-ndim\n";
108+
cout << wha2.toCDouble() << " <-ndim\n";
109109

110110
cout << t.sizes() << " " << t.strides() << "\n";
111111

@@ -143,6 +143,18 @@ int main() {
143143
ASSERT(Scalar(CPU(kFloat).ones({})).toTensor().type().scalarType() == kFloat);
144144

145145
dispatch<Foo>(x.type(),x,prev_h);
146+
147+
// test direct C-scalar type conversions
148+
try {
149+
auto x = T.ones({1,2});
150+
x.toCFloat();
151+
ASSERT(false);
152+
} catch (std::runtime_error &e) {}
153+
auto float_one = T.ones({});
154+
ASSERT(float_one.toCFloat() == 1);
155+
ASSERT(float_one.toCInt() == 1);
156+
ASSERT(float_one.toCHalf() == 1);
157+
146158
return 0;
147159

148160
}

tools/autograd/templates/Functions.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, Tensor
7676

7777
Tensor reduce_to(const Tensor & grad, IntList sizes) {
7878
if (sizes.size() == 0) {
79-
return grad.sum().toTensor();
79+
return grad.sum();
8080
}
8181
Tensor result = grad;
8282
while (result.dim() > (int64_t)sizes.size()) {
@@ -306,9 +306,9 @@ Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input
306306
Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average, bool reduce) {
307307
auto result = kl_div_backward(grad, input, target, size_average, false);
308308
if (reduce && size_average) {
309-
return result.mean().toTensor();
309+
return result.mean();
310310
} else if (reduce) {
311-
return result.sum().toTensor();
311+
return result.sum();
312312
}
313313
return result;
314314
}
@@ -343,9 +343,9 @@ Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_outp
343343
Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average, bool reduce) {
344344
auto output = l1_loss_backward(grad, input, target, size_average, false);
345345
if (reduce and size_average) {
346-
return output.mean().toTensor();
346+
return output.mean();
347347
} else if (reduce) {
348-
return output.sum().toTensor();
348+
return output.sum();
349349
}
350350
return output;
351351
}
@@ -364,7 +364,7 @@ Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Ten
364364
return smooth_l1_loss_backward(grad, input, target, size_average, reduce);
365365
}
366366
auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, size_average, true);
367-
return (r * grad).sum().toTensor().view({1});
367+
return (r * grad).sum().view({1});
368368
}
369369

370370
Tensor max_pool2d_double_backward(const Tensor & grad, const Tensor & indices) {
@@ -389,7 +389,7 @@ Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor &
389389
return mse_loss_backward(grad, input, target, size_average, reduce);
390390
}
391391
auto r = mse_loss_backward(ones_like(grad_output), input, target, size_average, true);
392-
return (r * grad).sum().toTensor().view({1});
392+
return (r * grad).sum().view({1});
393393
}
394394

395395
Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, bool size_average) {

torch/csrc/jit/test_jit.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ static void fusionTests() {
9898
auto o = at::CUDA(at::kFloat).zeros({3,4});
9999
comp.debugLaunchGraph(graph, {a,b}, {o});
100100
auto o2 = a*b;
101-
float max_diff = (o2 - o).abs().max().toDouble();
101+
float max_diff = (o2 - o).abs().max().toCDouble();
102102
//std::cout << "max diff: " << max_diff << "\n";
103103
JIT_ASSERT(max_diff == 0);
104104
};
@@ -160,7 +160,7 @@ static void fusionTests() {
160160
//auto out0 = inputs[0]*inputs[1];
161161
comp.debugLaunchGraph(graph, inputs, outputs);
162162
JIT_ASSERT(out0.is_same_size(outputs.front()));
163-
float max_diff = (outputs.front() - out0).abs().max().toDouble();
163+
float max_diff = (outputs.front() - out0).abs().max().toCDouble();
164164
JIT_ASSERT(max_diff < 1e-6);
165165

166166
};
@@ -191,9 +191,9 @@ static void fusionTests() {
191191
auto o2 = at::CUDA(at::kFloat).zeros(o2_r.sizes());
192192
comp.debugLaunchGraph(graph, {a,b}, {o, o2});
193193

194-
float max_diff = (o_r - o).abs().max().toDouble();
194+
float max_diff = (o_r - o).abs().max().toCDouble();
195195
JIT_ASSERT(max_diff == 0);
196-
float max_diff2 = (o2_r - o2).abs().max().toDouble();
196+
float max_diff2 = (o2_r - o2).abs().max().toCDouble();
197197
JIT_ASSERT(max_diff2 == 0);
198198
};
199199
testConcat(0);

0 commit comments

Comments
 (0)