Skip to content

bamojax.modified_blackjax.modified_elliptical_slice_nd

Public API for the Elliptical Slice sampling Kernel

EllipSliceState

Bases: NamedTuple

State of the Elliptical Slice sampling algorithm.

position Current position of the chain. logdensity Current value of the logdensity (evaluated at current position).

Source code in bamojax/modified_blackjax/modified_elliptical_slice_nd.py
33
34
35
36
37
38
39
40
41
42
43
44
class EllipSliceState(NamedTuple):
    """State of the Elliptical Slice sampling algorithm.

    position
        Current position of the chain.
    logdensity
        Current value of the logdensity (evaluated at current position).

    """

    position: ArrayTree
    logdensity: ArrayTree

EllipSliceInfo

Bases: NamedTuple

Additional information on the Elliptical Slice sampling chain.

This additional information can be used for debugging or computing diagnostics.

momentum The latent momentum variable returned at the end of the transition. theta A value between [-2\pi, 2\pi] identifying points in the ellipsis drawn from the positon and mommentum variables. This value indicates the theta value of the accepted proposal. subiter Number of sub iterations needed to accept a proposal. The more subiterations needed the less efficient the algorithm will be, and the more dependent the new value is likely to be to the previous value.

Source code in bamojax/modified_blackjax/modified_elliptical_slice_nd.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class EllipSliceInfo(NamedTuple):
    r"""Additional information on the Elliptical Slice sampling chain.

    This additional information can be used for debugging or computing
    diagnostics.

    momentum
        The latent momentum variable returned at the end of the transition.
    theta
        A value between [-2\pi, 2\pi] identifying points in the ellipsis drawn
        from the positon and mommentum variables. This value indicates the theta
        value of the accepted proposal.
    subiter
        Number of sub iterations needed to accept a proposal. The more subiterations
        needed the less efficient the algorithm will be, and the more dependent the
        new value is likely to be to the previous value.

    """

    momentum: ArrayTree
    theta: float
    subiter: int

build_kernel(cov_matrix, mean, nd=None)

Build an Elliptical Slice sampling kernel :cite:p:murray2010elliptical.

Parameters

cov_matrix The value of the covariance matrix of the gaussian prior distribution from the posterior we wish to sample.

Returns

A kernel that takes a rng_key and a Pytree that contains the current state of the chain and that returns a new state of the chain along with information about the transition.

