Skip to content

Commit 04cf833

Browse files
authored
Merge pull request #127 from devmotion/dw/inverse_eltype_scalar
Fix `inverse_eltype` for `ScalarTransform`s
2 parents fe778b0 + c59967a commit 04cf833

2 files changed

Lines changed: 41 additions & 3 deletions

File tree

src/scalar.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ $(TYPEDEF)
66
77
Transform a scalar (real number) to another scalar.
88
9-
Subtypes mustdefine `transform`, `transform_and_logjac`, and `inverse`; other
10-
methods of of the interface should have the right defaults.
9+
Subtypes must define `transform`, `transform_and_logjac`, and `inverse`.
10+
Other methods of of the interface should have the right defaults.
11+
12+
!!! NOTE
13+
This type is for code organization within the package, and is not part of the public API.
1114
"""
1215
abstract type ScalarTransform <: AbstractTransform end
1316

@@ -26,7 +29,11 @@ function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real)
2629
index + 1
2730
end
2831

29-
inverse_eltype(t::ScalarTransform, y::T) where {T <: Real} = float(T)
32+
function inverse_eltype(t::ScalarTransform, y::Real)
33+
# NOTE this is a shortcut to get sensible types for all subtypes of ScalarTransform, which
34+
# we test for. If it breaks it should be extended accordingly.
35+
return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), typeof(y)))
36+
end
3037

3138
_domain_label(::ScalarTransform, index::Int) = ()
3239

test/runtests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,34 @@ end
699699
U = transform(t, x)
700700
@test isfinite(logabsdet(U)[1])
701701
end
702+
703+
@testset "inverse_eltype of scalar transforms with parameters" begin
704+
# `Float64` parameters and `Float32` input
705+
for t in (as(Real, 0.5, ∞), as(Real, -∞, 2.1), as(Real, 0.5, 2.1))
706+
@test @inferred(inverse_eltype(t, 1.1f0)) === Float64
707+
@test @inferred(inverse(t, 1.1f0)) isa Float64
708+
end
709+
710+
# Derivatives wrt parameters of the transforms
711+
d1 = ForwardDiff.derivative(5.3) do x
712+
return @inferred only(inverse(as(Vector, as(Real, x, ∞), 1), [10]))
713+
end
714+
d2 = ForwardDiff.derivative(5.3) do x
715+
return @inferred inverse(as(Real, x, ∞), 10)
716+
end
717+
@test d1 == d2
718+
d1 = ForwardDiff.derivative(-3) do x
719+
return @inferred only(inverse(as(Vector, as(Real, -∞, x), 1), [-6.1]))
720+
end
721+
d2 = ForwardDiff.derivative(-3) do x
722+
return @inferred inverse(as(Real, -∞, x), -6.1)
723+
end
724+
@test d1 == d2
725+
d1 = ForwardDiff.gradient([-0.3, 4.7]) do x
726+
return @inferred only(inverse(as(Vector, as(Real, x[1], x[2]), 1), [2.3]))
727+
end
728+
d2 = ForwardDiff.gradient([-0.3, 4.7]) do x
729+
return @inferred inverse(as(Real, x[1], x[2]), 2.3)
730+
end
731+
@test d1 == d2
732+
end

0 commit comments

Comments
 (0)