Correlated Gaussian
This example will explore a highly-correlated Gaussian using Models.CorrelatedGaussian
. This model uses a conjuage Gaussian prior, see the docstring for the mathematical definition.
Setup
For this example, you'll need to add the following packages
julia>]add Distributions MCMCChains Measurements NestedSamplers StatsBase StatsPlots
Define model
using NestedSamplers
# set up a 4-dimensional Gaussian
D = 4
model, logz = Models.CorrelatedGaussian(D)
let's take a look at a couple of parameters to see what the likelihood surface looks like
using StatsPlots
θ1 = range(-1, 1, length=1000)
θ2 = range(-1, 1, length=1000)
loglike = model.prior_transform_and_loglikelihood.loglikelihood
logf = [loglike([t1, t2, 0, 0]) for t2 in θ2, t1 in θ1]
heatmap(
θ1, θ2, exp.(logf),
aspect_ratio=1,
xlims=extrema(θ1),
ylims=extrema(θ2),
xlabel="θ1",
ylabel="θ2"
)
Sample
using MCMCChains
using StatsBase
# using single Ellipsoid for bounds
# using Gibbs-style slicing for proposing new points
sampler = Nested(D, 50D;
bounds=Bounds.Ellipsoid,
proposal=Proposals.Slice()
)
names = ["θ_$i" for i in 1:D]
chain, state = sample(model, sampler; dlogz=0.01, param_names=names)
# resample chain using statistical weights
chain_resampled = sample(chain, Weights(vec(chain[:weights])), length(chain));
Results
chain_resampled
Chains MCMC chain (2386×5×1 Array{Float64, 3}):
Iterations = 1:2386
Number of chains = 1
Samples per chain = 2386
parameters = θ_1, θ_2, θ_3, θ_4
internals = weights
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
θ_1 1.6108 0.4664 0.0097 2328.6362 1936.1660 0.9996 ⋯
θ_2 1.6247 0.4596 0.0096 2268.8746 1958.8120 1.0001 ⋯
θ_3 1.5988 0.4543 0.0095 2307.4521 2194.8809 1.0001 ⋯
θ_4 1.6134 0.4677 0.0097 2331.0871 2434.4555 0.9998 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
θ_1 0.7428 1.2820 1.5899 1.9481 2.4679
θ_2 0.8042 1.3181 1.5984 1.9191 2.5305
θ_3 0.7181 1.2721 1.5976 1.8958 2.4717
θ_4 0.7550 1.3243 1.5942 1.9329 2.5527
corner(chain_resampled)
using Measurements
logz_est = state.logz ± state.logzerr
diff = logz_est - logz
println("logz: $logz")
println("estimate: $logz_est")
println("diff: $diff")
logz: -6.187913267630009
estimate: -6.23 ± 0.13
diff: -0.043 ± 0.13