Importance Weighted AutoEncoder#
This is a strategy applicable both for IWAE (which achieves tighter ELBO), and importance resampling which reduces variance
Setup#
Consider we know \(q(z \mid x)\) and \(p(x, y)\), and wish to sample from the mixture \(q^k(z \mid x)\)
Notice if we let \(\tilde{w}_j = \frac{w_j}{\sum_{l=1}^k w_l}\) be the normalized weights, then \(q^k(z \mid x)\) is given by
We can sample \(z \sim q^k(z \mid x)\) by doing:
Sample \( z_1, z_2, \ldots, z_k \sim q(z \mid x) \)
Compute importance weights:
Normalize the weights:
Sample index \( j \sim \text{Categorical}(\tilde{w}_1, \ldots, \tilde{w}_k) \)
Output \( \hat{z} = z_j \)
VAE Application#
IWAE objective#
The IWAE objective can be used as an optimisation objective instead of the VAE ELBO \(\mathcal{L}_{VAE}\). The IWAE objective was introduced in the following paper,
where it is shown that \(\mathcal{L}_{IWAE}^k\) is a tighter lower bound on the marginal likelihood than ELBO:
It is also shown (under certain conditions on \(q_\phi(z\mid x)\)) that \(\mathcal{L}_{IWAE}^k \to \log p_\theta(x)\) as \(k\to\infty\).
Log-Mean-Exp trick#
For \( z_1,\ldots,z_k\sim q_\phi({z}\mid {x}) \), denote
While previously, we are computing a single \( \log w_i \) as our one-sample MC estimate for the ELBO, now we are computing
We will proceed with first obtaining the (batch_size, k) vector:
Then use the log-mean-exp trick to obtain log(mean w) with numerical stabiltiy:
Resample from IWAE proposal#
We define the distribution \(q^k_{IWAE}(z\mid x)\) by
Proof this is valid density#
We first prove that \(q^k_{IWAE}(z\mid x)\) is a valid normalised distribution; that is, \(\int_z q^k_{IWAE}(z\mid x) dz = 1\).
We denote the weights as \( w_j = \frac{p(x, z_j)}{q(z_j \mid x)} \), where \( j = 2 \dots k \)
We write \( w_1 = \frac{p(x, z_1)}{q(z_1 \mid x)} \), where we denote \(z_1 := z\), the variable we are integrating over, with a slight abuse of notation
Then we have that
and integrating over \(z_1\):
Since \(z_1 \ldots z_k \sim q(z \mid x) \) iid, we can just move the inner expectation over \(z_1\) outside:
We note that by symmetry, each \(w_j = \frac{p(x, z_j)}{q(z_j \mid x)} \) where each \(z_j \sim q(z \mid x) \) iid,
should have same expectation under the joint distribution of \(\{z_j\}\) which is \(\prod_j q(z_j \mid z, x)\)
i.e. for all \( i = 1 \dots k\) :
So that
Sampling from it#
To construct the self-normalised sampling from IWAE posterior with k=50, we used our code from compute_NLL_diag
defined before to obtain \(\log w = (\log w_1 \dots \log w_k) \), then again apply removing max and log-sum-exp trick to ensure numerical stability
So that we obtained the normalized weights \(\hat{w_1}, \dots \hat{w_k}\)
We then sample a single j from categorical(\(\hat{w_1}, \dots \hat{w_k}\) ), and return z_j
In practice, z_samples is of shape (128, k, 2), so this means obtaining the sample indices of shape (128,), where on each row, a single indice from 1,2..,k is sampled. Then we apply the sampled indices on the second dimension of z_samples, to get the re-sampled z of shape (128, 2). This is then fed into the decoder to obtain x_recon for plotting
def recon_IWAE_posterior(model, data, k=50):
"""
Self-normalised IWAE reconstruction.
"""
# we first use code from compute_NLL_diag to obtain z1..zk ~ q(z|x)
z_mean, z_log_var = model.encoder(data)
batch_size = ops.shape(z_mean)[0]
# z_mean and z_log_var have (128, 2) shape
# expand to (128, k, 2)
z_mean = ops.broadcast_to(ops.expand_dims(z_mean, 1),
(batch_size, k, -1))
z_log_var = ops.broadcast_to(ops.expand_dims(z_log_var, 1),
(batch_size, k, -1))
eps = keras.random.normal(ops.shape(z_mean))
z = z_mean + ops.exp(0.5 * z_log_var) * eps # (128, k, 2)
z_flat = ops.reshape(z, (-1, 2)) # (128 * k, 2)
x_recon = model.decoder(z_flat) # (128 * k, 28, 28, 1)
# reshape to (128, k, 28, 28, 1)
x_recon = ops.reshape(x_recon, (batch_size, k, 28, 28, 1))
# expand data[0] to (128, k, 28, 28, 1)
x_expand = ops.broadcast_to(ops.expand_dims(data[0], 1),
(batch_size, k, 28, 28, 1))
# -log p(x|z): BCE
nll = ops.binary_crossentropy(x_expand, x_recon)
nll = ops.sum(nll, axis=[2, 3, 4]) # (128, k)
# log p(z)
log_p_z = -0.5 * ops.sum(z**2 + np.log(2 * np.pi), axis=-1)
# log q(z|x) = log N(z_mean, exp z_log_var)
log_q_z = -0.5 * ops.sum(
(z - z_mean)**2 / ops.exp(z_log_var)
+ z_log_var + np.log(2 * np.pi), axis=-1
) # (batch, k)
# obtained (logw1 ... logwk)
log_w = -nll + log_p_z - log_q_z # (batch, k)
# remove max for numerical stability
max_log_w = ops.max(log_w, axis=1, keepdims=True)
# log-sum-exp trick
w_unnorm = ops.exp(log_w - max_log_w) # (batch, k)
w_sum = ops.sum(w_unnorm, axis=1, keepdims=True)
w_norm = w_unnorm / w_sum
# sample from categorical(k) for each data in batch
# obtain indices of (128, ) shape
indices = torch.multinomial(w_norm, num_samples=1).squeeze()
print('sampled indices shape', indices.shape)
batch_indices = torch.arange(batch_size)
# Final sampled output: shape (128, 2)
z_IWAE_sampled = z[batch_indices, indices]
print('sampled z shape', z_IWAE_sampled.shape)
# using the sampled z, reconstruct the image
x_mean = model.decoder(z_IWAE_sampled)
return x_mean
# test shape
for batch in test_loader:
print(recon_IWAE_posterior(vae_iwae_diag, batch, k=50).shape)
break