Source code in bamojax/modified_blackjax/modified_elliptical_slice_nd.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def build_kernel(cov_matrix: Array, mean: Array, nd: tuple = None):
    """Build an Elliptical Slice sampling kernel :cite:p:`murray2010elliptical`.

    Parameters
    ----------
    cov_matrix
        The value of the covariance matrix of the gaussian prior distribution from
        the posterior we wish to sample.

    Returns
    -------
    A kernel that takes a rng_key and a Pytree that contains the current state
    of the chain and that returns a new state of the chain along with
    information about the transition.

    """
    ndim = jnp.ndim(cov_matrix)  # type: ignore[arg-type]
    n = cov_matrix.shape[0]

    if len(nd) == 2:
        d, nu = nd        
        flat_shape = (d*nu*n, )
    elif len(nd) == 1:
        d = 1
        nu = nd[0]
        flat_shape = (nu*n, )
    else:
        raise NotImplementedError(f'Elliptical slice sampling is not implemented for nd = {nd}')

    if ndim == 1:  # diagonal covariance matrix
        cov_matrix_sqrt = jnp.sqrt(cov_matrix)

    elif ndim == 2:
        cov_matrix_sqrt = jax.lax.linalg.cholesky(cov_matrix)

    else:
        raise ValueError(
            "The mass matrix has the wrong number of dimensions:"
            f" expected 1 or 2, got {jnp.ndim(cov_matrix)}."  # type: ignore[arg-type]
        )

    def to_flat(u):
        return jax.tree_util.tree_map(lambda l: jnp.reshape(l, shape=flat_shape), u)

    def to_nd(u):
        return jax.tree_util.tree_map(lambda l: jnp.reshape(l, shape=nd + (n, )), u)

    def momentum_generator(rng_key, position):
        generate_noise_fn = lambda k, p: generate_gaussian_noise(k, p, mean, cov_matrix_sqrt)
        # [MODIFICATIONS] vmap the noise generation over the dimensions nd, then reshape into flattened array

        u = to_nd(position)

        if d==1 and nu==1:
            z = generate_noise_fn(rng_key, position)
        elif d>1 and nu==1:
            keys_d = jax.random.split(rng_key, d)
            z = jax.vmap(generate_noise_fn, in_axes=(0, 0))(keys_d, u)
        elif d==1 and nu > 1:
            keys_nu = jax.random.split(rng_key, nu)
            z = jax.vmap(generate_noise_fn, in_axes=(0, 0))(keys_nu, u)
        elif d>1 and nu>1:
            keys_d = jax.random.split(rng_key, d)
            keys_dnu = jax.vmap(lambda k: jax.random.split(k, nu))(keys_d)
            z = jax.vmap(jax.vmap(generate_noise_fn, in_axes=(0, 0)), in_axes=(0, 0))(keys_dnu, u)
        else:
            raise NotImplementedError
        z_reshaped = to_flat(z)
        return z_reshaped

    def elliptical_proposal(
        logdensity_fn: Callable,
        momentum_generator: Callable,
        mean: Array,
    ) -> Callable:
        """Build an Ellitpical slice sampling kernel.

        The algorithm samples a latent parameter, traces an ellipse connecting the
        initial position and the latent parameter and does slice sampling on this
        ellipse to output a new sample from the posterior distribution.

        Parameters
        ----------
        logdensity_fn
            A function that returns the log-likelihood at a given position.
        momentum_generator
            A function that generates a new latent momentum variable.

        Returns
        -------
        A kernel that takes a rng_key and a Pytree that contains the current state
        of the chain and that returns a new state of the chain along with
        information about the transition.

        """
        num_el = d*nu
        mean = jnp.tile(mean, reps=num_el)

        def generate(
            rng_key: PRNGKey, state: EllipSliceState
        ) -> tuple[EllipSliceState, EllipSliceInfo]:
            position, logdensity = state
            position = to_flat(position)
            key_slice, key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 4)
            # step 1: sample momentum
            momentum = momentum_generator(key_momentum, position) 
            # step 2: get slice (y)
            logy = logdensity + jnp.log(jax.random.uniform(key_uniform))
            # step 3: get theta (ellipsis move), set inital interval
            theta = 2 * jnp.pi * jax.random.uniform(key_theta)
            theta_min = theta - 2 * jnp.pi
            theta_max = theta
            # step 4: proposal
            p, m = ellipsis(position, momentum, theta, mean)
            # step 5: acceptance
            logdensity = logdensity_fn(p)

            def slice_fn(vals):
                """Perform slice sampling around the ellipsis.

                Checks if the proposed position's likelihood is larger than the slice
                variable. Returns the position if True, shrinks the bracket for sampling
                `theta` and samples a new proposal if False.

                As the bracket `[theta_min, theta_max]` shrinks, the proposal gets closer
                to the original position, which has likelihood larger than the slice variable.
                It is guaranteed to stop in a finite number of iterations as long as the
                likelihood is continuous with respect to the parameter being sampled.

                """
                _, subiter, theta, theta_min, theta_max, *_ = vals
                thetak = jax.random.fold_in(key_slice, subiter)
                theta = jax.random.uniform(thetak, minval=theta_min, maxval=theta_max)
                p, m = ellipsis(position, momentum, theta, mean)
                logdensity = logdensity_fn(p)
                theta_min = jnp.where(theta < 0, theta, theta_min)
                theta_max = jnp.where(theta > 0, theta, theta_max)
                subiter += 1
                return logdensity, subiter, theta, theta_min, theta_max, p, m

            logdensity, subiter, theta, *_, position, momentum = jax.lax.while_loop(
                lambda vals: vals[0] <= logy,
                slice_fn,
                (logdensity, 1, theta, theta_min, theta_max, p, m),
            )
            position = to_nd(position)
            return (
                EllipSliceState(position, logdensity),
                EllipSliceInfo(momentum, theta, subiter),
            )

        return generate


    def kernel(
        rng_key: PRNGKey,
        state: EllipSliceState,
        logdensity_fn: Callable,
    ) -> tuple[EllipSliceState, EllipSliceInfo]:
        proposal_generator = elliptical_proposal(
            logdensity_fn, momentum_generator, mean, 
        )       
        return proposal_generator(rng_key, state)

    return kernel

