Skip to content

bamojax.marginal_likelihoods.importance_sampling

importance_sampling(key, model, g_IS, num_samples=1000, iid_obs=True)

Importance sampling routine for a given BayesianModel.

Importance sampling is based around the following approximation to the log marginal likelihood [Gronau et al., 2017]:

$$ p(D) \approx \frac{1}{N} \sum_{i=1}^N p\left(D \mid \theta_i\right) \frac{p(\theta_i)}{g_IS(\theta_i)}\enspace, $$ with \(\theta_i \sim g_IS(\theta)\)

Here, g_IS is the importance density, which should meet these criteria:

  1. It is easy to evaluate.
  2. It has the same domain as the posterior p(\theta \mid D).
  3. It matches the posterior as closely as possible.
  4. It has fatter tails than the posterior.

There is no one-size-fits-all importance density; this needs to be crafted carefully for each specific problem.

Note that the importance density can also be a mixture distribution, which can make it easier to introduce heavy tails.

References:

  • Gronau, Q. F., Sarafoglou, A., Matzke, D., Ly, A., Boehm, U., Marsman, M., Leslie, D. S., Forster, J. J., Wagenmakers, E.-J., & Steingroever, H. (2017). A tutorial on bridge sampling. Journal of Mathematical Psychology, 81, 80-97. https://doi.org/10.1016/j.jmp.2017.09.005
Source code in bamojax/marginal_likelihoods/importance_sampling.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def importance_sampling(key, 
                        model: Model, 
                        g_IS: Distribution,
                        num_samples: int = 1_000,
                        iid_obs: bool = True) -> Float:

    r"""Importance sampling routine for a given BayesianModel.

    Importance sampling is based around the following approximation to the log
    marginal likelihood [Gronau et al., 2017]:

    $$
    p(D) \approx \frac{1}{N} \sum_{i=1}^N p\left(D \mid \theta_i\right) \frac{p(\theta_i)}{g_IS(\theta_i)}\enspace,
    $$
    with $\theta_i \sim g_IS(\theta)$

    Here, g_IS is the importance density, which should meet these criteria:

    1. It is easy to evaluate.
    2. It has the same domain as the posterior p(\theta \mid D).
    3. It matches the posterior as closely as possible.
    4. It has fatter tails than the posterior.

    There is no one-size-fits-all importance density; this needs to be crafted
    carefully for each specific problem.

    Note that the importance density can also be a mixture distribution, which 
    can make it easier to introduce heavy tails.

    References:

    - Gronau, Q. F., Sarafoglou, A., Matzke, D., Ly, A., Boehm, U., Marsman, M., Leslie, D. S., Forster, J. J., Wagenmakers, E.-J., & Steingroever, H. (2017). A tutorial on bridge sampling. Journal of Mathematical Psychology, 81, 80-97. https://doi.org/10.1016/j.jmp.2017.09.005


    """

    def g_eval(state):
        logprob = 0        
        values_flat, _ = tree_flatten(state)
        for value, dist in zip(values_flat, g_flat):
            logprob += jnp.sum(dist.log_prob(value))
        return logprob

    # 
    def adjusted_likelihood(state):
        return loglikelihood_fn(state) + logprior_fn(state) - g_eval(state)

    #

    if iid_obs:
        loglikelihood_fn = iid_likelihood(model.loglikelihood_fn)
    else:
        loglikelihood_fn = model.loglikelihood_fn

    logprior_fn = model.logprior_fn()

    g_flat, g_treedef = tree_flatten(g_IS, 
                                     lambda l: isinstance(l, (Distribution, TransformedDistribution)))

    samples = list()
    for g in g_flat:
        key, subkey = jrnd.split(key)
        samples.append(g.sample(key=subkey, sample_shape=(num_samples, )))

    importance_samples = tree_unflatten(g_treedef, samples)
    adjusted_likelihoods = jax.vmap(adjusted_likelihood)(importance_samples)
    return logsumexp(adjusted_likelihoods) - jnp.log(num_samples)