|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -from itertools import product |
16 |
| - |
17 | 15 | import aesara
|
18 | 16 | import aesara.tensor as at
|
19 | 17 | import numpy as np
|
|
29 | 27 | from aesara.tensor.random.basic import normal, uniform
|
30 | 28 | from aesara.tensor.random.op import RandomVariable
|
31 | 29 | from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
|
32 |
| -from aesara.tensor.type import TensorType |
33 | 30 | from aesara.tensor.var import TensorVariable
|
34 | 31 |
|
35 | 32 | import pymc as pm
|
36 | 33 |
|
37 | 34 | from pymc.aesaraf import (
|
38 |
| - _conversion_map, |
39 | 35 | change_rv_size,
|
40 | 36 | compile_pymc,
|
41 | 37 | convert_observed_data,
|
42 | 38 | extract_obs_data,
|
43 | 39 | rvs_to_value_vars,
|
44 |
| - take_along_axis, |
45 | 40 | walk_model,
|
46 | 41 | )
|
47 | 42 | from pymc.distributions.dist_math import check_parameters
|
@@ -166,198 +161,6 @@ def _make_along_axis_idx(arr_shape, indices, axis):
|
166 | 161 | return tuple(fancy_index)
|
167 | 162 |
|
168 | 163 |
|
169 |
| -if hasattr(np, "take_along_axis"): |
170 |
| - np_take_along_axis = np.take_along_axis |
171 |
| -else: |
172 |
| - |
173 |
| - def np_take_along_axis(arr, indices, axis): |
174 |
| - if arr.shape[axis] <= 32: |
175 |
| - # We can safely test with numpy's choose |
176 |
| - arr = np.moveaxis(arr, axis, 0) |
177 |
| - indices = np.moveaxis(indices, axis, 0) |
178 |
| - out = np.choose(indices, arr) |
179 |
| - return np.moveaxis(out, 0, axis) |
180 |
| - else: |
181 |
| - # numpy's choose cannot handle such a large axis so we |
182 |
| - # just use the implementation of take_along_axis. This is kind of |
183 |
| - # cheating because our implementation is the same as the one below |
184 |
| - if axis < 0: |
185 |
| - _axis = arr.ndim + axis |
186 |
| - else: |
187 |
| - _axis = axis |
188 |
| - if _axis < 0 or _axis >= arr.ndim: |
189 |
| - raise ValueError(f"Supplied axis {axis} is out of bounds") |
190 |
| - return arr[_make_along_axis_idx(arr.shape, indices, _axis)] |
191 |
| - |
192 |
| - |
193 |
| -class TestTakeAlongAxis: |
194 |
| - def setup_class(self): |
195 |
| - self.inputs_buffer = dict() |
196 |
| - self.output_buffer = dict() |
197 |
| - self.func_buffer = dict() |
198 |
| - |
199 |
| - def _input_tensors(self, shape, floatX): |
200 |
| - intX = str(_conversion_map[floatX]) |
201 |
| - ndim = len(shape) |
202 |
| - arr = TensorType(floatX, [False] * ndim)("arr") |
203 |
| - indices = TensorType(intX, [False] * ndim)("indices") |
204 |
| - arr.tag.test_value = np.zeros(shape, dtype=floatX) |
205 |
| - indices.tag.test_value = np.zeros(shape, dtype=intX) |
206 |
| - return arr, indices |
207 |
| - |
208 |
| - def get_input_tensors(self, shape, floatX): |
209 |
| - ndim = len(shape) |
210 |
| - try: |
211 |
| - return self.inputs_buffer[(ndim, floatX)] |
212 |
| - except KeyError: |
213 |
| - arr, indices = self._input_tensors(shape, floatX) |
214 |
| - self.inputs_buffer[(ndim, floatX)] = arr, indices |
215 |
| - return arr, indices |
216 |
| - |
217 |
| - def _output_tensor(self, arr, indices, axis): |
218 |
| - return take_along_axis(arr, indices, axis) |
219 |
| - |
220 |
| - def get_output_tensors(self, shape, axis, floatX): |
221 |
| - ndim = len(shape) |
222 |
| - try: |
223 |
| - return self.output_buffer[(ndim, axis, floatX)] |
224 |
| - except KeyError: |
225 |
| - arr, indices = self.get_input_tensors(shape, floatX) |
226 |
| - out = self._output_tensor(arr, indices, axis) |
227 |
| - self.output_buffer[(ndim, axis, floatX)] = out |
228 |
| - return out |
229 |
| - |
230 |
| - def _function(self, arr, indices, out): |
231 |
| - return aesara.function([arr, indices], [out]) |
232 |
| - |
233 |
| - def get_function(self, shape, axis, floatX): |
234 |
| - ndim = len(shape) |
235 |
| - try: |
236 |
| - return self.func_buffer[(ndim, axis, floatX)] |
237 |
| - except KeyError: |
238 |
| - arr, indices = self.get_input_tensors(shape, floatX) |
239 |
| - out = self.get_output_tensors(shape, axis, floatX) |
240 |
| - func = self._function(arr, indices, out) |
241 |
| - self.func_buffer[(ndim, axis, floatX)] = func |
242 |
| - return func |
243 |
| - |
244 |
| - @staticmethod |
245 |
| - def get_input_values(shape, axis, samples, floatX): |
246 |
| - intX = str(_conversion_map[floatX]) |
247 |
| - arr = np.random.randn(*shape).astype(floatX) |
248 |
| - size = list(shape) |
249 |
| - size[axis] = samples |
250 |
| - size = tuple(size) |
251 |
| - indices = np.random.randint(low=0, high=shape[axis], size=size, dtype=intX) |
252 |
| - return arr, indices |
253 |
| - |
254 |
| - @pytest.mark.parametrize( |
255 |
| - ["shape", "axis", "samples"], |
256 |
| - product( |
257 |
| - [ |
258 |
| - (1,), |
259 |
| - (3,), |
260 |
| - (3, 1), |
261 |
| - (3, 2), |
262 |
| - (1, 1), |
263 |
| - (1, 2), |
264 |
| - (40, 40), # choose fails here |
265 |
| - (5, 1, 1), |
266 |
| - (5, 1, 2), |
267 |
| - (5, 3, 1), |
268 |
| - (5, 3, 2), |
269 |
| - ], |
270 |
| - [0, -1], |
271 |
| - [1, 10], |
272 |
| - ), |
273 |
| - ids=str, |
274 |
| - ) |
275 |
| - @pytest.mark.parametrize("floatX", ["float32", "float64"]) |
276 |
| - def test_take_along_axis(self, shape, axis, samples, floatX): |
277 |
| - with aesara.config.change_flags(floatX=floatX): |
278 |
| - arr, indices = self.get_input_values(shape, axis, samples, floatX) |
279 |
| - func = self.get_function(shape, axis, floatX) |
280 |
| - assert np.allclose(np_take_along_axis(arr, indices, axis=axis), func(arr, indices)[0]) |
281 |
| - |
282 |
| - @pytest.mark.parametrize( |
283 |
| - ["shape", "axis", "samples"], |
284 |
| - product( |
285 |
| - [ |
286 |
| - (1,), |
287 |
| - (3,), |
288 |
| - (3, 1), |
289 |
| - (3, 2), |
290 |
| - (1, 1), |
291 |
| - (1, 2), |
292 |
| - (40, 40), # choose fails here |
293 |
| - (5, 1, 1), |
294 |
| - (5, 1, 2), |
295 |
| - (5, 3, 1), |
296 |
| - (5, 3, 2), |
297 |
| - ], |
298 |
| - [0, -1], |
299 |
| - [1, 10], |
300 |
| - ), |
301 |
| - ids=str, |
302 |
| - ) |
303 |
| - @pytest.mark.parametrize("floatX", ["float32", "float64"]) |
304 |
| - def test_take_along_axis_grad(self, shape, axis, samples, floatX): |
305 |
| - with aesara.config.change_flags(floatX=floatX): |
306 |
| - if axis < 0: |
307 |
| - _axis = len(shape) + axis |
308 |
| - else: |
309 |
| - _axis = axis |
310 |
| - # Setup the aesara function |
311 |
| - t_arr, t_indices = self.get_input_tensors(shape, floatX) |
312 |
| - t_out2 = aesara.grad( |
313 |
| - at.sum(self._output_tensor(t_arr**2, t_indices, axis)), |
314 |
| - t_arr, |
315 |
| - ) |
316 |
| - func = aesara.function([t_arr, t_indices], [t_out2]) |
317 |
| - |
318 |
| - # Test that the gradient gives the same output as what is expected |
319 |
| - arr, indices = self.get_input_values(shape, axis, samples, floatX) |
320 |
| - expected_grad = np.zeros_like(arr) |
321 |
| - slicer = [slice(None)] * len(shape) |
322 |
| - for i in range(indices.shape[axis]): |
323 |
| - slicer[axis] = i |
324 |
| - inds = indices[tuple(slicer)].reshape(shape[:_axis] + (1,) + shape[_axis + 1 :]) |
325 |
| - inds = _make_along_axis_idx(shape, inds, _axis) |
326 |
| - expected_grad[inds] += 1 |
327 |
| - expected_grad *= 2 * arr |
328 |
| - out = func(arr, indices)[0] |
329 |
| - assert np.allclose(out, expected_grad) |
330 |
| - |
331 |
| - @pytest.mark.parametrize("axis", [-4, 4], ids=str) |
332 |
| - @pytest.mark.parametrize("floatX", ["float32", "float64"]) |
333 |
| - def test_axis_failure(self, axis, floatX): |
334 |
| - with aesara.config.change_flags(floatX=floatX): |
335 |
| - arr, indices = self.get_input_tensors((3, 1), floatX) |
336 |
| - with pytest.raises(ValueError): |
337 |
| - take_along_axis(arr, indices, axis=axis) |
338 |
| - |
339 |
| - @pytest.mark.parametrize("floatX", ["float32", "float64"]) |
340 |
| - def test_ndim_failure(self, floatX): |
341 |
| - with aesara.config.change_flags(floatX=floatX): |
342 |
| - intX = str(_conversion_map[floatX]) |
343 |
| - arr = TensorType(floatX, [False] * 3)("arr") |
344 |
| - indices = TensorType(intX, [False] * 2)("indices") |
345 |
| - arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=floatX) |
346 |
| - indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=intX) |
347 |
| - with pytest.raises(ValueError): |
348 |
| - take_along_axis(arr, indices) |
349 |
| - |
350 |
| - @pytest.mark.parametrize("floatX", ["float32", "float64"]) |
351 |
| - def test_dtype_failure(self, floatX): |
352 |
| - with aesara.config.change_flags(floatX=floatX): |
353 |
| - arr = TensorType(floatX, [False] * 3)("arr") |
354 |
| - indices = TensorType(floatX, [False] * 3)("indices") |
355 |
| - arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=floatX) |
356 |
| - indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=floatX) |
357 |
| - with pytest.raises(IndexError): |
358 |
| - take_along_axis(arr, indices) |
359 |
| - |
360 |
| - |
361 | 164 | def test_extract_obs_data():
|
362 | 165 |
|
363 | 166 | with pytest.raises(TypeError):
|
|
0 commit comments