Skip to content

Commit c327a02

Browse files
committed
Remove take_along_axis in favor of Aesara's implementation
1 parent 8520a2c commit c327a02

File tree

3 files changed

+3
-256
lines changed

3 files changed

+3
-256
lines changed

pymc/aesaraf.py

+1-57
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from aesara.tensor.var import TensorConstant, TensorVariable
5858

5959
from pymc.exceptions import ShapeError
60-
from pymc.vartypes import continuous_types, int_types, isgenerator, typefilter
60+
from pymc.vartypes import continuous_types, isgenerator, typefilter
6161

6262
PotentialShapeType = Union[
6363
int, np.ndarray, Tuple[Union[int, Variable], ...], List[Union[int, Variable]], Variable
@@ -80,7 +80,6 @@
8080
"generator",
8181
"set_at_rng",
8282
"at_rng",
83-
"take_along_axis",
8483
"convert_observed_data",
8584
]
8685

@@ -854,61 +853,6 @@ def largest_common_dtype(tensors):
854853
return np.stack([np.ones((), dtype=dtype) for dtype in dtypes]).dtype
855854

856855

857-
def _make_along_axis_idx(arr_shape, indices, axis):
858-
# compute dimensions to iterate over
859-
if str(indices.dtype) not in int_types:
860-
raise IndexError("`indices` must be an integer array")
861-
shape_ones = (1,) * indices.ndim
862-
dest_dims = list(range(axis)) + [None] + list(range(axis + 1, indices.ndim))
863-
864-
# build a fancy index, consisting of orthogonal aranges, with the
865-
# requested index inserted at the right location
866-
fancy_index = []
867-
for dim, n in zip(dest_dims, arr_shape):
868-
if dim is None:
869-
fancy_index.append(indices)
870-
else:
871-
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
872-
fancy_index.append(at.arange(n).reshape(ind_shape))
873-
874-
return tuple(fancy_index)
875-
876-
877-
def take_along_axis(arr, indices, axis=0):
878-
"""Take values from the input array by matching 1d index and data slices.
879-
880-
This iterates over matching 1d slices oriented along the specified axis in
881-
the index and data arrays, and uses the former to look up values in the
882-
latter. These slices can be different lengths.
883-
884-
Functions returning an index along an axis, like argsort and argpartition,
885-
produce suitable indices for this function.
886-
"""
887-
arr = at.as_tensor_variable(arr)
888-
indices = at.as_tensor_variable(indices)
889-
# normalize inputs
890-
if axis is None:
891-
arr = arr.flatten()
892-
arr_shape = (len(arr),) # flatiter has no .shape
893-
_axis = 0
894-
else:
895-
if axis < 0:
896-
_axis = arr.ndim + axis
897-
else:
898-
_axis = axis
899-
if _axis < 0 or _axis >= arr.ndim:
900-
raise ValueError(
901-
"Supplied `axis` value {} is out of bounds of an array with "
902-
"ndim = {}".format(axis, arr.ndim)
903-
)
904-
arr_shape = arr.shape
905-
if arr.ndim != indices.ndim:
906-
raise ValueError("`indices` and `arr` must have the same number of dimensions")
907-
908-
# use the fancy index
909-
return arr[_make_along_axis_idx(arr_shape, indices, _axis)]
910-
911-
912856
@local_optimizer(tracks=[CheckParameterValue])
913857
def local_remove_check_parameter(fgraph, node):
914858
"""Rewrite that removes Aeppl's CheckParameterValue

pymc/distributions/discrete.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
import pymc as pm
3333

34-
from pymc.aesaraf import floatX, intX, take_along_axis
34+
from pymc.aesaraf import floatX, intX
3535
from pymc.distributions.dist_math import (
3636
betaln,
3737
binomln,
@@ -1318,7 +1318,7 @@ def logp(value, p):
13181318
p = at.shape_padleft(p, value_clip.ndim - p_.ndim)
13191319
pattern = (p.ndim - 1,) + tuple(range(p.ndim - 1))
13201320
a = at.log(
1321-
take_along_axis(
1321+
at.take_along_axis(
13221322
p.dimshuffle(pattern),
13231323
value_clip,
13241324
)

pymc/tests/test_aesaraf.py

-197
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from itertools import product
16-
1715
import aesara
1816
import aesara.tensor as at
1917
import numpy as np
@@ -29,19 +27,16 @@
2927
from aesara.tensor.random.basic import normal, uniform
3028
from aesara.tensor.random.op import RandomVariable
3129
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
32-
from aesara.tensor.type import TensorType
3330
from aesara.tensor.var import TensorVariable
3431

3532
import pymc as pm
3633

3734
from pymc.aesaraf import (
38-
_conversion_map,
3935
change_rv_size,
4036
compile_pymc,
4137
convert_observed_data,
4238
extract_obs_data,
4339
rvs_to_value_vars,
44-
take_along_axis,
4540
walk_model,
4641
)
4742
from pymc.distributions.dist_math import check_parameters
@@ -166,198 +161,6 @@ def _make_along_axis_idx(arr_shape, indices, axis):
166161
return tuple(fancy_index)
167162

168163

169-
if hasattr(np, "take_along_axis"):
170-
np_take_along_axis = np.take_along_axis
171-
else:
172-
173-
def np_take_along_axis(arr, indices, axis):
174-
if arr.shape[axis] <= 32:
175-
# We can safely test with numpy's choose
176-
arr = np.moveaxis(arr, axis, 0)
177-
indices = np.moveaxis(indices, axis, 0)
178-
out = np.choose(indices, arr)
179-
return np.moveaxis(out, 0, axis)
180-
else:
181-
# numpy's choose cannot handle such a large axis so we
182-
# just use the implementation of take_along_axis. This is kind of
183-
# cheating because our implementation is the same as the one below
184-
if axis < 0:
185-
_axis = arr.ndim + axis
186-
else:
187-
_axis = axis
188-
if _axis < 0 or _axis >= arr.ndim:
189-
raise ValueError(f"Supplied axis {axis} is out of bounds")
190-
return arr[_make_along_axis_idx(arr.shape, indices, _axis)]
191-
192-
193-
class TestTakeAlongAxis:
194-
def setup_class(self):
195-
self.inputs_buffer = dict()
196-
self.output_buffer = dict()
197-
self.func_buffer = dict()
198-
199-
def _input_tensors(self, shape, floatX):
200-
intX = str(_conversion_map[floatX])
201-
ndim = len(shape)
202-
arr = TensorType(floatX, [False] * ndim)("arr")
203-
indices = TensorType(intX, [False] * ndim)("indices")
204-
arr.tag.test_value = np.zeros(shape, dtype=floatX)
205-
indices.tag.test_value = np.zeros(shape, dtype=intX)
206-
return arr, indices
207-
208-
def get_input_tensors(self, shape, floatX):
209-
ndim = len(shape)
210-
try:
211-
return self.inputs_buffer[(ndim, floatX)]
212-
except KeyError:
213-
arr, indices = self._input_tensors(shape, floatX)
214-
self.inputs_buffer[(ndim, floatX)] = arr, indices
215-
return arr, indices
216-
217-
def _output_tensor(self, arr, indices, axis):
218-
return take_along_axis(arr, indices, axis)
219-
220-
def get_output_tensors(self, shape, axis, floatX):
221-
ndim = len(shape)
222-
try:
223-
return self.output_buffer[(ndim, axis, floatX)]
224-
except KeyError:
225-
arr, indices = self.get_input_tensors(shape, floatX)
226-
out = self._output_tensor(arr, indices, axis)
227-
self.output_buffer[(ndim, axis, floatX)] = out
228-
return out
229-
230-
def _function(self, arr, indices, out):
231-
return aesara.function([arr, indices], [out])
232-
233-
def get_function(self, shape, axis, floatX):
234-
ndim = len(shape)
235-
try:
236-
return self.func_buffer[(ndim, axis, floatX)]
237-
except KeyError:
238-
arr, indices = self.get_input_tensors(shape, floatX)
239-
out = self.get_output_tensors(shape, axis, floatX)
240-
func = self._function(arr, indices, out)
241-
self.func_buffer[(ndim, axis, floatX)] = func
242-
return func
243-
244-
@staticmethod
245-
def get_input_values(shape, axis, samples, floatX):
246-
intX = str(_conversion_map[floatX])
247-
arr = np.random.randn(*shape).astype(floatX)
248-
size = list(shape)
249-
size[axis] = samples
250-
size = tuple(size)
251-
indices = np.random.randint(low=0, high=shape[axis], size=size, dtype=intX)
252-
return arr, indices
253-
254-
@pytest.mark.parametrize(
255-
["shape", "axis", "samples"],
256-
product(
257-
[
258-
(1,),
259-
(3,),
260-
(3, 1),
261-
(3, 2),
262-
(1, 1),
263-
(1, 2),
264-
(40, 40), # choose fails here
265-
(5, 1, 1),
266-
(5, 1, 2),
267-
(5, 3, 1),
268-
(5, 3, 2),
269-
],
270-
[0, -1],
271-
[1, 10],
272-
),
273-
ids=str,
274-
)
275-
@pytest.mark.parametrize("floatX", ["float32", "float64"])
276-
def test_take_along_axis(self, shape, axis, samples, floatX):
277-
with aesara.config.change_flags(floatX=floatX):
278-
arr, indices = self.get_input_values(shape, axis, samples, floatX)
279-
func = self.get_function(shape, axis, floatX)
280-
assert np.allclose(np_take_along_axis(arr, indices, axis=axis), func(arr, indices)[0])
281-
282-
@pytest.mark.parametrize(
283-
["shape", "axis", "samples"],
284-
product(
285-
[
286-
(1,),
287-
(3,),
288-
(3, 1),
289-
(3, 2),
290-
(1, 1),
291-
(1, 2),
292-
(40, 40), # choose fails here
293-
(5, 1, 1),
294-
(5, 1, 2),
295-
(5, 3, 1),
296-
(5, 3, 2),
297-
],
298-
[0, -1],
299-
[1, 10],
300-
),
301-
ids=str,
302-
)
303-
@pytest.mark.parametrize("floatX", ["float32", "float64"])
304-
def test_take_along_axis_grad(self, shape, axis, samples, floatX):
305-
with aesara.config.change_flags(floatX=floatX):
306-
if axis < 0:
307-
_axis = len(shape) + axis
308-
else:
309-
_axis = axis
310-
# Setup the aesara function
311-
t_arr, t_indices = self.get_input_tensors(shape, floatX)
312-
t_out2 = aesara.grad(
313-
at.sum(self._output_tensor(t_arr**2, t_indices, axis)),
314-
t_arr,
315-
)
316-
func = aesara.function([t_arr, t_indices], [t_out2])
317-
318-
# Test that the gradient gives the same output as what is expected
319-
arr, indices = self.get_input_values(shape, axis, samples, floatX)
320-
expected_grad = np.zeros_like(arr)
321-
slicer = [slice(None)] * len(shape)
322-
for i in range(indices.shape[axis]):
323-
slicer[axis] = i
324-
inds = indices[tuple(slicer)].reshape(shape[:_axis] + (1,) + shape[_axis + 1 :])
325-
inds = _make_along_axis_idx(shape, inds, _axis)
326-
expected_grad[inds] += 1
327-
expected_grad *= 2 * arr
328-
out = func(arr, indices)[0]
329-
assert np.allclose(out, expected_grad)
330-
331-
@pytest.mark.parametrize("axis", [-4, 4], ids=str)
332-
@pytest.mark.parametrize("floatX", ["float32", "float64"])
333-
def test_axis_failure(self, axis, floatX):
334-
with aesara.config.change_flags(floatX=floatX):
335-
arr, indices = self.get_input_tensors((3, 1), floatX)
336-
with pytest.raises(ValueError):
337-
take_along_axis(arr, indices, axis=axis)
338-
339-
@pytest.mark.parametrize("floatX", ["float32", "float64"])
340-
def test_ndim_failure(self, floatX):
341-
with aesara.config.change_flags(floatX=floatX):
342-
intX = str(_conversion_map[floatX])
343-
arr = TensorType(floatX, [False] * 3)("arr")
344-
indices = TensorType(intX, [False] * 2)("indices")
345-
arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=floatX)
346-
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=intX)
347-
with pytest.raises(ValueError):
348-
take_along_axis(arr, indices)
349-
350-
@pytest.mark.parametrize("floatX", ["float32", "float64"])
351-
def test_dtype_failure(self, floatX):
352-
with aesara.config.change_flags(floatX=floatX):
353-
arr = TensorType(floatX, [False] * 3)("arr")
354-
indices = TensorType(floatX, [False] * 3)("indices")
355-
arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=floatX)
356-
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=floatX)
357-
with pytest.raises(IndexError):
358-
take_along_axis(arr, indices)
359-
360-
361164
def test_extract_obs_data():
362165

363166
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)