Marginal Likelihood Computation#

Recall Lemma 2#

Under same HMM setup, we are interested in marginal \(p(y)\):

\[\begin{align*} x &\sim \mathcal{N}(m_0, P_0) \\ y \mid x &\sim \mathcal{N}(A x + b, Q) \end{align*}\]

Then

\[\begin{align*} p(y) &= \int p(x) \; p(y|x) dx \\ &= \mathcal{N}(A m_0 + b, Q + APA^\top) \end{align*}\]

Marginal Likelihood form#

Consider the HMM:

\[\begin{align*} \pi_0(x_0) &= \mathcal{N}(x_0; m_0, P_0) \\ \tau(x_t | x_{t-1}) &= \mathcal{N}(x_t; A x_{t-1}, Q) \\ p(y_t | x_t) &= \mathcal{N}(y_t; H x_t, R) \end{align*}\]

We know the exact marginal likelihood \(p(y_{1:T})\) is given by

\[ p(y_{1:T}) = \int p(x_{0:T}, y_{1:T}) dx_{0:T} \]
\[ = \int \pi_0(x_0) \prod_{t=1}^T \bigg(\tau(x_t | x_{t-1}) p(y_t | x_t) \bigg) dx_{0:T} \]

We also note that equivalently

\[ p(y_{1:T}) = \prod_{t=1}^T p(y_t \mid y_{1:t-1}) \]

Where

\[ p(y_t \mid y_{1:t-1}) = \int p(y_t \mid x_t) p(x_t \mid y_{1:t-1}) \, dx_t \]

Since we have a fully Gaussian HMM setup, we can use Kalman filter to compute \(p(x_t \mid y_{1:t-1}) = \mathcal{N}(x_t; \hat{m}_t, \hat{P}_t)\), where

\[\begin{align*} \hat{m}_t &= A m_{t-1} \\ \hat{P}_t &= A P_{t-1} A^\top + Q \end{align*}\]

So that we have

\[ p(y_t \mid y_{1:t-1}) = \int \mathcal{N}(y_t; H x_t, R) \cdot \mathcal{N}(x_t; \hat{m}_t, \hat{P}_t) \, dx_t \]

By Lemma 2, this is

\[ p(y_t \mid y_{1:t-1}) = \mathcal{N}(H \hat{m}_t , H \hat{P}_t H^\top + R) = \mathcal{N}(H \hat{m}_t , S_t) \]

Where we let \(S_t = H \hat{P}_t H^\top + R\) for simplicity

Kalman for marginal likelihood#

The full algorithm for computing \(\log p(y_{1:T}) = \sum_{t=1}^T \log p(y_t \mid y_{1:t-1})\) is given by

  • Input: Starting point \( m_0, P_0\), and the sequence of observations \( y_{1:T} \) for the specific T.

    Set \(\hat{m}_0 = m_0, \hat{P}_0 = P_0\)

  • Filtering:
    For \( n = 1, \dots, T \) do

    • Prediction step:

    \[\begin{align*} \hat{m}_t &= \theta m_{t-1} \\ \hat{P}_t &= \theta P_{t-1} \theta^\top + Q \end{align*}\]
    • Update step:

    \[\begin{align*} S_t &= H \hat{P}_t H^\top + R \\ K_t &= \hat{P}_t H^\top (S_t)^{-1} \\ m_t &= \hat{m}_t + K_t (y_t - H \hat{m}_t) \\ P_t &= (I - K_t H) \hat{P}_t \end{align*}\]

    End for

  • Return \( \hat{m}_{1:T}, S_{1:T}\)

And we output

\[ \log p(y_{1:T}) = \sum_{t=1}^T \log \mathcal{N}(y_t; H \hat{m}_{t}, S_t) \]

BPF for Marginal Likelihood#

At each t, the weights are given by \(p(y_t | x_t^{(i)})\), where \(x_t^{(i)}\) are the particles at \(t\).

The average weights give un unbiased estimate of marginal likelihood:

Since we have

\[ p(y_{1:n}) = p(y_{1:n-1}) p(y_n|y_{1:n-1}) \]

Where the conditional likelihood

\[ p(y_n|y_{1:n-1}) = \int p(y_n|x_n) p(x_n|y_{1:n-1}) \, \mathrm{d}x_n \]

