TuringCallbacks

Getting started

As the package is not yet officially released, the package has to be added from the GitHub repository:

julia> ]
pkg> add TuringCallbacks.jl

Visualizing sampling on-the-fly

TensorBoardCallback is a wrapper around Base.CoreLogging.AbstractLogger which can be used to create a callback compatible with Turing.sample.

To actually visualize the results of the logging, you need to have installed tensorboard in Python. If you do not have tensorboard installed, it should hopefully be sufficient to just run

pip3 install tensorboard

Then you can start up the TensorBoard:

python3 -m tensorboard.main --logdir tensorboard_logs/run

Now we're ready to actually write some Julia code.

The following snippet demonstrates the usage of TensorBoardCallback on a simple model. This will write a set of statistics at each iteration to an event-file compatible with Tensorboard:

using Turing, TuringCallbacks

@model function demo(x)
    s ~ InverseGamma(2, 3)
    m ~ Normal(0, √s)
    for i in eachindex(x)
        x[i] ~ Normal(m, √s)
    end
end

xs = randn(100) .+ 1;
model = demo(xs);

# Number of MCMC samples/steps
num_samples = 10_000
num_adapts = 100

# Sampling algorithm to use
alg = NUTS(num_adapts, 0.65)

# Create the callback
callback = TensorBoardCallback("tensorboard_logs/run")

# Sample
chain = sample(model, alg, num_samples; callback = callback)

While this is sampling, you can head right over to localhost:6006 in your web browser and you should be seeing some plots!

TensorBoard dashboard

In particular, note the "Distributions" tab in the above picture. Clicking this, you should see something similar to:

TensorBoard dashboard

And finally, the "Histogram" tab shows a slightly more visually pleasing version of the marginal distributions:

TensorBoard dashboard

Note that the names of the stats following a naming $variable_name/... where $variable_name refers to name of the variable in the model.

Choosing what and how you log

Statistics

In the above example we didn't provide any statistics explicit and so it used the default statistics, e.g. Mean and Variance. But using other statistics is easy! Here's a much more interesting example:

# Create the stats (estimators are sub-types of `OnlineStats.OnlineStat`)
stats = Skip(
    num_adapts, # Consider adaptation steps
    Series(
        # Estimators using the entire chain
        Series(Mean(), Variance(), AutoCov(10), KHist(100)),
        # Estimators using the entire chain but only every 10-th sample
        Thin(10, Series(Mean(), Variance(), AutoCov(10), KHist(100))),
        # Estimators using only the last 1000 samples
        WindowStat(1000, Series(Mean(), Variance(), AutoCov(10), KHist(100)))
    )
)
# Create the callback
callback = TensorBoardCallback("tensorboard_logs/run", stats)

# Sample
chain = sample(model, alg, num_samples; callback = callback)

Tada! Now you should be seeing waaaay more interesting statistics in your TensorBoard dashboard. See the OnlineStats.jl documentation for more on the different statistics, with the exception of Thin, Skip and WindowStat which are implemented in this package.

Note that these statistic estimators are stateful, and therefore the following is bad:

julia> s = AutoCov(5)AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN]
julia> stat = Series(s, s) # => 10 samples but `n=20` since we've called `fit!` twice for each observationSeries ├─ AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN] └─ AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN]
julia> fit!(stat, randn(10))Series ├─ AutoCov: n=20 | value=[0.58937, 0.256619, -0.0761323, -0.00642576, 0.0632807, -0.0191778] └─ AutoCov: n=20 | value=[0.58937, 0.256619, -0.0761323, -0.00642576, 0.0632807, -0.0191778]

while the following is good:

julia> stat = Series(AutoCov(5), AutoCov(5))
       # => 10 samples AND `n=10`; great!Series
├─ AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN]
└─ AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN]
julia> fit!(stat, randn(10))Series ├─ AutoCov: n=10 | value=[0.631969, -0.316744, -0.0267961, 0.240407, -0.280301, 0.213093] └─ AutoCov: n=10 | value=[0.631969, -0.316744, -0.0267961, 0.240407, -0.280301, 0.213093]

Since at the moment the only support statistics are sub-types of OnlineStats.OnlineStat. If you want to log some custom statistic, again, at the moment, you have to make a sub-type and implement OnlineStats.fit! and OnlineStats.value. By default, a OnlineStat is passed to tensorboard by simply calling OnlineStat.value(stat). Therefore, if you also want to customize how a stat is passed to tensorbord, you need to overload TensorBoardLogger.preprocess(name, stat, data) accordingly.

Filter variables to log

Maybe you want to only log stats for certain variables, e.g. in the above example we might want to exclude m and exclude the sampler statistics:

callback = TensorBoardCallback(
    "tensorboard_logs/run", stats;
    exclude = ["m", ], include_extras = false
)

Or you can create the filter (a mapping variable_name -> ::Bool) yourself:

var_filter(varname, value) = varname != "m"
callback = TensorBoardCallback(
    "tensorboard_logs/run", stats;
    filter = var_filter
)

Supporting TensorBoardCallback with your own sampler

It's also possible to make your own sampler compatible with TensorBoardCallback.

To do so, you need to implement the following method:

TuringCallbacks.params_and_valuesFunction
params_and_values(model, transition[, state]; kwargs...)
params_and_values(model, sampler, transition, state; kwargs...)

Return an iterator over parameter names and values from a transition.

source

If you don't have any particular names for your parameters, you're free to make use of the convenience method

Note

The params_and_values(model, sampler, transition, state; kwargs...) is not usually overloaded, but it can sometimes be useful for defining more complex behaviors.

For example, if the transition for your MySampler is just a Vector{Float64}, a basic implementation of TuringCallbacks.params_and_values would just be

