Skip to content

bamojax.marginal_likelihoods.utility

iid_likelihood(L)

We typically have multiple observations and assume the likelihood factorizes as:

\[ \log p\left(Y \mid \theta\right) = \sum_{i=1}^N \log p\left(y_i \mid \theta\right) \enspace. \]
Source code in bamojax/marginal_likelihoods/utility.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def iid_likelihood(L: Callable):
    r"""

    We typically have multiple observations and assume the likelihood factorizes 
    as: 

    $$    
        \log p\left(Y \mid \theta\right) = \sum_{i=1}^N \log p\left(y_i \mid \theta\right) \enspace.
    $$

    """
    return lambda x: jnp.sum(L()(x))

flatten_dict_to_array(samples)

Bamojax states are dictionaries, with entries per model variable. Here we flatten them so the proposal distribution can be one single multivariate distribution.

Source code in bamojax/marginal_likelihoods/utility.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def flatten_dict_to_array(samples: dict):
    """ Bamojax states are dictionaries, with entries per model variable. Here we flatten them so the 
    proposal distribution can be one single multivariate distribution.

    """

    leaves, treedef = jax.tree_util.tree_flatten(samples)

    Ns = [leaf.shape[0] for leaf in leaves]
    N = Ns[0]

    flat_leaves = [leaf.reshape(N, -1) for leaf in leaves]

    sizes = [f.shape[1] for f in flat_leaves]
    cumulative_sizes = jnp.cumsum(jnp.array(sizes))
    samples_flattened = jnp.concatenate(flat_leaves, axis=-1)  # (N, D)

    leaf_shapes = [leaf.shape[1:] for leaf in leaves]

    def unravel_one_sample(vec: jnp.ndarray):

        splits = jnp.split(vec, cumulative_sizes[:-1])
        reshaped = [v.reshape(s) for v, s in zip(splits, leaf_shapes)]
        return jax.tree_util.tree_unflatten(treedef, reshaped)

    return samples_flattened, unravel_one_sample