Skip to content

bamojax.modified_blackjax.modified_tempered

TemperedSMCState

Bases: NamedTuple

Current state for the tempered SMC algorithm.

PyTree

The particles' positions.

lmbda: float Current value of the tempering parameter.

Source code in bamojax/modified_blackjax/modified_tempered.py
28
29
30
31
32
33
34
35
36
37
38
39
40
class TemperedSMCState(NamedTuple):
    """Current state for the tempered SMC algorithm.

    particles: PyTree
        The particles' positions.
    lmbda: float
        Current value of the tempering parameter.

    """

    particles: ArrayTree
    weights: Array
    lmbda: float

update_and_take_last(mcmc_init_fn, tempered_logposterior_fn, shared_mcmc_step_fn, num_mcmc_steps, n_particles)

Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and returns the last values, waisting the previous num_mcmc_steps-1 samples per chain.

Source code in bamojax/modified_blackjax/modified_tempered.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def update_and_take_last(
    mcmc_init_fn,
    tempered_logposterior_fn,
    shared_mcmc_step_fn,
    num_mcmc_steps,
    n_particles,
):
    """
    Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
    returns the last values, waisting the previous num_mcmc_steps-1
    samples per chain.
    """

    def mcmc_kernel(rng_key, position, step_parameters):
        state = mcmc_init_fn(position, tempered_logposterior_fn)

        def body_fn(state, rng_key):
            new_state, info = shared_mcmc_step_fn(
                rng_key, state, tempered_logposterior_fn, **step_parameters
            )
            return new_state, info

        keys = jax.random.split(rng_key, num_mcmc_steps)
        last_state, info = jax.lax.scan(body_fn, state, keys)
        return last_state.position, info

    return jax.vmap(mcmc_kernel), n_particles

build_kernel(logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy=update_and_take_last)

Build the base Tempered SMC kernel.

Tempered SMC uses tempering to sample from a distribution given by

.. math:: p(x) \propto p_0(x) \exp(-V(x)) \mathrm{d}x

where :math:p_0 is the prior distribution, typically easy to sample from and for which the density is easy to compute, and :math:\exp(-V(x)) is an unnormalized likelihood term for which :math:V(x) is easy to compute pointwise.

Parameters

logprior_fn A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. mcmc_step_fn A function that creates a mcmc kernel from a log-probability density function. mcmc_init_fn: Callable A function that creates a new mcmc state from a position and a log-probability density function. resampling_fn A random function that resamples generated particles based of weights num_mcmc_iterations Number of iterations in the MCMC chain.

Returns

A callable that takes a rng_key and a TemperedSMCState that contains the current state of the chain and that returns a new state of the chain along with information about the transition.

