Importance Weighted AutoEncoder#

This is a strategy applicable both for IWAE (which achieves tighter ELBO), and importance resampling which reduces variance

Setup#

Consider we know \(q(z \mid x)\) and \(p(x, y)\), and wish to sample from the mixture \(q^k(z \mid x)\)

\[\begin{align*} q^k(z \mid x) &= \frac{w(z)}{\frac{1}{k} \sum_{j=1}^k w(z_j)} q(z \mid x) \\ \text{where} \quad w(z) &= \frac{p(x, z)}{q(z \mid x)} \end{align*}\]

Notice if we let \(\tilde{w}_j = \frac{w_j}{\sum_{l=1}^k w_l}\) be the normalized weights, then \(q^k(z \mid x)\) is given by

\[\begin{align*} q^k(z \mid x) \propto \tilde{w} q(z \mid x) \end{align*}\]

We can sample \(z \sim q^k(z \mid x)\) by doing:

  • Sample \( z_1, z_2, \ldots, z_k \sim q(z \mid x) \)

  • Compute importance weights:

\[ w_j = \frac{p(x, z_j)}{q(z_j \mid x)} \quad \text{for } j = 1, \ldots, k \]
  • Normalize the weights:

\[ \tilde{w}_j = \frac{w_j}{\sum_{l=1}^k w_l} \]
  • Sample index \( j \sim \text{Categorical}(\tilde{w}_1, \ldots, \tilde{w}_k) \)

  • Output \( \hat{z} = z_j \)

VAE Application#

IWAE objective#

\[ \mathcal{L}_{IWAE}^k(\theta, \phi; x) := \mathbb{E}_{z_1,\ldots,z_k\sim q_\phi({z}\mid {x})} \left[ \log \frac{1}{k} \sum_{i=1}^k \frac{p_\theta({x}, {z_i})}{q_\phi({z_i}\mid {x})} \right], \]

The IWAE objective can be used as an optimisation objective instead of the VAE ELBO \(\mathcal{L}_{VAE}\). The IWAE objective was introduced in the following paper,

where it is shown that \(\mathcal{L}_{IWAE}^k\) is a tighter lower bound on the marginal likelihood than ELBO:

(1)#\[\begin{align} \log p_\theta(x) \ge \mathcal{L}_{IWAE}^{k+1} \ge \mathcal{L}_{IWAE}^k \end{align}\]

It is also shown (under certain conditions on \(q_\phi(z\mid x)\)) that \(\mathcal{L}_{IWAE}^k \to \log p_\theta(x)\) as \(k\to\infty\).

Log-Mean-Exp trick#

For \( z_1,\ldots,z_k\sim q_\phi({z}\mid {x}) \), denote

\[ w_i = \frac{p_\theta({x}, {z_i})}{q_\phi({z_i}\mid {x})} \]

While previously, we are computing a single \( \log w_i \) as our one-sample MC estimate for the ELBO, now we are computing

\[ \log \frac{1}{k} \sum_i w_i \]

We will proceed with first obtaining the (batch_size, k) vector:

\[ \log w = ( \log w_1, \dots \log w_k ) \]

Then use the log-mean-exp trick to obtain log(mean w) with numerical stabiltiy:

\[ \log \text{mean}(\exp (\log w)) = \log \text{mean}(w_1 \dots w_n) \]
\[ = \log (\frac{1}{k} \sum_i w_i) \]

Resample from IWAE proposal#

We define the distribution \(q^k_{IWAE}(z\mid x)\) by

\[ q^k_{IWAE}(z\mid x) = \mathbb{E}_{z_2,\ldots,z_k\sim q_\phi(z\mid x)}\left[ \frac{p_\theta(x, z)}{\frac{1}{k} \left( \frac{p_\theta(x, z)}{q_\phi(z \mid x)} + \sum_{j=2}^k \frac{p_\theta(x, z_j)}{q_\phi(z_j \mid x)} \right)} \right]. \]

Proof this is valid density#

We first prove that \(q^k_{IWAE}(z\mid x)\) is a valid normalised distribution; that is, \(\int_z q^k_{IWAE}(z\mid x) dz = 1\).

We denote the weights as \( w_j = \frac{p(x, z_j)}{q(z_j \mid x)} \), where \( j = 2 \dots k \)

We write \( w_1 = \frac{p(x, z_1)}{q(z_1 \mid x)} \), where we denote \(z_1 := z\), the variable we are integrating over, with a slight abuse of notation

Then we have that

\[ q^k_{IWAE}(z_1 \mid x) = \mathbb{E}_{z_2,\ldots,z_k \sim q} \left[ \frac{w_1 \cdot q(z_1 \mid x)}{\frac{1}{k} \sum_{j=1}^k w_j} \right] \]

and integrating over \(z_1\):

\[ \int q^k_{IWAE}(z_1 \mid x) \, dz_1 = \mathbb{E}_{z_2,\ldots,z_k \sim q} \left[ \int \frac{w_1 \cdot q(z \mid x)}{\frac{1}{k} \sum_{j=1}^k w_j} \, dz_1 \right] \]
\[ = \mathbb{E}_{z_2,\ldots,z_k \sim q(z_j \mid x)} \left[ \mathbb{E}_{z_1 \sim q(z \mid x)} \left[ \frac{w_1}{\frac{1}{k} \sum_{j=1}^k w_j} \right] \right] \]

