Skip to content

bamojax.modified_blackjax.modified_adaptive_tempered

build_kernel(logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, resampling_fn, target_ess, root_solver=solver.dichotomy, **extra_parameters)

Build a Tempered SMC step using an adaptive schedule.

Parameters

logprior_fn: Callable A function that computes the log-prior density. loglikelihood_fn: Callable A function that returns the log-likelihood density. mcmc_kernel_factory: Callable A callable function that creates a mcmc kernel from a log-probability density function. make_mcmc_state: Callable A function that creates a new mcmc state from a position and a log-probability density function. resampling_fn: Callable A random function that resamples generated particles based of weights target_ess: float The target ESS for the adaptive MCMC tempering root_solver: Callable, optional A solver utility to find delta matching the target ESS. Signature is root_solver(fun, delta_0, min_delta, max_delta), default is a dichotomy solver use_log_ess: bool, optional Use ESS in log space to solve for delta, default is True. This is usually more stable when using gradient based solvers.

Returns

A callable that takes a rng_key and a TemperedSMCState that contains the current state of the chain and that returns a new state of the chain along with information about the transition.

Source code in bamojax/modified_blackjax/modified_adaptive_tempered.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def build_kernel(
    logprior_fn: Callable,
    loglikelihood_fn: Callable,
    mcmc_step_fn: Callable,
    mcmc_init_fn: Callable,
    resampling_fn: Callable,
    target_ess: float,
    root_solver: Callable = solver.dichotomy,
    **extra_parameters,
) -> Callable:
    r"""Build a Tempered SMC step using an adaptive schedule.

    Parameters
    ----------
    logprior_fn: Callable
        A function that computes the log-prior density.
    loglikelihood_fn: Callable
        A function that returns the log-likelihood density.
    mcmc_kernel_factory: Callable
        A callable function that creates a mcmc kernel from a log-probability
        density function.
    make_mcmc_state: Callable
        A function that creates a new mcmc state from a position and a
        log-probability density function.
    resampling_fn: Callable
        A random function that resamples generated particles based of weights
    target_ess: float
        The target ESS for the adaptive MCMC tempering
    root_solver: Callable, optional
        A solver utility to find delta matching the target ESS. Signature is
        `root_solver(fun, delta_0, min_delta, max_delta)`, default is a dichotomy solver
    use_log_ess: bool, optional
        Use ESS in log space to solve for delta, default is `True`.
        This is usually more stable when using gradient based solvers.

    Returns
    -------
    A callable that takes a rng_key and a TemperedSMCState that contains the current state
    of the chain and that returns a new state of the chain along with
    information about the transition.

    """

    def compute_delta(state: tempered.TemperedSMCState) -> float:
        lmbda = state.lmbda
        max_delta = 1 - lmbda
        delta = ess.ess_solver(
            jax.vmap(loglikelihood_fn),
            state.particles,
            target_ess,
            max_delta,
            root_solver,
        )
        delta = jnp.clip(delta, 0.0, max_delta)

        return delta

    tempered_kernel = tempered.build_kernel(
        logprior_fn,
        loglikelihood_fn,
        mcmc_step_fn,
        mcmc_init_fn,
        resampling_fn,
        **extra_parameters,
    )

    def kernel(
        rng_key: PRNGKey,
        state: tempered.TemperedSMCState,
        num_mcmc_steps: int,
        mcmc_parameters: dict,
    ) -> tuple[tempered.TemperedSMCState, base.SMCInfo]:
        delta = compute_delta(state)
        lmbda = delta + state.lmbda
        return tempered_kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters)

    return kernel

as_top_level_api(logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, mcmc_parameters, resampling_fn, target_ess, root_solver=solver.dichotomy, num_mcmc_steps=10, **extra_parameters)

Implements the (basic) user interface for the Adaptive Tempered SMC kernel.

Parameters

logprior_fn The log-prior function of the model we wish to draw samples from. loglikelihood_fn The log-likelihood function of the model we wish to draw samples from. mcmc_step_fn The MCMC step function used to update the particles. mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters The parameters of the MCMC step function. Parameters with leading dimension length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. target_ess The number of effective sample size to aim for at each step. root_solver The solver used to adaptively compute the temperature given a target number of effective samples. num_mcmc_steps The number of times the MCMC kernel is applied to the particles per step.

Returns

A SamplingAlgorithm.

Source code in bamojax/modified_blackjax/modified_adaptive_tempered.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def as_top_level_api(
    logprior_fn: Callable,
    loglikelihood_fn: Callable,
    mcmc_step_fn: Callable,
    mcmc_init_fn: Callable,
    mcmc_parameters: dict,
    resampling_fn: Callable,
    target_ess: float,
    root_solver: Callable = solver.dichotomy,
    num_mcmc_steps: int = 10,
    **extra_parameters,
) -> SamplingAlgorithm:
    """Implements the (basic) user interface for the Adaptive Tempered SMC kernel.

    Parameters
    ----------
    logprior_fn
        The log-prior function of the model we wish to draw samples from.
    loglikelihood_fn
        The log-likelihood function of the model we wish to draw samples from.
    mcmc_step_fn
        The MCMC step function used to update the particles.
    mcmc_init_fn
        The MCMC init function used to build a MCMC state from a particle position.
    mcmc_parameters
        The parameters of the MCMC step function.  Parameters with leading dimension
        length of 1 are shared amongst the particles.
    resampling_fn
        The function used to resample the particles.
    target_ess
        The number of effective sample size to aim for at each step.
    root_solver
        The solver used to adaptively compute the temperature given a target number
        of effective samples.
    num_mcmc_steps
        The number of times the MCMC kernel is applied to the particles per step.

    Returns
    -------
    A ``SamplingAlgorithm``.

    """
    kernel = build_kernel(
        logprior_fn,
        loglikelihood_fn,
        mcmc_step_fn,
        mcmc_init_fn,
        resampling_fn,
        target_ess,
        root_solver,
        **extra_parameters,
    )

    def init_fn(position: ArrayLikeTree, rng_key=None):
        del rng_key
        return init(position)

    def step_fn(rng_key: PRNGKey, state):
        return kernel(
            rng_key,
            state,
            num_mcmc_steps,
            mcmc_parameters,
        )

    return SamplingAlgorithm(init_fn, step_fn)