Skip to content

bamojax.more_distributions

GaussianProcessFactory(cov_fn, mean_fn=Zero(), nd=None, jitter=1e-06)

Returns an instantiated Gaussian process distribution object.

This is essentially a dist.MultivariateNormal object, with its mean and covariance determined by the mean and covariance functions of the GP.

Parameters:

Name Type Description Default
cov_fn Callable

The GP covariance function. It assumes a signature of cov_fn(parameters: dict, x: Array, y: Array). This is provided by the jaxkern library, but others can be used as well.

required
mean_fn Callable

The GP mean function.

Zero()
nd Tuple[int, ...]

A tuple of integers indicating optional additional output dimensions (for multi-task GPs).

None
jitter float

A small value for numerical stability.

1e-06

Returns: A GaussianProcessInstance distrax Distribution object.

Source code in bamojax/more_distributions.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def GaussianProcessFactory(cov_fn: Callable, mean_fn: Callable = Zero(),  nd: Tuple[int, ...] = None, jitter: float = 1e-6):
    r""" Returns an instantiated Gaussian process distribution object. 

    This is essentially a dist.MultivariateNormal object, with its mean and covariance determined by the mean and covariance functions of the GP.

    Args: 
        cov_fn: The GP covariance function. It assumes a signature of cov_fn(parameters: dict, x: Array, y: Array). 
                This is provided by the `jaxkern` library, but others can be used as well.
        mean_fn: The GP mean function.
        nd: A tuple of integers indicating optional additional output dimensions (for multi-task GPs).
        jitter: A small value for numerical stability.
    Returns:
        A GaussianProcessInstance distrax Distribution object.

    """

    class GaussianProcessInstance(Distribution):
        """ An instantiated Gaussian process distribution object, i.e. a multivariate Gaussian.

        """

        def __init__(self, input: Node, **params):
            self.input = input

            # In case of composite covariance functions:
            if 'params' in params:
                self.params = params['params']
            else:
                self.params = params

        #
        def sample(self, key, sample_shape=()):
            r""" Sample from the instantiated Gaussian process (i.e. multivariate Gaussian)

            """
            x = self.input
            m = x.shape[0]
            output_shape = (m, )

            if nd is not None:
                output_shape = nd + output_shape

            if len(sample_shape) >= 1:
                output_shape = sample_shape + output_shape

            mu = self._get_mean()
            cov = self._get_cov()
            L = jnp.linalg.cholesky(cov)
            z = jrnd.normal(key, shape=output_shape).T
            V = jnp.tensordot(L, z, axes=(1, 0))
            f = jnp.add(mu, jnp.moveaxis(V, 0, -1))
            # if jnp.ndim(f) == 1:
            #     f = f[jnp.newaxis, :]
            return f

        #
        def log_prob(self, value):
            mu = self._get_mean()
            cov = self._get_cov()
            return dist.MultivariateNormal(loc=mu, covariance_matrix=cov).log_prob(value=value)

        #
        def sample_predictive_batched(self, key: Array, x_pred: Array, f: Array, num_batches:int = 20):
            r""" Samples from the posterior predictve of the latent f, but in batches to converve memory.

            Args:
                key: PRNGkey
                x_pred: Array
                    The test locations
                f: Array
                    The trained GP to condition on
                num_batches: int
                    The number of batches to predict over.

            Returns:
                Returns samples from the posterior predictive distribution:

                $$
                    \mathbf{f}* \sim p(\mathbf{f}* \mid \mathbf{f}, X, y x^*) = \int p(\mathbf{f}* \mid x^*, \mathbf{f}) p(\mathbf{f} \mid X, y) \,\text{d} \mathbf{f}
                ##


            """
            if jnp.ndim(x_pred) == 1:
                x_pred = x_pred[:, jnp.newaxis]

            n_pred = x_pred.shape[0]
            data_per_batch = int(n_pred / num_batches)
            fpreds = list()
            for batch in range(num_batches):
                key, subkey = jrnd.split(key)
                lb = data_per_batch*batch
                ub = data_per_batch*(batch + 1)
                fpred_batch = self.sample_predictive(subkey, x_pred[lb:ub, :], f)
                fpreds.append(fpred_batch)

            fpred = jnp.hstack(fpreds)
            return fpred

        #
        def sample_predictive(self, key: Array, x_pred: Array, f: Array):
            r"""Sample latent f for new points x_pred given one posterior sample.

            See Rasmussen & Williams. We are sampling from the posterior predictive for
            the latent GP f, at this point not concerned with an observation model yet.

            We have $[\mathbf{f}, \mathbf{f}^*]^T ~ \mathcal{N}(0, KK)$, where $KK$ can be partitioned as:

            $$
                KK = \begin{bmatrix} K(x,x) & K(x,x^*) \\ K(x,x^*)\top & K(x^*,x^*)\end{bmatrix}
            $$

            This results in the conditional
            $$
            \mathbf{f}^* | x, x^*, \mathbf{f} ~ \mathcal{N}(\mu, \Sigma) \enspace,
             $$ where

            $$
            \begin{align*}
                \mu &= K(x^*, x)K(x,x)^-1 f \enspace,
                \Sigma &= K(x^*, x^*) - K(x^*, x) K(x, x)^-1 K(x, x^*) \enspace.
            \end{align*}                
            $$

            Args:
                key: The jrnd.PRNGKey object
                x_pred: The prediction locations $x^*$
                state_variables: A sample from the posterior

            Returns:
                A single posterior predictive sample $\mathbf{f}^*$

            """
            x = self.input
            n = x.shape[0]
            z = x_pred
            if 'obs_noise' in self.params:
                obs_noise = self.params['obs_noise']
                if jnp.isscalar(obs_noise) or jnp.ndim(obs_noise) == 0:
                    diagonal_noise = obs_noise**2 * jnp.eye(n, )
                else:
                    diagonal_noise = jnp.diagflat(obs_noise)**2
            else:
                diagonal_noise = 0

            mean = mean_fn.mean(params=self.params, x=z)
            Kxx = self.get_cov()
            Kzx = cov_fn(params=self.params, x=z, y=x)
            Kzz = cov_fn(params=self.params, x=z, y=z)

            Kxx += jitter * jnp.eye(*Kxx.shape)
            Kzx += jitter * jnp.eye(*Kzx.shape)
            Kzz += jitter * jnp.eye(*Kzz.shape)

            L = jnp.linalg.cholesky(Kxx + diagonal_noise)
            v = jnp.linalg.solve(L, Kzx.T)

            predictive_var = Kzz - jnp.dot(v.T, v)
            predictive_var += jitter * jnp.eye(*Kzz.shape)
            C = jnp.linalg.cholesky(predictive_var)

            def get_sample(u_, target_):
                alpha = jnp.linalg.solve(L.T, jnp.linalg.solve(L, target_))
                predictive_mean = mean + jnp.dot(Kzx, alpha)
                return predictive_mean + jnp.dot(C, u_)

            #
            if jnp.ndim(f) == 3:            
                _, nu, d = f.shape
                u = jrnd.normal(key, shape=(len(z), nu, d))
                samples = jax.vmap(jax.vmap(get_sample, in_axes=1), in_axes=1)(u, f)
                return samples.transpose([2, 0, 1])
            elif jnp.ndim(f) == 1:
                u = jrnd.normal(key, shape=(len(z),))
                return get_sample(u, f)
            else:
                raise NotImplementedError(f'Shape of target must be (n,) or (n, nu, d)',
                f'but {f.shape} was provided.')

        #
        def _get_mean(self):
            """ Returns the mean of the GP at the input locations.

            """
            return mean_fn.mean(params=self.params, x=self.input)

        #
        def get_mean(self):
            return self._get_mean()

        #
        def _get_cov(self):
            """ Returns the covariance of the GP at the input locations.

            """
            x = self.input
            m = x.shape[0]
            return cov_fn(params=self.params, x=x, y=x) + jitter * jnp.eye(m)

        #
        def get_cov(self):
            return self._get_cov()

        #
        @property
        def event_shape(self):
            r""" Event shape in this case is the shape of a single draw of $F = (f(x_1), ..., f(x_n))$

            """

            output_shape = (self.input.shape[0], )
            if nd is not None:
                output_shape = nd + output_shape
            return output_shape

        #
        @property
        def batch_shape(self):
            return ()

        #

    #
    return GaussianProcessInstance

