Skip to content

Commit 5e550ca

Browse files
committed
add default values to arange
1 parent 9506c28 commit 5e550ca

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

Diff for: sharpy/__init__.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ def _validate_device(device):
9696
raise ValueError(f"Invalid device string: {device}")
9797

9898

99+
def arange(start, /, end=None, step=1, dtype=int64, device="", team=1):
100+
if end is None:
101+
end = start
102+
start = 0
103+
if (end - start) / step < 0:
104+
# invalid range, return empty array
105+
start = end = 0
106+
step = 1
107+
return ndarray(
108+
_csp.Creator.arange(
109+
start, end, step, dtype, _validate_device(device), team
110+
)
111+
)
112+
113+
99114
for func in api.api_categories["Creator"]:
100115
FUNC = func.upper()
101116
if func == "full":
@@ -114,10 +129,6 @@ def _validate_device(device):
114129
exec(
115130
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, _validate_device(device), team))"
116131
)
117-
elif func == "arange":
118-
exec(
119-
f"{func} = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, _validate_device(device), team))"
120-
)
121132
elif func == "linspace":
122133
exec(
123134
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))"

0 commit comments

Comments
 (0)