TuringCallbacks
Getting started
As the package is not yet officially released, the package has to be added from the GitHub repository:
julia> ]
pkg> add TuringCallbacks.jl
Visualizing sampling on-the-fly
TensorBoardCallback
is a wrapper around TensorBoardLogger.TBLogger
which can be used to create a callback
compatible with Turing.sample
.
To actually visualize the results of the logging, you need to have installed tensorboard
in Python. If you do not have tensorboard
installed, it should hopefully be sufficient to just run
pip3 install tensorboard
Then you can start up the TensorBoard
:
python3 -m tensorboard.main --logdir tensorboard_logs/run
Now we're ready to actually write some Julia code.
The following snippet demonstrates the usage of TensorBoardCallback
on a simple model. This will write a set of statistics at each iteration to an event-file compatible with Tensorboard:
using Turing, TuringCallbacks
@model function demo(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(x)
x[i] ~ Normal(m, √s)
end
end
xs = randn(100) .+ 1;
model = demo(xs);
# Number of MCMC samples/steps
num_samples = 10_000
num_adapts = 100
# Sampling algorithm to use
alg = NUTS(num_adapts, 0.65)
# Create the callback
callback = TensorBoardCallback("tensorboard_logs/run")
# Sample
chain = sample(model, alg, num_samples; callback = callback)
While this is sampling, you can head right over to localhost:6006
in your web browser and you should be seeing some plots!
In particular, note the "Distributions" tab in the above picture. Clicking this, you should see something similar to:
And finally, the "Histogram" tab shows a slighly more visually pleasing version of the marginal distributions:
Note that the names of the stats following a naming $variable_name/...
where $variable_name
refers to name of the variable in the model.
Choosing what and how you log
Statistics
In the above example we didn't provide any statistics explicit and so it used the default statistics, e.g. Mean
and Variance
. But using other statistics is easy! Here's a much more interesting example:
# Create the stats (estimators are sub-types of `OnlineStats.OnlineStat`)
stats = Skip(
num_adapts, # Consider adaptation steps
Series(
# Estimators using the entire chain
Series(Mean(), Variance(), AutoCov(10), KHist(100)),
# Estimators using the entire chain but only every 10-th sample
Thin(10, Series(Mean(), Variance(), AutoCov(10), KHist(100))),
# Estimators using only the last 1000 samples
WindowStat(1000, Series(Mean(), Variance(), AutoCov(10), KHist(100)))
)
)
# Create the callback
callback = TensorBoardCallback("tensorboard_logs/run", stats)
# Sample
chain = sample(model, alg, num_samples; callback = callback)
Tada! Now you should be seeing waaaay more interesting statistics in your TensorBoard dashboard. See the OnlineStats.jl
documentation for more on the different statistics, with the exception of Thin
, Skip
and WindowStat
which are implemented in this package.
Note that these statistic estimators are stateful, and therefore the following is bad:
julia> s = AutoCov(5)
AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN]
julia> stat = Series(s, s)
Series ├─ AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN] └─ AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN]
julia> # => 10 samples but `n=20` since we've called `fit!` twice for each observation fit!(stat, randn(10))
Series ├─ AutoCov: n=20 | value=[0.912205, 0.260832, -0.390541, -0.132463, 0.125615, 0.0644281] └─ AutoCov: n=20 | value=[0.912205, 0.260832, -0.390541, -0.132463, 0.125615, 0.0644281]
while the following is good:
julia> stat = Series(AutoCov(5), AutoCov(5))
Series ├─ AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN] └─ AutoCov: n=0 | value=[NaN, NaN, NaN, NaN, NaN, NaN]
julia> # => 10 samples AND `n=10`; great! fit!(stat, randn(10))
Series ├─ AutoCov: n=10 | value=[1.25101, -0.242006, -0.487156, 0.115114, 0.102494, -0.28756] └─ AutoCov: n=10 | value=[1.25101, -0.242006, -0.487156, 0.115114, 0.102494, -0.28756]
Since at the moment the only support statistics are sub-types of OnlineStats.OnlineStat
. If you want to log some custom statistic, again, at the moment, you have to make a sub-type and implement OnlineStats.fit!
and OnlineStats.value
. By default, a OnlineStat
is passed to tensorboard
by simply calling OnlineStat.value(stat)
. Therefore, if you also want to customize how a stat is passed to tensorbord
, you need to overload TensorBoardLogger.preprocess(name, stat, data)
accordingly.
Filter variables to log
Maybe you want to only log stats for certain variables, e.g. in the above example we might want to exclude m
and exclude the sampler statistics:
callback = TensorBoardCallback(
"tensorboard_logs/run", stats;
exclude = ["m", ], include_extras = false
)
Or you can create the filter (a mapping variable_name -> ::Bool
yourself:
var_filter(varname) = varname != "m"
callback = TensorBoardCallback(
"tensorboard_logs/run", stats;
variable_filter = var_filter
)
Types & Functions
TuringCallbacks.Skip
— Typemutable struct Skip{T, O<:OnlineStat{T}} <: OnlineStat{T}
Usage
Skip(b::Int, stat::OnlineStat)
Skips the first b
observations before passing them on to stat
.
TuringCallbacks.TensorBoardCallback
— Typestruct TensorBoardCallback{F, L}
Wraps a TensorBoardLogger.TBLogger
to construct a callback to be passed to Turing.sample
.
Usage
TensorBoardCallback(; kwargs...)
TensorBoardCallback(directory::string[, stats]; kwargs...)
TensorBoardCallback(lg::TBLogger[, stats]; kwargs...)
Constructs an instance of a TensorBoardCallback
, creating a TBLogger
if directory
is provided instead of lg
.
Arguments
stats = nothing
:OnlineStat
or lookup for variable name to statistic estimator. Ifstats isa OnlineStat
, we will create aDefaultDict
which copiesstats
for unseen variable names. Ifisnothing
, then aDefaultDict
with a default constructor returning aOnlineStats.Series
estimator withMean()
,Variance()
, andKHist(num_bins)
will be used.
Keyword arguments
num_bins::Int = 100
: Number of bins to use in the histograms.variable_filter = nothing
: Filter determining whether or not we should log stats for a particular variable. Ifisnothing
a default-filter constructed fromexclude
andinclude
will be used.exclude = String[]
: If non-empty, these variables will not be logged.include = String[]
: If non-empty, only these variables will be logged.include_extras::Bool = true
: Include extra statistics from transitions.directory::String = nothing
: if specified, will together withcomment
be used to define the logging directory.comment::String = nothing
: if specified, will together withdirectory
be used to define the logging directory.
Fields
logger::TensorBoardLogger.TBLogger
Underlying logger.
variable_filter::Any
Filter determining whether or not we should log stats for a particular variable.
include_extras::Bool
Include extra statistics from transitions.
stats::Any
Lookup for variable name to statistic estimate.
TuringCallbacks.Thin
— Typemutable struct Thin{T, O<:OnlineStat{T}} <: OnlineStat{T}
Usage
Thin(b::Int, stat::OnlineStat)
Thins stat
with an interval b
, i.e. only passes every b-th observation to stat
.
TuringCallbacks.WindowStat
— Typestruct WindowStat{T, O} <: OnlineStat{T}
Usage
WindowStat(b::Int, stat::O) where {O <: OnlineStat}
"Wraps" stat
in a MovingWindow
of length b
.
value(o::WindowStat)
will then return an OnlineStat
of the same type as stat
, which is only fitted on the batched data contained in the MovingWindow
.
Internals
TuringCallbacks.tb_name
— Methodtb_name(args...)
Returns a string
representing the name for arg
or args
in TensorBoard.
If length(args) > 1
, args
are joined together by "/"
.