Source code in bamojax/modified_blackjax/modified_tempered.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def build_kernel(
    logprior_fn: Callable,
    loglikelihood_fn: Callable,
    mcmc_step_fn: Callable,
    mcmc_init_fn: Callable,
    resampling_fn: Callable,
    update_strategy: Callable = update_and_take_last,
) -> Callable:
    """Build the base Tempered SMC kernel.

    Tempered SMC uses tempering to sample from a distribution given by

    .. math::
        p(x) \\propto p_0(x) \\exp(-V(x)) \\mathrm{d}x

    where :math:`p_0` is the prior distribution, typically easy to sample from
    and for which the density is easy to compute, and :math:`\\exp(-V(x))` is an
    unnormalized likelihood term for which :math:`V(x)` is easy to compute
    pointwise.

    Parameters
    ----------
    logprior_fn
        A function that computes the log density of the prior distribution
    loglikelihood_fn
        A function that returns the probability at a given
        position.
    mcmc_step_fn
        A function that creates a mcmc kernel from a log-probability density function.
    mcmc_init_fn: Callable
        A function that creates a new mcmc state from a position and a
        log-probability density function.
    resampling_fn
        A random function that resamples generated particles based of weights
    num_mcmc_iterations
        Number of iterations in the MCMC chain.

    Returns
    -------
    A callable that takes a rng_key and a TemperedSMCState that contains the current state
    of the chain and that returns a new state of the chain along with
    information about the transition.

    """

    def kernel(
        rng_key: PRNGKey,
        state: TemperedSMCState,
        num_mcmc_steps: int,
        lmbda: float,
        mcmc_parameters: dict,
    ) -> tuple[TemperedSMCState, smc.base.SMCInfo]:
        """Move the particles one step using the Tempered SMC algorithm.

        Parameters
        ----------
        rng_key
            JAX PRNGKey for randomness
        state
            Current state of the tempered SMC algorithm
        lmbda
            Current value of the tempering parameter
        mcmc_parameters
            The parameters of the MCMC step function.  Parameters with leading dimension
            length of 1 are shared amongst the particles.

        Returns
        -------
        state
            The new state of the tempered SMC algorithm
        info
            Additional information on the SMC step

        """
        delta = lmbda - state.lmbda

        # [MODIFICATION]
        mcmc_parameters['temperature'] = state.lmbda*jnp.eye(1)
        # [MODIFICATION]

        shared_mcmc_parameters = {}
        unshared_mcmc_parameters = {}
        for k, v in mcmc_parameters.items():
            if v.shape[0] == 1:
                shared_mcmc_parameters[k] = v[0, ...]
            else:
                unshared_mcmc_parameters[k] = v

        def log_weights_fn(position: ArrayLikeTree) -> float:
            return delta * loglikelihood_fn(position)

        def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
            logprior = logprior_fn(position)
            tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
            return logprior + tempered_loglikelihood

        shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

        update_fn, num_resampled = update_strategy(
            mcmc_init_fn,
            tempered_logposterior_fn,
            shared_mcmc_step_fn,
            n_particles=state.weights.shape[0],
            num_mcmc_steps=num_mcmc_steps,
        )

        smc_state, info = smc.base.step(
            rng_key,
            SMCState(state.particles, state.weights, unshared_mcmc_parameters),
            update_fn,
            jax.vmap(log_weights_fn),
            resampling_fn,
            num_resampled,
        )

        tempered_state = TemperedSMCState(
            smc_state.particles, smc_state.weights, state.lmbda + delta
        )

        return tempered_state, info

    return kernel

as_top_level_api(logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, mcmc_parameters, resampling_fn, num_mcmc_steps=10, update_strategy=update_and_take_last)

Implements the (basic) user interface for the Adaptive Tempered SMC kernel.

Parameters

logprior_fn The log-prior function of the model we wish to draw samples from. loglikelihood_fn The log-likelihood function of the model we wish to draw samples from. mcmc_step_fn The MCMC step function used to update the particles. mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters The parameters of the MCMC step function. Parameters with leading dimension length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. num_mcmc_steps The number of times the MCMC kernel is applied to the particles per step.

Returns

A SamplingAlgorithm.

Source code in bamojax/modified_blackjax/modified_tempered.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def as_top_level_api(
    logprior_fn: Callable,
    loglikelihood_fn: Callable,
    mcmc_step_fn: Callable,
    mcmc_init_fn: Callable,
    mcmc_parameters: dict,
    resampling_fn: Callable,
    num_mcmc_steps: Optional[int] = 10,
    update_strategy=update_and_take_last,
) -> SamplingAlgorithm:
    """Implements the (basic) user interface for the Adaptive Tempered SMC kernel.

    Parameters
    ----------
    logprior_fn
        The log-prior function of the model we wish to draw samples from.
    loglikelihood_fn
        The log-likelihood function of the model we wish to draw samples from.
    mcmc_step_fn
        The MCMC step function used to update the particles.
    mcmc_init_fn
        The MCMC init function used to build a MCMC state from a particle position.
    mcmc_parameters
        The parameters of the MCMC step function.  Parameters with leading dimension
        length of 1 are shared amongst the particles.
    resampling_fn
        The function used to resample the particles.
    num_mcmc_steps
        The number of times the MCMC kernel is applied to the particles per step.

    Returns
    -------
    A ``SamplingAlgorithm``.

    """

    kernel = build_kernel(
        logprior_fn,
        loglikelihood_fn,
        mcmc_step_fn,
        mcmc_init_fn,
        resampling_fn,
        update_strategy,
    )

    def init_fn(position: ArrayLikeTree, rng_key=None):
        del rng_key
        return init(position)

    def step_fn(rng_key: PRNGKey, state, lmbda):
        return kernel(
            rng_key,
            state,
            num_mcmc_steps,
            lmbda,
            mcmc_parameters,
        )

    return SamplingAlgorithm(init_fn, step_fn)  # type: ignore[arg-type]