diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 8619b124be..79b4eb2de8 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -26,6 +26,7 @@ concatenate, constant, expand_dims, + extract_constant, stack, switch, ) @@ -1569,6 +1570,46 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False): return ret +def quantile(input, quant, axis=None): + """ + Computes the median along the given axis(es) of a tensor `input`. + Parameters + ---------- + input: TensorVariable + The input tensor. + quant: float + Probability for the quantiles to compute. + Values must be between 0 and 1 inclusive. + axis: None or int or (list of int) (see `Sum`) + Compute the quantile along this axis of the tensor. + None means computing along the flattened tensor. + """ + input = as_tensor_variable(input) + input_ndim = input.type.ndim + if axis is None: + axis = list(range(input_ndim)) + elif isinstance(axis, int | np.integer): + axis = [axis] + elif isinstance(axis, np.ndarray) and axis.ndim == 0: + axis = [int(axis)] + else: + axis = [int(a) for a in axis] + + new_axes_order = [i for i in range(input.ndim) if i not in axis] + axis + input = input.dimshuffle(new_axes_order) + + remaining_axis_size = shape(input)[: input.ndim - len(axis)] + flattened_axis_size = prod(shape(input)[input.ndim - len(axis) :]) + + input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]])) + axis = -1 + + sorted_input = input.sort(axis=axis) + input_shape = input.shape[axis] + k = extract_constant(input_shape) * quant + return sorted_input[k] + + @scalar_elemwise(symbolname="scalar_maximum") def maximum(x, y): """elemwise maximum. See max for the maximum in one tensor"""