Skip to content

Commit 003c900

Browse files
ezyangpytorchmergebot
authored andcommitted
Add _assert_scalar (pytorch#117378)
Peeled off from pytorch#114148, because that PR is going to take a while to actually land. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#117378 Approved by: https://github.com/jansel
1 parent 1a85451 commit 003c900

File tree

5 files changed

+22
-0
lines changed

5 files changed

+22
-0
lines changed

aten/src/ATen/native/TensorCompare.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include <ATen/ops/_aminmax_native.h>
2323
#include <ATen/ops/_assert_async_native.h>
2424
#include <ATen/ops/_functional_assert_async_native.h>
25+
#include <ATen/ops/_assert_scalar_native.h>
26+
#include <ATen/ops/_functional_assert_scalar_native.h>
2527
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
2628
#include <ATen/ops/_unique.h>
2729
#include <ATen/ops/allclose_native.h>
@@ -421,6 +423,15 @@ void _assert_async_msg_cpu(const Tensor& self, c10::string_view assert_msg) {
421423
TORCH_CHECK(native::is_nonzero(self), assert_msg != "" ? assert_msg : "Assertion is failed");
422424
}
423425

426+
void _assert_scalar(const Scalar& scalar, c10::string_view assert_msg) {
427+
TORCH_SYM_CHECK(scalar.toSymBool(), assert_msg != "" ? assert_msg : "Assertion is failed");
428+
}
429+
430+
Tensor _functional_assert_scalar(const Scalar& scalar, c10::string_view assert_msg, const Tensor& dep_token) {
431+
_assert_scalar(scalar, assert_msg);
432+
return dep_token.clone();
433+
}
434+
424435
Tensor _functional_assert_async_msg_cpu(
425436
const Tensor& self,
426437
c10::string_view assert_msg,

aten/src/ATen/native/native_functions.yaml

+8
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@
175175
CPU: _assert_async_msg_cpu
176176
CUDA: _assert_async_msg_cuda
177177

178+
- func: _assert_scalar(Scalar self, str assert_msg) -> ()
179+
dispatch:
180+
CompositeExplicitAutograd: _assert_scalar
181+
182+
- func: _functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor
183+
dispatch:
184+
CompositeExplicitAutograd: _functional_assert_scalar
185+
178186
- func: _functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor
179187
dispatch:
180188
CPU: _functional_assert_async_msg_cpu

test/expect/HasDecompTest.test_has_decomposition.expect

+1
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ aten::_foreach_zero
334334
aten::_foreach_zero.out
335335
aten::_foreach_zero_
336336
aten::_functional_assert_async.msg
337+
aten::_functional_assert_scalar
337338
aten::_functional_sym_constrain_range
338339
aten::_functional_sym_constrain_range_for_size
339340
aten::_fused_adam

torch/fx/node.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
torch._assert,
4242
torch._assert_async,
4343
_ops.aten._assert_async.msg,
44+
_ops.aten._assert_scalar.default,
4445
_ops.aten.copy_.default,
4546
_ops.aten.sym_constrain_range.default,
4647
_ops.aten.sym_constrain_range_for_size.default,

torchgen/native_function_generation.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"_assert_async", # no return
5252
"_assert_async.msg", # no return
5353
"_cslt_sparse_mm_search", # returns an int
54+
"_assert_scalar", # no return
5455
"_dimI", # returns an int
5556
"_dimV", # returns an int
5657
"_has_same_storage_numel", # returns a boolean

0 commit comments

Comments
 (0)