Skip to content

Commit a748f0e

Browse files
committed
Improve, type, fix array_equal across all array contexts
1 parent 2120e4a commit a748f0e

File tree

4 files changed

+91
-43
lines changed

4 files changed

+91
-43
lines changed

arraycontext/impl/jax/fake_numpy.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,16 @@
2727

2828
import jax.numpy as jnp
2929

30-
from arraycontext.container import NotAnArrayContainerError, serialize_container
30+
from arraycontext.container import (
31+
NotAnArrayContainerError,
32+
serialize_container,
33+
)
3134
from arraycontext.container.traversal import (
3235
rec_map_array_container,
3336
rec_map_reduce_array_container,
3437
rec_multimap_array_container,
3538
)
39+
from arraycontext.context import Array, ArrayOrContainer
3640
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace
3741

3842

@@ -156,29 +160,35 @@ def any(self, a):
156160
return rec_map_reduce_array_container(
157161
partial(reduce, jnp.logical_or), jnp.any, a)
158162

159-
def array_equal(self, a, b):
163+
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
160164
actx = self._array_context
161165

162166
# NOTE: not all backends support `bool` properly, so use `int8` instead
163-
true = actx.from_numpy(np.int8(True))
164-
false = actx.from_numpy(np.int8(False))
167+
true_ary = actx.from_numpy(np.int8(True))
168+
false_ary = actx.from_numpy(np.int8(False))
165169

166170
def rec_equal(x, y):
167171
if type(x) is not type(y):
168-
return false
172+
return false_ary
169173

170174
try:
171-
iterable = zip(serialize_container(x), serialize_container(y))
175+
serialized_x = serialize_container(x)
176+
serialized_y = serialize_container(y)
172177
except NotAnArrayContainerError:
173178
if x.shape != y.shape:
174-
return false
179+
return false_ary
175180
else:
176181
return jnp.all(jnp.equal(x, y))
177182
else:
183+
if len(serialized_x) != len(serialized_y):
184+
return false_ary
178185
return reduce(
179186
jnp.logical_and,
180-
[rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable],
181-
true)
187+
[(true_ary if kx_i == ky_i else false_ary)
188+
and rec_equal(x_i, y_i)
189+
for (kx_i, x_i), (ky_i, y_i)
190+
in zip(serialized_x, serialized_y)],
191+
true_ary)
182192

183193
return rec_equal(a, b)
184194

arraycontext/impl/numpy/fake_numpy.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525

2626
import numpy as np
2727

28-
from arraycontext.container import is_array_container
28+
from arraycontext.container import NotAnArrayContainerError, serialize_container
2929
from arraycontext.container.traversal import (
30-
multimap_reduce_array_container,
3130
rec_map_array_container,
3231
rec_map_reduce_array_container,
3332
rec_multimap_array_container,
3433
rec_multimap_reduce_array_container,
3534
)
35+
from arraycontext.context import Array, ArrayOrContainer
3636
from arraycontext.fake_numpy import (
3737
BaseFakeNumpyLinalgNamespace,
3838
BaseFakeNumpyNamespace,
@@ -131,18 +131,34 @@ def all(self, a):
131131
return rec_map_reduce_array_container(partial(reduce, np.logical_and),
132132
lambda subary: np.all(subary), a)
133133

134-
def array_equal(self, a, b):
135-
if type(a) != type(b):
136-
return False
137-
elif not is_array_container(a):
138-
if a.shape != b.shape:
139-
return False
134+
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
135+
def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> np.ndarray:
136+
false_ary = np.array(False)
137+
true_ary = np.array(True)
138+
if type(x) is not type(y):
139+
return false_ary
140+
141+
try:
142+
serialized_x = serialize_container(x)
143+
serialized_y = serialize_container(y)
144+
except NotAnArrayContainerError:
145+
assert isinstance(x, np.ndarray)
146+
assert isinstance(y, np.ndarray)
147+
return np.array(np.array_equal(x, y))
140148
else:
141-
return np.all(np.equal(a, b))
142-
else:
143-
return multimap_reduce_array_container(partial(reduce,
144-
np.logical_and),
145-
self.array_equal, a, b)
149+
if len(serialized_x) != len(serialized_y):
150+
return false_ary
151+
return reduce(
152+
np.logical_and,
153+
[(true_ary if kx_i == ky_i else false_ary)
154+
and rec_equal(x_i, y_i)
155+
for (kx_i, x_i), (ky_i, y_i)
156+
in zip(serialized_x, serialized_y)],
157+
true_ary)
158+
159+
result = rec_equal(a, b)
160+
161+
return result
146162

147163
def arange(self, *args, **kwargs):
148164
return np.arange(*args, **kwargs)

arraycontext/impl/pyopencl/fake_numpy.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
rec_multimap_array_container,
3939
rec_multimap_reduce_array_container,
4040
)
41+
from arraycontext.context import Array, ArrayOrContainer
4142
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
4243
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
4344
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
@@ -215,30 +216,40 @@ def _any(ary):
215216
result = result.get()[()]
216217
return result
217218

