Skip to content

Commit 6b74a68

Browse files
authored
Add arange default values (#20)
1 parent 9506c28 commit 6b74a68

File tree

4 files changed

+65
-24
lines changed

4 files changed

+65
-24
lines changed

examples/shallow_water.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
6161
def transpose(a):
6262
return np.permute_dims(a, [1, 0])
6363

64-
all_axes = [0, 1]
6564
init(False)
6665

6766
elif backend == "numpy":
@@ -76,7 +75,6 @@ def transpose(a):
7675
transpose = np.transpose
7776

7877
fini = sync = lambda x=None: None
79-
all_axes = None
8078
else:
8179
raise ValueError(f'Unknown backend: "{backend}"')
8280

@@ -207,11 +205,11 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
207205
# set bathymetry
208206
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
209207
# steady state potential energy
210-
pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
208+
pe_offset = 0.5 * g * float(np.sum(h**2.0)) / nx / ny
211209

212210
# compute time step
213211
alpha = 0.5
214-
h_max = float(np.max(h, all_axes))
212+
h_max = float(np.max(h))
215213
c = (g * h_max) ** 0.5
216214
dt = alpha * dx / c
217215
dt = t_export / int(math.ceil(t_export / dt))
@@ -344,22 +342,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
344342
t = i * dt
345343

346344
if t >= next_t_export - 1e-8:
347-
_elev_max = np.max(e, all_axes)
348-
_u_max = np.max(u, all_axes)
349-
_q_max = np.max(q, all_axes)
350-
_total_v = np.sum(e + h, all_axes)
345+
_elev_max = np.max(e)
346+
_u_max = np.max(u)
347+
_q_max = np.max(q)
348+
_total_v = np.sum(e + h)
351349

352350
# potential energy
353351
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
354-
_total_pe = np.sum(_pe, all_axes)
352+
_total_pe = np.sum(_pe)
355353

356354
# kinetic energy
357355
u2 = u * u
358356
v2 = v * v
359357
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
360358
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
361359
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
362-
_total_ke = np.sum(_ke, all_axes)
360+
_total_ke = np.sum(_ke)
363361

364362
total_pe = float(_total_pe) * dx * dy
365363
total_ke = float(_total_ke) * dx * dy
@@ -406,7 +404,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
406404
2
407405
]
408406
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
409-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
407+
err_L2 = math.sqrt(float(np.sum(err2)))
410408
info(f"L2 error: {err_L2:7.15e}")
411409

412410
if nx < 128 or ny < 128:

examples/wave_equation.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def run(n, backend, datatype, benchmark_mode):
6161
def transpose(a):
6262
return np.permute_dims(a, [1, 0])
6363

64-
all_axes = [0, 1]
6564
init(False)
6665

6766
elif backend == "numpy":
@@ -76,7 +75,6 @@ def transpose(a):
7675
transpose = np.transpose
7776

7877
fini = sync = lambda x=None: None
79-
all_axes = None
8078
else:
8179
raise ValueError(f'Unknown backend: "{backend}"')
8280

@@ -240,9 +238,9 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
240238
t = i * dt
241239

242240
if t >= next_t_export - 1e-8:
243-
_elev_max = np.max(e, all_axes)
244-
_u_max = np.max(u, all_axes)
245-
_total_v = np.sum(e + h, all_axes)
241+
_elev_max = np.max(e)
242+
_u_max = np.max(u)
243+
_total_v = np.sum(e + h)
246244

247245
elev_max = float(_elev_max)
248246
u_max = float(_u_max)
@@ -279,7 +277,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
279277

280278
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
281279
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
282-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
280+
err_L2 = math.sqrt(float(np.sum(err2)))
283281
info(f"L2 error: {err_L2:7.5e}")
284282

285283
if nx == 128 and ny == 128 and not benchmark_mode:

sharpy/__init__.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,22 @@ def _validate_device(device):
9696
raise ValueError(f"Invalid device string: {device}")
9797

9898

99+
def arange(start, /, end=None, step=1, dtype=int64, device="", team=1):
100+
if end is None:
101+
end = start
102+
start = 0
103+
assert step != 0, "step cannot be zero"
104+
if (end - start) * step < 0:
105+
# invalid range, return empty array
106+
start = end = 0
107+
step = 1
108+
return ndarray(
109+
_csp.Creator.arange(
110+
start, end, step, dtype, _validate_device(device), team
111+
)
112+
)
113+
114+
99115
for func in api.api_categories["Creator"]:
100116
FUNC = func.upper()
101117
if func == "full":
@@ -114,10 +130,6 @@ def _validate_device(device):
114130
exec(
115131
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, _validate_device(device), team))"
116132
)
117-
elif func == "arange":
118-
exec(
119-
f"{func} = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, _validate_device(device), team))"
120-
)
121133
elif func == "linspace":
122134
exec(
123135
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))"

test/test_create.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,41 @@ def creator(request):
2626
return request.param[0], request.param[1]
2727

2828

29+
def test_arange():
30+
n = 10
31+
a = sp.arange(0, n, 1, dtype=sp.int32, device=device)
32+
assert tuple(a.shape) == (n,)
33+
assert numpy.allclose(sp.to_numpy(a), list(range(n)))
34+
35+
36+
def test_arange2():
37+
n = 10
38+
a = sp.arange(0, n, dtype=sp.int32, device=device)
39+
assert tuple(a.shape) == (n,)
40+
assert numpy.allclose(sp.to_numpy(a), list(range(n)))
41+
42+
43+
def test_arange3():
44+
n = 10
45+
a = sp.arange(n, device=device)
46+
assert tuple(a.shape) == (n,)
47+
assert numpy.allclose(sp.to_numpy(a), list(range(n)))
48+
49+
50+
def test_arange_empty():
51+
n = 10
52+
a = sp.arange(n, 0, device=device)
53+
assert tuple(a.shape) == (0,)
54+
assert numpy.allclose(sp.to_numpy(a), list())
55+
56+
57+
def test_arange_empty2():
58+
n = 10
59+
a = sp.arange(0, n, -1, device=device)
60+
assert tuple(a.shape) == (0,)
61+
assert numpy.allclose(sp.to_numpy(a), list())
62+
63+
2964
def test_create_datatypes(creator, datatype):
3065
shape = (6, 4)
3166
func, expected_value = creator
@@ -67,9 +102,7 @@ def test_full_invalid_shape():
67102
sp.full(shape, value, dtype=datatype, device=device)
68103

69104

70-
@pytest.mark.parametrize(
71-
"start,end,step", [(0, 10, -1), (0, -10, 1), (0, 99999999999999999999, 1)]
72-
)
105+
@pytest.mark.parametrize("start,end,step", [(0, 99999999999999999999, 1)])
73106
def tests_arange_invalid(start, end, step):
74107
with pytest.raises(TypeError):
75108
sp.arange(start, end, step, dtype=sp.int32, device=device)

0 commit comments

Comments
 (0)