AutoRegressionFactory(ar_fn)

Generates an autoregressive distribution with Gaussian emissions.

This is a generator function that constructs a distrax Distribution object, which can then be queried for its log probability for inference.

Parameters:

Name Type Description Default
ar_fn Callable

A Callable function that takes innovations \(\epsilon \sim \mathcal{N}(0, \sigma^2)\), and the previous instances \(x(t-1), ..., x(t-p)\), and performs whatever computation the user requires.

required
Source code in bamojax/more_distributions.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def AutoRegressionFactory(ar_fn: Callable):
    r""" Generates an autoregressive distribution with Gaussian emissions.

    This is a generator function that constructs a distrax Distribution object, which can then be queried for its log probability for inference.

    Args:
        ar_fn: A Callable function that takes innovations $\epsilon \sim \mathcal{N}(0, \sigma^2)$, and the previous instances $x(t-1), ..., x(t-p)$, and performs whatever computation the user requires.

    """

    # TODO: migrate from `distrax` format to `numpyro` format

    class ARInstance(Distribution):
        """ An instantiated autoregressive distribution object.

        """

        def __init__(self, **kwargs):
            self.parameters = kwargs

        #
        def _construct_lag_matrix(self, y, y_init):
            r""" Construct $y$, and up to order shifts of it.

            """
            order = 1 if jnp.isscalar(y_init) else y_init.shape[0]

            @jax.jit
            def update_fn(carry, i):
                y_shifted = jnp.roll(carry, shift=1)  
                y_shifted = y_shifted.at[0].set(y_init[i])  
                return y_shifted, y_shifted  

            #
            _, columns = jax.lax.scan(update_fn, y, jnp.arange(order))
            return columns

        #
        def log_prob(self, value):
            r""" Returns the log-density of the complete AR distribution

            """
            y_lagged = self._construct_lag_matrix(y=value, y_init=self.parameters['y0'])   
            mu = ar_fn(y_prev=y_lagged, **self.parameters) 
            return dist.Normal(loc=mu, scale=self.parameters['scale']).log_prob(value)

        #
        def _sample_n(self, key, n):
            r""" Sample from the AR distribution


            """
            keys = jrnd.split(key, n)
            samples = jax.vmap(self._sample_predictive)(keys)  
            return samples

        #        
        def _sample_predictive(self, key):
            r""" Sample from the AR(p) model.

            Let:

            $$
            \begin{align*}
                \epsilon_t &\sim \mathcal{N}(0, \sigma_y)
                y_t &= f(y_t-1, \theta) + \epsilon_t
            \end{align*}
            $$ for $t = M+1, \ldots, T$.

            """
            @jax.jit
            def ar_step(carry, epsilon_t):
                y_t = ar_fn(y_prev=carry, **self.parameters) + epsilon_t
                new_carry = jnp.concatenate([carry[1:], jnp.array([y_t])])
                return new_carry, y_t

            # 
            y_init = self.parameters['y0']
            order = 1 if jnp.isscalar(y_init) else y_init.shape[0]
            innovations = self.parameters['scale'] * jrnd.normal(key, shape=(self.parameters['T'] - order, ))
            _, ys = jax.lax.scan(ar_step, y_init, innovations)
            y = jnp.concatenate([y_init, ys])
            return y

        #                        
        @property
        def batch_shape(self):
            return ( )

        #
        @property
        def event_shape(self):
            return (self.T, )

        #

    #
    return ARInstance