Add custom ordering to OrderedBijector#453
Conversation
penelopeysm
left a comment
There was a problem hiding this comment.
Hi @a1ix2, thanks for the PR, and sorry that nobody responded until now! Please feel free to ping me if there are PRs that need to be reviewed.
I do like the direction of this PR and it is implemented cleanly. There are a handful of things I'd change, which are in the comments below.
One obvious missing piece is that the new functionality is not tested. The AD rules are tested, but I think it would be valuable to add tests to test/bijector/ordered.jl which exercise the functions you've implemented, and check that the invariants you are expecting hold true. For example, if you sample from ordered(MvNormal(...), Descending()) are the resulting samples reverse-sorted, and so on. I don't think you need to go to the extreme of doing the MCMC sampling in there, but some common-sense tests would go a long way.
| struct FixedOrder{ordertuple} <: AbstractOrdering | ||
| function FixedOrder{ordertuple}() where {ordertuple} | ||
| @assert (ordertuple isa Tuple{Int, Int, Vararg{Int}}) && allunique(ordertuple) && all(ordertuple .> 0) | ||
| new{ordertuple}() | ||
| end | ||
| FixedOrder(ordertuple::Tuple{Int, Int, Vararg{Int}}) = FixedOrder{ordertuple}() | ||
| end |
There was a problem hiding this comment.
These checks look pretty useful! One thing I'd say is that @asserts are not the best way of enforcing them, for two reasons:
- In principle, the Julia compiler might remove asserts (see the docstring of
@assert) and thus they should not be used to enforce invariants; - The error message obtained is quite opaque.
Along the lines of (2), I think it's in principle great to enforce that length(ordertuple) >= 2 via the type system. However, it can be confusing for the end user if they pass a shorter tuple, because they will hit a MethodError instead of an explicit error message.
So, I'd suggest that we make these more explicit. For example, we could relax the type bound to just NTuple{N,Int} where {N}, and do
isvalid = N >= 2 && allunique(ordertuple) && all(i -> i > 0, x)
if !isvalid
throw(ArgumentError("`ordertuple` must be a unique tuple of integers with at least length 2")
end(Also, the change all(i -> i > 0, x) over all(x .> 0) avoids creating an intermediate array.)
| OrderedBijector{OT}() where {OT <: AbstractOrdering} = new{OT}() | ||
| OrderedBijector(ordertuple::Tuple{Int, Int, Vararg{Int}}) = OrderedBijector{FixedOrder{ordertuple}}() | ||
| function OrderedBijector{FixedOrder{ordertuple}}() where {ordertuple} | ||
| @assert (ordertuple isa Tuple{Int, Int, Vararg{Int}}) && all(ordertuple .> 0) && allunique(ordertuple) |
There was a problem hiding this comment.
Given that these checks are already inside the inner constructor of FixedOrder, they don't need to be repeated here.
| @inbounds x[N] = y[N] | ||
| @inbounds for i in N-1:-1:1 |
There was a problem hiding this comment.
I learnt fairly recently that peppering @inbounds everywhere is not really helpful: JuliaStats/Distributions.jl#2005
So far I don't think I've ran into a case where they actually helped, so I think it would probably be better to remove them, unless there's a demonstrable case where it improves performance. I see that a lot of the existing codebase already uses it, but that's probably a historical artifact rather than a reflection of current best practice.
| x = similar(y) | ||
| @assert !isempty(y) |
There was a problem hiding this comment.
| x = similar(y) | |
| @assert !isempty(y) | |
| isempty(y) && return y | |
| x = similar(y) |
I think it might be more graceful to return an empty array if the input is empty (after all, an empty array is trivially ordered in whatever order one might want).
| """ | ||
| function ordered(d::ContinuousMultivariateDistribution) | ||
|
|
||
| ordered(d::ContinuousMultivariateDistribution, ordertuple::Tuple{Int, Int, Vararg{Int}}) = ordered(d, FixedOrder{ordertuple}()) |
There was a problem hiding this comment.
As above, I think we can relax the type bound to NTuple{N,Int} where {N} and let the inner constructor of FixedOrder handle the checking.
There was a problem hiding this comment.
Although, personally, in this case, I'd lean towards just not providing this extra method, and just making people write ordered(dist, FixedOrder(tuple)) -- because it saves very little work, and ordered(dist, tuple) is less clear to read (what does the tuple mean? -- one has to look at the implementation to figure it out).
| """ | ||
| ordered(d::Distribution) | ||
|
|
||
| Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements. |
There was a problem hiding this comment.
Regarding documentation:
-
The docstring of
orderedneeds to be updated to describe the improvements in this PR, otherwise nobody can find out about this without trawling the source code. (I get that Bijectors' docs are a bit underwhelming, but let's not let it stay that way!) -
Also, the newly exported constructors
Ascending,Descending, andFixedOrdershould be added to the API docs insidedocs/src/interface.md. You could conceivably pullorderedinto a separate section together with them, but all that is a matter of taste.
| abstract type AbstractOrdering end | ||
| struct Ascending <: AbstractOrdering end | ||
| struct Descending <: AbstractOrdering end | ||
| struct FixedOrder{ordertuple} <: AbstractOrdering |
There was a problem hiding this comment.
Does the tuple itself need to be a type parameter? Looking at the code below, I'm not sure if any of it really needs the tuple to be known at the type level for type stability, because the ordered stuff just turn a vector (or matrix) into a vector (or matrix). If we don't need it in the type, then it would probably be better to keep it as an ordinary field:
struct FixedOrder{T} <: AbstractOrdering
ordertuple::T
endas otherwise we might end up creating multiple method specialisations for different values of ordertuple even when it's not needed.
| if !((ordertuple isa Tuple{Int, Int, Vararg{Int}}) | ||
| && issubset(ordertuple, 1:length(d)) | ||
| && allunique(ordertuple)) | ||
| throw(ArgumentError("ordertuple must be a subset of 1:$(length(d)) of length at least 2 with no duplicates.")) | ||
| end |
There was a problem hiding this comment.
For these checks, we'd only need to keep the issubset one; the others would have been caught earlier by the inner constructor of FixedOrder. But also we can probably make it even simpler. We know that they're all positive integers, so we only really need to check that they're all within bounds for d:
N = length(d)
if any(i -> i > N, ordertuple)
throw(ArgumentError("all elements of `ordertuple` must be within bounds for distribution of length $N"))
end|
|
||
| # `OrderedBijector` | ||
| function ChainRulesCore.rrule(::typeof(_transform_ordered), y::AbstractVector) | ||
| function ChainRulesCore.rrule(::typeof(_transform_ordered), y::AbstractVector, ::Type{Ascending}) |
There was a problem hiding this comment.
I haven't looked at the AD front, but do you need the ChainRulesCore implementations for your use case? If not, I'd say that we can really just get rid of all of them. They're only really used for Zygote, and Bijectors doesn't support Zygote anymore (there will always be a previous version available for people who need compat).
I work with pharmacokinetic models where you often have local but not global identifiability and you need to impose some ordering constraint on subsets of elements of a multivariate rv but not other A strictly increasing ordering allows you to partially deal with that but at the cost of introducing a lot of reordering gymnastic within the model specification itself.
This PR introduces a type parameter so that now OrderedBijector{OT<:AbstractOrdering} that allows to impose any fix ordering, including an incomplete one, e.g. for some rv X ~ MvNormal(I(D)), you could have x[3] < x[4] < x[1] while all other indices are left unconstrained. The type parameter OT can Ascending, Descending, or FixedOrder{ordertuple}. In the previous example this would lead to OrderedBijector{FixedOrdering{(3,4,1)}} and OrderedDistribution{D, B, FixedOrdering{(3,4,1)}}. The default is Ascending. I've also modified chainrules.jl to account for these changes.
It passes all tests, but looking for feedback to bring everything up to the project's standards and style and for someone to double-check the math.