diff --git a/test/combinators/conditional.jl b/test/combinators/conditional.jl new file mode 100644 index 00000000..2ebfc559 --- /dev/null +++ b/test/combinators/conditional.jl @@ -0,0 +1,29 @@ +# test/combinators/conditional.jl +using Test +using MeasureBase +using Random: MersenneTwister + +@testset "Conditional" begin + # Create a simple conditional measure + base_measure = StdNormal() + condition(x) = abs(x) <= 2 # Only accept values in [-2, 2] + + cond_measure = @inferred Conditional(base_measure, condition) + + # Test basic properties + @test basemeasure(cond_measure) === base_measure + + # Test sampling with rejection sampling + rng = MersenneTwister(123) + samples = [rand(rng, cond_measure) for _ in 1:100] + @test all(condition, samples) + + # Test density + x = 1.0 + @test logdensityof(cond_measure, x) ≈ logdensityof(base_measure, x) + @test logdensityof(cond_measure, 3.0) == -Inf # Outside condition + + # Test support + @test insupport(cond_measure, 1.0) + @test !insupport(cond_measure, 3.0) +end \ No newline at end of file diff --git a/test/combinators/half.jl b/test/combinators/half.jl new file mode 100644 index 00000000..eb46752e --- /dev/null +++ b/test/combinators/half.jl @@ -0,0 +1,74 @@ +using Test +using MeasureBase +using Random: MersenneTwister +using LogExpFunctions: loghalf +using IrrationalConstants: log2π + +@testset "Half" begin + rng = MersenneTwister(42) + + @testset "Basic properties" begin + μ = Half(StdNormal()) + + # Test show method + @test sprint(show, μ) == "Half(StdNormal())" + + # Test unhalf + @test unhalf(μ) === StdNormal() + + # Test basemeasure + @test basemeasure(μ) isa WeightedMeasure + @test _logweight(basemeasure(μ)) ≈ log(2) + end + + @testset "Sampling and density" begin + μ = Half(StdNormal()) + n_samples = 1000 + samples = [rand(rng, μ) for _ in 1:n_samples] + + # All samples should be non-negative + @test all(x -> x ≥ 0, samples) + + # Test density at specific points + x = 1.0 + expected_log_density = -0.5 * (x^2 + log2π) - loghalf + @test logdensityof(μ, x) ≈ expected_log_density + + # Test density at negative points + @test logdensityof(μ, -1.0) == -Inf + + # Test density at zero + @test isfinite(logdensityof(μ, 0.0)) + end + + @testset "Transport" begin + μ = Half(StdNormal()) + + # Test transport to/from uniform + u = 0.7 # arbitrary point in (0,1) + x = transport_to(μ, StdUniform(), u) + @test x ≥ 0 + @test transport_to(StdUniform(), μ, x) ≈ u + + # Test edge cases + @test transport_to(μ, StdUniform(), 0.0) == 0.0 + @test transport_to(μ, StdUniform(), 1.0) > 0 + end + + @testset "SMF (Standardized Measure Function)" begin + μ = Half(StdUniform()) + + # Test SMF properties + @test smf(μ, -1.0) == -1.0 # Below support + @test smf(μ, 0.0) == -1.0 # At lower bound + @test smf(μ, 0.5) == 0.0 # Midpoint + @test smf(μ, 1.0) == 1.0 # At upper bound + + # Test inverse SMF + for p in [0.0, 0.25, 0.5, 0.75, 1.0] + x = invsmf(μ, p) + @test smf(μ, x) ≈ p + @test 0 ≤ x ≤ 1 + end + end +end \ No newline at end of file diff --git a/test/combinators/implicitlymapped.jl b/test/combinators/implicitlymapped.jl index 7fa75bb0..c8b5b7f7 100644 --- a/test/combinators/implicitlymapped.jl +++ b/test/combinators/implicitlymapped.jl @@ -93,3 +93,57 @@ using AffineMaps, PropertyFunctions obs, ) end + +using Test +using MeasureBase +using Static: static +using Random: MersenneTwister + +@testset "TakeAny" begin + rng = MersenneTwister(42) + + @testset "Basic properties" begin + take2 = TakeAny(2) + take_static2 = TakeAny(static(2)) + + # Test with various collection types + arr = [1,2,3,4,5] + @test length(take2(arr)) == 2 + @test length(take_static2(arr)) == 2 + + # Test consistency + @test take2(arr) == take2(arr) # Same elements when called multiple times + + # Test with different sized inputs + @test length(take2(1:10)) == 2 + @test length(take2(1:1)) == 1 # Should handle cases where input is smaller than n + end + + @testset "Implicit mapping with TakeAny" begin + # Create a kernel that produces a product measure + kernel = par -> StdNormal()^3 + + # Create mapped version that only looks at first two components + mapped_kernel = ImplicitlyMapped(kernel, TakeAny(2)) + + # Test with some parameter value + par = 1.0 + full_measure = kernel(par) + mapped_measure = explicit_kernel(mapped_kernel, rand(rng, full_measure))(par) + + # Check dimensions + @test getdof(mapped_measure) == 2 + @test getdof(full_measure) == 3 + + # Test consistency of mapping + obs1 = rand(rng, full_measure) + obs2 = rand(rng, full_measure) + + mapped1 = mapped_kernel.mapfunc(obs1) + mapped2 = mapped_kernel.mapfunc(obs2) + + # Same elements should be selected consistently + @test length(mapped1) == 2 + @test mapped_kernel.mapfunc(obs1) == mapped1 # Consistent mapping + end +end \ No newline at end of file diff --git a/test/domains.jl b/test/domains.jl new file mode 100644 index 00000000..1138636e --- /dev/null +++ b/test/domains.jl @@ -0,0 +1,48 @@ +# test/domains.jl +using Test +using MeasureBase +using Static: static +using Random: MersenneTwister + +@testset "Domains" begin + @testset "BoundedInts" begin + bounded = ℤ[1:5] + @test 3 ∈ bounded + @test -1 ∉ bounded + @test 6 ∉ bounded + @test 1.5 ∉ bounded + @test minimum(bounded) == 1 + @test maximum(bounded) == 5 + @test testvalue(bounded) == 0 + + # Test show method + @test sprint(show, bounded) == "ℤ[1:5]" + end + + @testset "ZeroSet" begin + # Simple quadratic function and its gradient + f(x) = sum(x.^2) + ∇f(x) = 2x + zs = ZeroSet(f, ∇f) + + # Test points + @test zeros(3) ∈ zs + @test [1e-8, -1e-8, 1e-8] ∈ zs + @test [0.1, 0.1, 0.1] ∉ zs + + # Test with different floating point types + @test zeros(Float32, 2) ∈ zs + @test zeros(Float64, 2) ∈ zs + end + + @testset "IntegerNumbers" begin + @test minimum(ℤ) == static(-Inf) + @test maximum(ℤ) == static(Inf) + + # Test membership + @test 42 ∈ ℤ + @test -42 ∈ ℤ + @test 3.14 ∉ ℤ + @test 2.0 ∈ ℤ # Integer-valued floats should be in ℤ + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 52d6695a..1811707e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,6 +21,7 @@ using JET # include("test_aqua.jl") include("static.jl") +include("domains.jl") include("test_primitive.jl") include("test_standard.jl") @@ -33,5 +34,6 @@ include("smf.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") include("combinators/implicitlymapped.jl") - +include("combinators/conditional.jl") +include("combinators/half.jl") include("test_docs.jl") diff --git a/test/static.jl b/test/static.jl index f618124b..9350d93f 100644 --- a/test/static.jl +++ b/test/static.jl @@ -3,6 +3,7 @@ using Test import MeasureBase import Static +using StaticArrays using Static: static import FillArrays @@ -32,3 +33,20 @@ import FillArrays @test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) isa Static.StaticInt @test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) == static(7) end + +@testset "maybestatic_size" begin + # Test regular array + arr = rand(MersenneTwister(123), 3, 4) + @test MeasureBase.maybestatic_size(arr) == (3, 4) + + # Test static array + static_arr = SMatrix{2,2}([1 2; 3 4]) + @test MeasureBase.maybestatic_size(static_arr) == (static(2), static(2)) + + # Test mixed static/dynamic array + mixed = zeros(static(2), 3) # Create a matrix with static first dimension + size_result = MeasureBase.maybestatic_size(mixed) + @test size_result[1] isa Static.StaticInt + @test size_result[2] isa Int + @test size_result == (static(2), 3) +end \ No newline at end of file