bamojax.modified_blackjax.modified_adaptive_tempered
¶
build_kernel(logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, resampling_fn, target_ess, root_solver=solver.dichotomy, **extra_parameters)
¶
Build a Tempered SMC step using an adaptive schedule.
Parameters¶
logprior_fn: Callable
A function that computes the log-prior density.
loglikelihood_fn: Callable
A function that returns the log-likelihood density.
mcmc_kernel_factory: Callable
A callable function that creates a mcmc kernel from a log-probability
density function.
make_mcmc_state: Callable
A function that creates a new mcmc state from a position and a
log-probability density function.
resampling_fn: Callable
A random function that resamples generated particles based of weights
target_ess: float
The target ESS for the adaptive MCMC tempering
root_solver: Callable, optional
A solver utility to find delta matching the target ESS. Signature is
root_solver(fun, delta_0, min_delta, max_delta)
, default is a dichotomy solver
use_log_ess: bool, optional
Use ESS in log space to solve for delta, default is True
.
This is usually more stable when using gradient based solvers.
Returns¶
A callable that takes a rng_key and a TemperedSMCState that contains the current state of the chain and that returns a new state of the chain along with information about the transition.
Source code in bamojax/modified_blackjax/modified_adaptive_tempered.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
|
as_top_level_api(logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, mcmc_parameters, resampling_fn, target_ess, root_solver=solver.dichotomy, num_mcmc_steps=10, **extra_parameters)
¶
Implements the (basic) user interface for the Adaptive Tempered SMC kernel.
Parameters¶
logprior_fn The log-prior function of the model we wish to draw samples from. loglikelihood_fn The log-likelihood function of the model we wish to draw samples from. mcmc_step_fn The MCMC step function used to update the particles. mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters The parameters of the MCMC step function. Parameters with leading dimension length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. target_ess The number of effective sample size to aim for at each step. root_solver The solver used to adaptively compute the temperature given a target number of effective samples. num_mcmc_steps The number of times the MCMC kernel is applied to the particles per step.
Returns¶
A SamplingAlgorithm
.
Source code in bamojax/modified_blackjax/modified_adaptive_tempered.py
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
|