|
74 | 74 | # implementation of ${api_name} if we have overloaded a function with
|
75 | 75 | # the same name (but different signature) already
|
76 | 76 | ZERO_DIM_CHECK = CodeTemplate("""\
|
77 |
| -if(${check_name}.dim() == 0) { |
| 77 | +if (${check_name}.dim() == 0) { |
78 | 78 | return static_cast<const Type*>(this)->${method_prefix}${api_name}(${zero_dim_actuals});
|
79 | 79 | }""")
|
80 | 80 |
|
| 81 | +ZERO_DIM_ONLY = CodeTemplate("""\ |
| 82 | +runtime_error("${api_name} only supports a 0-dimensional ${check_name} tensor, but got tensor " |
| 83 | + "with %" PRId64 " dimension(s)", ${check_name}.dim()); |
| 84 | +""") |
| 85 | + |
81 | 86 | SPARSE_CHECK = CodeTemplate("""\
|
82 | 87 | if(${check_name}.type().isSparse()) {
|
83 | 88 | return static_cast<const Type*>(this)->${method_prefix}${api_name}(${sparse_actuals});
|
@@ -136,8 +141,8 @@ def __init__(self, reason):
|
136 | 141 | 'THIndexTensor*': 'Tensor',
|
137 | 142 | 'THBoolTensor*': 'Tensor',
|
138 | 143 | 'THIntegerTensor*': 'Tensor',
|
139 |
| - 'real': 'Scalar', |
140 |
| - 'accreal': 'Scalar', |
| 144 | + 'real': 'Tensor', |
| 145 | + 'accreal': 'Tensor', |
141 | 146 | 'long': 'int64_t',
|
142 | 147 | }
|
143 | 148 |
|
@@ -710,14 +715,24 @@ def is_actual_return_long(ret):
|
710 | 715 | return backend_type_env['AccScalarName'] == 'Long'
|
711 | 716 | return False
|
712 | 717 |
|
| 718 | + def get_zero_dim_dispatch_when_scalar(option): |
| 719 | + return option.get('zero_dim_dispatch_when_scalar', False) |
| 720 | + |
713 | 721 | def handle_zero_dim(env, option):
|
714 |
| - if 'zero_dim_dispatch_when_scalar' not in option: |
| 722 | + zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option) |
| 723 | + if not zero_dim_dispatch: |
715 | 724 | return []
|
716 |
| - check_name = option['zero_dim_dispatch_when_scalar'] |
717 | 725 | zero_dim_actuals = [arg['name']
|
718 |
| - if arg['name'] != check_name else "Scalar({})".format(arg['name']) |
| 726 | + if arg['name'] != zero_dim_dispatch else "Scalar({})".format(arg['name']) |
719 | 727 | for arg in option['formals_list']]
|
720 |
| - return [ZERO_DIM_CHECK.substitute(env, check_name=check_name, zero_dim_actuals=zero_dim_actuals)] |
| 728 | + return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)] |
| 729 | + |
| 730 | + def handle_only_zero_dim(env, option): |
| 731 | + if option.get('zero_dim_tensor_only', False): |
| 732 | + check_name = get_zero_dim_dispatch_when_scalar(option) |
| 733 | + return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)] |
| 734 | + else: |
| 735 | + return None |
721 | 736 |
|
722 | 737 | def handle_sparse(env, option):
|
723 | 738 | if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']:
|
@@ -781,6 +796,12 @@ def emit_body(env, option):
|
781 | 796 | body = []
|
782 | 797 | body += handle_sparse(env, option)
|
783 | 798 | body += handle_zero_dim(env, option)
|
| 799 | + only_zero_dim_check = handle_only_zero_dim(env, option) |
| 800 | + if only_zero_dim_check is not None: |
| 801 | + # code below only_zero_dim_check is unreachable so we do not need to generate the rest. |
| 802 | + body += only_zero_dim_check |
| 803 | + return body |
| 804 | + |
784 | 805 | body += handle_buffers(env, option)
|
785 | 806 | # arguments are potentially duplicated because of one argument
|
786 | 807 | # referencing another
|
@@ -933,6 +954,10 @@ def emit_body(env, option):
|
933 | 954 | return_tensor = "return Tensor((new ${Tensor}(context,${arg_name}))${maybe_scalar},false);"
|
934 | 955 | body.append(CodeTemplate(return_tensor).substitute(
|
935 | 956 | env, arg_name=call, maybe_scalar=maybe_scalar))
|
| 957 | + # return the same underlying Tensor type for both real and accreal; this ensures |
| 958 | + # e.g. x.sum(0) and x.sum() return the same type. |
| 959 | + elif ret['type'] == 'accreal' or ret['type'] == 'real': |
| 960 | + body.append('return scalarTensor({});'.format(call)) |
936 | 961 | else:
|
937 | 962 | # we using int64_t for long in the API, so correct it here...
|
938 | 963 | if is_actual_return_long(ret):
|
|
0 commit comments