function TuringCallbacks.params_and_values(transition::Vectorr{Float64}; kwargs...)
    param_names = TuringCallbacks.default_param_names_for_values(transition)
    return zip(param_names, transition)
end

Or sometimes the user might pass the parameter names in as a keyword argument, and so you might want to support that with something like

function TuringCallbacks.params_and_values(transition::Vectorr{Float64}; param_names = nothing, kwargs...)
    param_names = isnothing(param_names) ? TuringCallbacks.default_param_names_for_values(transition) : param_names
    return zip(param_names, transition)
end

Finally, if you in addition want to log "extra" information, e.g. some sampler statistics you're keeping track of, you also need to implement

TuringCallbacks.extrasFunction
extras(model, transition[, state]; kwargs...)
extras(model, sampler, transition, state; kwargs...)

Return an iterator with elements of the form (name, value) for additional statistics in transition.

Default implementation returns an empty iterator.

source

Types & Functions

TuringCallbacks.SkipType
mutable struct Skip{T, O<:OnlineStat{T}} <: OnlineStat{T}

Usage

Skip(b::Int, stat::OnlineStat)

Skips the first b observations before passing them on to stat.

source
TuringCallbacks.TensorBoardCallbackType
struct TensorBoardCallback{L, F1, F2, F3}

Wraps a CoreLogging.AbstractLogger to construct a callback to be passed to AbstractMCMC.step.

Usage

TensorBoardCallback(; kwargs...)
TensorBoardCallback(directory::string[, stats]; kwargs...)
TensorBoardCallback(lg::AbstractLogger[, stats]; kwargs...)

Constructs an instance of a TensorBoardCallback, creating a TBLogger if directory is provided instead of lg.

Arguments

  • lg: an instance of an AbstractLogger which implements TuringCallbacks.increment_step!.
  • stats = nothing: OnlineStat or lookup for variable name to statistic estimator. If stats isa OnlineStat, we will create a DefaultDict which copies stats for unseen variable names. If isnothing, then a DefaultDict with a default constructor returning a OnlineStats.Series estimator with Mean(), Variance(), and KHist(num_bins) will be used.

Keyword arguments

  • num_bins::Int = 100: Number of bins to use in the histograms.
  • filter = nothing: Filter determining whether or not we should log stats for a particular variable and value; expected signature is filter(varname, value). If isnothing a default-filter constructed from exclude and include will be used.
  • exclude = String[]: If non-empty, these variables will not be logged.
  • include = String[]: If non-empty, only these variables will be logged.
  • include_extras::Bool = true: Include extra statistics from transitions.
  • extras_include = String[]: If non-empty, only these extra statistics will be logged.
  • extras_exclude = String[]: If non-empty, these extra statistics will not be logged.
  • extras_filter = nothing: Filter determining whether or not we should log extra statistics; expected signature is filter(extra, value). If isnothing a default-filter constructed from extras_exclude and extras_include will be used.
  • include_hyperparams::Bool = true: Include hyperparameters.
  • hyperparam_include = String[]: If non-empty, only these hyperparameters will be logged.
  • hyperparam_exclude = String[]: If non-empty, these hyperparameters will not be logged.
  • hyperparam_filter = nothing: Filter determining whether or not we should log hyperparameters; expected signature is filter(hyperparam, value). If isnothing a default-filter constructed from hyperparam_exclude and hyperparam_include will be used.
  • directory::String = nothing: if specified, will together with comment be used to define the logging directory.
  • comment::String = nothing: if specified, will together with directory be used to define the logging directory.

Fields

  • logger::Base.CoreLogging.AbstractLogger: Underlying logger.

  • stats::Any: Lookup for variable name to statistic estimate.

  • variable_filter::Any: Filter determining whether to include stats for a particular variable.

  • include_extras::Bool: Include extra statistics from transitions.

  • extras_filter::Any: Filter determining whether to include a particular extra statistic.

  • include_hyperparams::Bool: Include hyperparameters.

  • hyperparam_filter::Any: Filter determining whether to include a particular hyperparameter.

  • param_prefix::String: Prefix used for logging realizations/parameters

  • extras_prefix::String: Prefix used for logging extra statistics

source
TuringCallbacks.ThinType
mutable struct Thin{T, O<:OnlineStat{T}} <: OnlineStat{T}

Usage

Thin(b::Int, stat::OnlineStat)

Thins stat with an interval b, i.e. only passes every b-th observation to stat.

source
TuringCallbacks.WindowStatType
struct WindowStat{T, O} <: OnlineStat{T}

Usage

WindowStat(b::Int, stat::O) where {O <: OnlineStat}

"Wraps" stat in a MovingWindow of length b.

value(o::WindowStat) will then return an OnlineStat of the same type as stat, which is only fitted on the batched data contained in the MovingWindow.

source

Internals

TuringCallbacks.extrasMethod
extras(model, transition[, state]; kwargs...)
extras(model, sampler, transition, state; kwargs...)

Return an iterator with elements of the form (name, value) for additional statistics in transition.

Default implementation returns an empty iterator.

source
TuringCallbacks.hyperparamsMethod
hyperparams(model, sampler[, transition, state]; kwargs...)

Return an iterator with elements of the form (name, value) for hyperparameters in model.

source
TuringCallbacks.params_and_valuesMethod
params_and_values(model, transition[, state]; kwargs...)
params_and_values(model, sampler, transition, state; kwargs...)

Return an iterator over parameter names and values from a transition.

source
TuringCallbacks.tb_nameMethod
tb_name(args...)

Returns a string representing the name for arg or args in TensorBoard.

If length(args) > 1, args are joined together by "/".

source

Index