as_top_level_api(loglikelihood_fn, *, mean, cov, nd)

Implements the (basic) user interface for the Elliptical Slice sampling kernel.

Examples

A new Elliptical Slice sampling kernel can be initialized and used with the following code:

.. code::

ellip_slice = blackjax.elliptical_slice(loglikelihood_fn, cov_matrix)
state = ellip_slice.init(position)
new_state, info = ellip_slice.step(rng_key, state)

We can JIT-compile the step function for better performance

.. code::

step = jax.jit(ellip_slice.step)
new_state, info = step(rng_key, state)

Parameters

loglikelihood_fn Only the log likelihood function from the posterior distributon we wish to sample. cov_matrix The value of the covariance matrix of the gaussian prior distribution from the posterior we wish to sample.

Returns

A SamplingAlgorithm.

Source code in bamojax/modified_blackjax/modified_elliptical_slice_nd.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def as_top_level_api(
    loglikelihood_fn: Callable,
    *,
    mean: Array,
    cov: Array,
    nd: tuple,
) -> SamplingAlgorithm:
    """Implements the (basic) user interface for the Elliptical Slice sampling kernel.

    Examples
    --------

    A new Elliptical Slice sampling kernel can be initialized and used with the following code:

    .. code::

        ellip_slice = blackjax.elliptical_slice(loglikelihood_fn, cov_matrix)
        state = ellip_slice.init(position)
        new_state, info = ellip_slice.step(rng_key, state)

    We can JIT-compile the step function for better performance

    .. code::

        step = jax.jit(ellip_slice.step)
        new_state, info = step(rng_key, state)

    Parameters
    ----------
    loglikelihood_fn
        Only the log likelihood function from the posterior distributon we wish to sample.
    cov_matrix
        The value of the covariance matrix of the gaussian prior distribution from the posterior we wish to sample.

    Returns
    -------
    A ``SamplingAlgorithm``.
    """
    kernel = build_kernel(cov, mean, nd)

    def init_fn(position: ArrayLikeTree, rng_key=None):
        del rng_key
        return init(position, loglikelihood_fn)

    def step_fn(rng_key: PRNGKey, state):
        return kernel(
            rng_key,
            state,
            loglikelihood_fn,
        )

    return SamplingAlgorithm(init_fn, step_fn)

ellipsis(position, momentum, theta, mean)

Generate proposal from the ellipsis.

Given a scalar theta indicating a point on the circumference of the ellipsis and the shared mean vector for both position and momentum variables, generate proposed position and momentum to later accept or reject depending on the slice variable.

Source code in bamojax/modified_blackjax/modified_elliptical_slice_nd.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def ellipsis(position, momentum, theta, mean):
    """Generate proposal from the ellipsis.

    Given a scalar theta indicating a point on the circumference of the ellipsis
    and the shared mean vector for both position and momentum variables,
    generate proposed position and momentum to later accept or reject
    depending on the slice variable.

    """
    position, unravel_fn = jax.flatten_util.ravel_pytree(position)
    momentum, _ = jax.flatten_util.ravel_pytree(momentum)
    position_centered = position - mean
    momentum_centered = momentum - mean
    return (
        unravel_fn(
            position_centered * jnp.cos(theta)
            + momentum_centered * jnp.sin(theta)
            + mean
        ),
        unravel_fn(
            momentum_centered * jnp.cos(theta)
            - position_centered * jnp.sin(theta)
            + mean
        ),
    )