can be estimated by BPF:

\[ p^N(y_n|y_{1:n-1}) = \frac{1}{N} \sum_{i=1}^{N} p(y_n|\tilde{x}_n^{(i)}) \]

Where \(\tilde{x}_n^{(i)}\) are the final resampled outputs of the BPF. So that the full marginal likelihood is given by

\[ p^N(y_{1:n}) = \prod_{k=1}^{t} p^N(y_k|y_{1:k-1}) \]

As we are working with the log domain, we have

\[\begin{align*} \log p^N(y_n | y_{1:n-1}) =& \log \sum_{i=1}^{N} \exp\left( \log p(y_n | \tilde{x}_n^{(i)}) \right) - \log N, \\ =& \log \sum_{i=1}^{N} \exp\left( \log W_n^{(i)} \right) - \log N \end{align*}\]

Pseudocode#

Recall setup:

\[\begin{align*} \pi_0(x_0) &= \mathcal{N}(x_0; m_0, P_0) \\ \tau(x_t | x_{t-1}) &= \mathcal{N}(x_t; \theta x_{t-1}, Q) \\ p(y_t | x_t) &= \mathcal{N}(y_t; H x_t, R) \end{align*}\]

The whole BPF pseudocode is given by

Sample:
\( x_0^{(i)} \sim \pi_0 \) for \( i = 1, \ldots, N. \)

for \( t = 1, \ldots, T \) do

  • Sample: \( \bar{x}_t^{(i)} \sim \tau_{\theta}(\cdot|x_{t-1}^{(i)}) \quad \text{for} \quad i = 1, \ldots, N. \)

  • Weight: \( W_t^{(i)} = g_{\theta}(y_t|\bar{x}_t^{(i)}), \) for \( i = 1, \ldots, N \).

  • Store: \( p^N(y_n|y_{1:n-1}) = \frac{1}{N} \sum_{i=1}^{N} W_t^{(i)} \)

  • Normalize \( w_t^{(i)} = \frac{W_t^{(i)}}{\sum_{j=1}^{N} W_t^{(j)}}, \)

  • Resample: Sample \( o_t(1), \dots, o_t(N) \sim \) Multinomial( \(w_t^{(1)}, \dots, w_t^{(N)}\) ), and set \( x_t^{(i)} = \bar{x}_t^{(o_t(i))} \) for \( i = 1, \ldots, N \).

Return \(p^N(y_{1:n}) = \prod_{k=1}^{t} p^N(y_k|y_{1:k-1})\)

# example setup

import numpy as np

def log_p(y, y_pred, R):
    '''y: scalar, y_pred: (N,), R: scalar variance'''
    diff = y_pred - y  # (N,)
    mahalanobis = (diff ** 2) / R  # (N,)
    log_det_R = np.log(R)
    
    log_likelihoods = -0.5 * (mahalanobis + np.log(2 * np.pi) + log_det_R)
    return log_likelihoods

def bpf_q2(Y, A, Q, H, R, m0, P0, N=500):
    T = len(Y)
    W = np.zeros((T, N))
    log_marginal_likelihood = 0.0
    
    # Step 1: init particles
    x = np.random.normal(loc=m0, scale=np.sqrt(P0), size=N)  # (N,)
    particle_history = []

    for t in range(T):
        # Step 2: Propagation
        noise = np.random.normal(loc=0.0, scale=np.sqrt(Q), size=N)
        x = A * x + noise  # (N,)
        
        # Step 3: Weight Update
        y = Y[t]
        y_pred = H * x  # (N,)
        log_w = log_p(y, y_pred, R)

        # compute average weight using log-sum-exp trick
        # log sum exp( log W) - log N
        max_log = np.max(log_w)
        log_sum = np.log(np.sum(np.exp(log_w - max_log))) + max_log
        log_marginal_likelihood += log_sum - np.log(N)

        # Normalize log-weights in log-domain
        # same as weights /= np.sum(weights)
        log_w -= np.log(np.sum(np.exp(log_w - np.max(log_w)))) + np.max(log_w)
        weights = np.exp(log_w)
        W[t] = weights
        
        # Step 4: Resampling
        indices = np.random.choice(N, size=N, p=weights)
        x = x[indices]
        particle_history.append(x.copy())
        
    return log_marginal_likelihood