diff --git a/src/sampler.jl b/src/sampler.jl index 40418114e..974828e8b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -156,6 +156,17 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() +function set_values!!( + varinfo::AbstractVarInfo, initial_params::AbstractVector, spl::AbstractSampler +) + throw( + ArgumentError( + "`initial_params` must be a vector of type `Union{Real,Missing}`. " * + "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", + ), + ) +end + function set_values!!( varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}, @@ -164,7 +175,8 @@ function set_values!!( flattened_param_vals = varinfo[spl] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))", + "Provided initial value size ($(length(initial_params))) doesn't match " * + "the model size ($(length(flattened_param_vals))).", ), ) @@ -183,6 +195,24 @@ end function set_values!!( varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler ) + vars_in_varinfo = keys(varinfo) + for v in keys(initial_params) + vn = VarName{v}() + if !(vn in vars_in_varinfo) + for vv in vars_in_varinfo + if subsumes(vn, vv) + throw( + ArgumentError( + "The current model contains sub-variables of $v, such as ($vv). " * + "Using NamedTuple for initial_params is not supported in such a case. " * + "Please use AbstractVector for initial_params instead of NamedTuple.", + ), + ) + end + end + throw(ArgumentError("Variable $v not found in the model.")) + end + end initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) return update_values!!( varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) diff --git a/test/sampler.jl b/test/sampler.jl index e5fe6dc98..3b5424671 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -178,5 +178,30 @@ @test c1[1].metadata.s.vals == c2[1].metadata.s.vals end end + + @testset "error handling" begin + # https://github.com/TuringLang/Turing.jl/issues/2452 + @model function constrained_uniform(n) + Z ~ Uniform(10, 20) + X = Vector{Float64}(undef, n) + for i in 1:n + X[i] ~ Uniform(0, Z) + end + end + + n = 2 + initial_z = 15 + initial_x = [0.2, 0.5] + model = constrained_uniform(n) + vi = VarInfo(model) + + @test_throws ArgumentError DynamicPPL.initialize_parameters!!( + vi, [initial_z, initial_x], DynamicPPL.SampleFromPrior(), model + ) + + @test_throws ArgumentError DynamicPPL.initialize_parameters!!( + vi, (X=initial_x, Z=initial_z), DynamicPPL.SampleFromPrior(), model + ) + end end end