Log-Sum-Exp Trick#
Often when we do Importance Sampling, we need to deal with weights \(w_1, \dots, w_N\), where some weights can go very small.
We often operate in the log domain, in which case we operate with \(\log w_1, \dots, \log w_N\), where some \(\log w_i\) can approach -inf.
When we need to compute the log mean weight \(\log \left( \frac{1}{N} \sum_{i=1}^N w_i \right)\), or when we have unnormalized weights \(\log W_1, \dots, \log W_N\), and we need the sum \( \sum W_i\) to normalize our weights, we need to use the log-sum-exp trick
We note that
Idea: Let \(w_j\) be the largest weight. Since \(\log\) is monotone, \(\log w_j\) is also the largest.
Then
The reason we want to subtract the max log weight from each \(\log w_i\):
Suppose we have log weights: -100, -99, -98. After subtracting we have -2, -1, 0.
So less likely for underflow
Now following same idea, we have the log mean weight:
And the normalized weights, i.e. doing \(w_j = \frac{W_j}{\sum_i W_i}\)
import numpy as np
weights = np.array([0.1, 0.2, 0.3, 0.4])
weights /= np.random.uniform(1, 1000) # unnormalized weights
log_W = np.log(weights) # unnormalized log weights
print("unnormalized log weights:", log_W)
log_W -= np.log(np.sum(np.exp(log_W - np.max(log_W)))) + np.max(log_W)
weights = np.exp(log_W)
print(weights)
unnormalized log weights: [-8.63722187 -7.94407469 -7.53860958 -7.25092751]
[0.1 0.2 0.3 0.4]