Eggbox

This example will explore the classic eggbox function using Models.Eggbox.

Setup

For this example, you'll need to add the following packages

julia>]add Distributions MCMCChains Measurements NestedSamplers StatsBase StatsPlots

Define model

using NestedSamplers

model, logz = Models.Eggbox()

let's take a look at a couple of parameters to see what the log-likelihood surface looks like

using StatsPlots

x = range(0, 1, length=1000)
y = range(0, 1, length=1000)
loglike = model.prior_transform_and_loglikelihood.loglikelihood
logf = [loglike([xi, yi]) for yi in y, xi in x]
heatmap(
    x, y, logf,
    xlims=extrema(x),
    ylims=extrema(y),
    xlabel="x",
    ylabel="y",
)

Sample

using MCMCChains
using StatsBase
# using multi-ellipsoid for bounds
# using default rejection sampler for proposals
sampler = Nested(2, 500)
chain, state = sample(model, sampler; dlogz=0.01, param_names=["x", "y"])
# resample chain using statistical weights
chain_resampled = sample(chain, Weights(vec(chain[:weights])), length(chain));

Results

chain_resampled
Chains MCMC chain (6331×3×1 Array{Float64, 3}):

Iterations        = 1:6331
Number of chains  = 1
Samples per chain = 6331
parameters        = x, y
internals         = weights

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

           x    0.4383    0.2813    0.0035   6289.3358   6441.7408    1.0000   ⋯
           y    0.5272    0.2856    0.0037   6133.9008   5587.3118    1.0002   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           x    0.0959    0.1029    0.4966    0.6980    0.9030
           y    0.0971    0.3000    0.5005    0.8955    0.9040
marginalkde(chain[:x], chain[:y])
plot!(xlims=(0, 1), ylims=(0, 1), sp=2)
plot!(xlims=(0, 1), sp=1)
plot!(ylims=(0, 1), sp=3)
density(chain_resampled, xlims=(0, 1))
vline!(0.1:0.2:0.9, c=:black, ls=:dash, sp=1)
vline!(0.1:0.2:0.9, c=:black, ls=:dash, sp=2)
using Measurements
logz_est = state.logz ± state.logzerr
diff = logz_est - logz
println("logz: $logz")
println("estimate: $logz_est")
println("diff: $diff")
logz: 235.88
estimate: 235.95 ± 0.11
diff: 0.07 ± 0.11