Since \(z_1 \ldots z_k \sim q(z \mid x) \) iid, we can just move the inner expectation over \(z_1\) outside:

\[ = \mathbb{E}_{z_1,\ldots,z_k \sim q} \left[ \frac{w_1}{\frac{1}{k} \sum_{j=1}^k w_j} \right] \]

We note that by symmetry, each \(w_j = \frac{p(x, z_j)}{q(z_j \mid x)} \) where each \(z_j \sim q(z \mid x) \) iid,
should have same expectation under the joint distribution of \(\{z_j\}\) which is \(\prod_j q(z_j \mid z, x)\)

i.e. for all \( i = 1 \dots k\) :

\[ \mathbb{E}_{z_1,\ldots,z_k \sim q} \left[ \frac{w_i}{\sum_{j=1}^k w_j} \right] = \frac{1}{k} \]

So that

\[ \int q^k_{IWAE}(z_1 \mid x) \, dz_1 = \mathbb{E}_{z_1,\ldots,z_k \sim q} \left[ \frac{w_1}{\frac{1}{k} \sum_{j=1}^k w_j} \right] = \frac{ \frac{1}{k} }{ \frac{1}{k}} = 1 \]

Sampling from it#

To construct the self-normalised sampling from IWAE posterior with k=50, we used our code from compute_NLL_diag defined before to obtain \(\log w = (\log w_1 \dots \log w_k) \), then again apply removing max and log-sum-exp trick to ensure numerical stability

So that we obtained the normalized weights \(\hat{w_1}, \dots \hat{w_k}\)

We then sample a single j from categorical(\(\hat{w_1}, \dots \hat{w_k}\) ), and return z_j

In practice, z_samples is of shape (128, k, 2), so this means obtaining the sample indices of shape (128,), where on each row, a single indice from 1,2..,k is sampled. Then we apply the sampled indices on the second dimension of z_samples, to get the re-sampled z of shape (128, 2). This is then fed into the decoder to obtain x_recon for plotting

def recon_IWAE_posterior(model, data, k=50):
    """
    Self-normalised IWAE reconstruction.
    """

    # we first use code from compute_NLL_diag to obtain z1..zk ~ q(z|x)
    z_mean, z_log_var = model.encoder(data)
    batch_size = ops.shape(z_mean)[0]

    # z_mean and z_log_var have (128, 2) shape
    # expand to (128, k, 2)
    z_mean = ops.broadcast_to(ops.expand_dims(z_mean, 1), 
                              (batch_size, k, -1))
    z_log_var = ops.broadcast_to(ops.expand_dims(z_log_var, 1), 
                                 (batch_size, k, -1))
    
    eps = keras.random.normal(ops.shape(z_mean))
    z = z_mean + ops.exp(0.5 * z_log_var) * eps # (128, k, 2)
    z_flat = ops.reshape(z, (-1, 2)) # (128 * k, 2)
    x_recon = model.decoder(z_flat) # (128 * k, 28, 28, 1)
    
    # reshape to (128, k, 28, 28, 1)
    x_recon = ops.reshape(x_recon, (batch_size, k, 28, 28, 1))
    
    # expand data[0] to (128, k, 28, 28, 1)
    x_expand = ops.broadcast_to(ops.expand_dims(data[0], 1), 
                                (batch_size, k, 28, 28, 1))

    # -log p(x|z): BCE
    nll = ops.binary_crossentropy(x_expand, x_recon)
    nll = ops.sum(nll, axis=[2, 3, 4]) # (128, k)
    
    # log p(z)
    log_p_z = -0.5 * ops.sum(z**2 + np.log(2 * np.pi), axis=-1)

    # log q(z|x) = log N(z_mean, exp z_log_var)
    log_q_z = -0.5 * ops.sum(
        (z - z_mean)**2 / ops.exp(z_log_var) 
        + z_log_var + np.log(2 * np.pi), axis=-1
    )  # (batch, k)

    # obtained (logw1 ... logwk) 
    log_w = -nll + log_p_z - log_q_z  # (batch, k)
    
    # remove max for numerical stability
    max_log_w = ops.max(log_w, axis=1, keepdims=True) 
    # log-sum-exp trick  
    w_unnorm = ops.exp(log_w - max_log_w)  # (batch, k)             
    w_sum = ops.sum(w_unnorm, axis=1, keepdims=True)     
    w_norm = w_unnorm / w_sum   

    # sample from categorical(k) for each data in batch
    # obtain indices of (128, ) shape
    indices = torch.multinomial(w_norm, num_samples=1).squeeze()

    print('sampled indices shape', indices.shape)

    batch_indices = torch.arange(batch_size)

    # Final sampled output: shape (128, 2)
    z_IWAE_sampled = z[batch_indices, indices] 
    print('sampled z shape', z_IWAE_sampled.shape)
    
    # using the sampled z, reconstruct the image
    x_mean = model.decoder(z_IWAE_sampled)
    return x_mean

# test shape
for batch in test_loader:
    print(recon_IWAE_posterior(vae_iwae_diag, batch, k=50).shape)
    break