218-
def array_equal(self, a, b):
219+
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
219220
actx = self._array_context
220221
queue = actx.queue
221222

222223
# NOTE: pyopencl doesn't like `bool` much, so use `int8` instead
223-
true = actx.from_numpy(np.int8(True))
224-
false = actx.from_numpy(np.int8(False))
224+
true_ary = actx.from_numpy(np.int8(True))
225+
false_ary = actx.from_numpy(np.int8(False))
225226

226-
def rec_equal(x, y):
227+
def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array:
227228
if type(x) is not type(y):
228-
return false
229+
return false_ary
229230

230231
try:
231-
iterable = zip(serialize_container(x), serialize_container(y))
232+
serialized_x = serialize_container(x)
233+
serialized_y = serialize_container(y)
232234
except NotAnArrayContainerError:
235+
assert isinstance(x, cl_array.Array)
236+
assert isinstance(y, cl_array.Array)
237+
233238
if x.shape != y.shape:
234-
return false
239+
return false_ary
235240
else:
236241
return (x == y).all()
237242
else:
243+
if len(serialized_x) != len(serialized_y):
244+
return false_ary
245+
238246
return reduce(
239247
partial(cl_array.minimum, queue=queue),
240-
[rec_equal(x_i, y_i)for (_, x_i), (_, y_i) in iterable],
241-
true)
248+
[(true_ary if kx_i == ky_i else false_ary)
249+
and rec_equal(x_i, y_i)
250+
for (kx_i, x_i), (ky_i, y_i)
251+
in zip(serialized_x, serialized_y)],
252+
true_ary)
242253

243254
result = rec_equal(a, b)
244255
if not self._array_context._force_device_scalars:

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
THE SOFTWARE.
2323
"""
2424
from functools import partial, reduce
25-
from typing import Any
25+
from typing import Any, cast
2626

2727
import numpy as np
2828

@@ -34,6 +34,7 @@
3434
rec_map_reduce_array_container,
3535
rec_multimap_array_container,
3636
)
37+
from arraycontext.context import Array, ArrayOrContainer
3738
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
3839
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace
3940

@@ -171,31 +172,41 @@ def any(self, a):
171172
partial(reduce, pt.logical_or),
172173
lambda subary: pt.any(subary), a)
173174

174-
def array_equal(self, a, b):
175+
def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
175176
actx = self._array_context
176177

177178
# NOTE: not all backends support `bool` properly, so use `int8` instead
178-
true = actx.from_numpy(np.int8(True))
179-
false = actx.from_numpy(np.int8(False))
179+
true_ary = actx.from_numpy(np.int8(True))
180+
false_ary = actx.from_numpy(np.int8(False))
180181

181-
def rec_equal(x, y):
182+
def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> pt.Array:
182183
if type(x) is not type(y):
183-
return false
184+
return false_ary
184185

185186
try:
186-
iterable = zip(serialize_container(x), serialize_container(y))
187+
serialized_x = serialize_container(x)
188+
serialized_y = serialize_container(y)
187189
except NotAnArrayContainerError:
190+
assert isinstance(x, pt.Array)
191+
assert isinstance(y, pt.Array)
192+
188193
if x.shape != y.shape:
189-
return false
194+
return false_ary
190195
else:
191-
return pt.all(pt.equal(x, y))
196+
return pt.all(cast(pt.Array, pt.equal(x, y)))
192197
else:
198+
if len(serialized_x) != len(serialized_y):
199+
return false_ary
200+
193201
return reduce(
194202
pt.logical_and,
195-
[rec_equal(x_i, y_i) for (_, x_i), (_, y_i) in iterable],
196-
true)
203+
[(true_ary if kx_i == ky_i else false_ary)
204+
and rec_equal(x_i, y_i)
205+
for (kx_i, x_i), (ky_i, y_i)
206+
in zip(serialized_x, serialized_y)],
207+
true_ary)
197208

198-
return rec_equal(a, b)
209+
return cast(Array, rec_equal(a, b))
199210

200211
# }}}
201212

0 commit comments

Comments
 (0)