|
74 | 74 | # implementation of ${api_name} if we have overloaded a function with |
75 | 75 | # the same name (but different signature) already |
76 | 76 | ZERO_DIM_CHECK = CodeTemplate("""\ |
77 | | -if(${check_name}.dim() == 0) { |
| 77 | +if (${check_name}.dim() == 0) { |
78 | 78 | return static_cast<const Type*>(this)->${method_prefix}${api_name}(${zero_dim_actuals}); |
79 | 79 | }""") |
80 | 80 |
|
| 81 | +ZERO_DIM_ONLY = CodeTemplate("""\ |
| 82 | +runtime_error("${api_name} only supports a 0-dimensional ${check_name} tensor, but got tensor " |
| 83 | + "with %" PRId64 " dimension(s)", ${check_name}.dim()); |
| 84 | +""") |
| 85 | + |
81 | 86 | SPARSE_CHECK = CodeTemplate("""\ |
82 | 87 | if(${check_name}.type().isSparse()) { |
83 | 88 | return static_cast<const Type*>(this)->${method_prefix}${api_name}(${sparse_actuals}); |
@@ -136,8 +141,8 @@ def __init__(self, reason): |
136 | 141 | 'THIndexTensor*': 'Tensor', |
137 | 142 | 'THBoolTensor*': 'Tensor', |
138 | 143 | 'THIntegerTensor*': 'Tensor', |
139 | | - 'real': 'Scalar', |
140 | | - 'accreal': 'Scalar', |
| 144 | + 'real': 'Tensor', |
| 145 | + 'accreal': 'Tensor', |
141 | 146 | 'long': 'int64_t', |
142 | 147 | } |
143 | 148 |
|
@@ -710,14 +715,24 @@ def is_actual_return_long(ret): |
710 | 715 | return backend_type_env['AccScalarName'] == 'Long' |
711 | 716 | return False |
712 | 717 |
|
| 718 | + def get_zero_dim_dispatch_when_scalar(option): |
| 719 | + return option.get('zero_dim_dispatch_when_scalar', False) |
| 720 | + |
713 | 721 | def handle_zero_dim(env, option): |
714 | | - if 'zero_dim_dispatch_when_scalar' not in option: |
| 722 | + zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option) |
| 723 | + if not zero_dim_dispatch: |
715 | 724 | return [] |
716 | | - check_name = option['zero_dim_dispatch_when_scalar'] |
717 | 725 | zero_dim_actuals = [arg['name'] |
718 | | - if arg['name'] != check_name else "Scalar({})".format(arg['name']) |
| 726 | + if arg['name'] != zero_dim_dispatch else "Scalar({})".format(arg['name']) |
719 | 727 | for arg in option['formals_list']] |
720 | | - return [ZERO_DIM_CHECK.substitute(env, check_name=check_name, zero_dim_actuals=zero_dim_actuals)] |
| 728 | + return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)] |
| 729 | + |
| 730 | + def handle_only_zero_dim(env, option): |
| 731 | + if option.get('zero_dim_tensor_only', False): |
| 732 | + check_name = get_zero_dim_dispatch_when_scalar(option) |
| 733 | + return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)] |
| 734 | + else: |
| 735 | + return None |
721 | 736 |
|
722 | 737 | def handle_sparse(env, option): |
723 | 738 | if 'when_sparse_dispatch' not in option or 'Sparse' in backend_type_env['Backend']: |
@@ -781,6 +796,12 @@ def emit_body(env, option): |
781 | 796 | body = [] |
782 | 797 | body += handle_sparse(env, option) |
783 | 798 | body += handle_zero_dim(env, option) |
| 799 | + only_zero_dim_check = handle_only_zero_dim(env, option) |
| 800 | + if only_zero_dim_check is not None: |
| 801 | + # code below only_zero_dim_check is unreachable so we do not need to generate the rest. |
| 802 | + body += only_zero_dim_check |
| 803 | + return body |
| 804 | + |
784 | 805 | body += handle_buffers(env, option) |
785 | 806 | # arguments are potentially duplicated because of one argument |
786 | 807 | # referencing another |
@@ -933,6 +954,10 @@ def emit_body(env, option): |
933 | 954 | return_tensor = "return Tensor((new ${Tensor}(context,${arg_name}))${maybe_scalar},false);" |
934 | 955 | body.append(CodeTemplate(return_tensor).substitute( |
935 | 956 | env, arg_name=call, maybe_scalar=maybe_scalar)) |
| 957 | + # return the same underlying Tensor type for both real and accreal; this ensures |
| 958 | + # e.g. x.sum(0) and x.sum() return the same type. |
| 959 | + elif ret['type'] == 'accreal' or ret['type'] == 'real': |
| 960 | + body.append('return scalarTensor({});'.format(call)) |
936 | 961 | else: |
937 | 962 | # we using int64_t for long in the API, so correct it here... |
938 | 963 | if is_actual_return_long(ret): |
|
0 commit comments