Log-Sum-Exp Trick

Log-Sum-Exp Trick#

Often when we do Importance Sampling, we need to deal with weights \(w_1, \dots, w_N\), where some weights can go very small.

We often operate in the log domain, in which case we operate with \(\log w_1, \dots, \log w_N\), where some \(\log w_i\) can approach -inf.

When we need to compute the log mean weight \(\log \left( \frac{1}{N} \sum_{i=1}^N w_i \right)\), or when we have unnormalized weights \(\log W_1, \dots, \log W_N\), and we need the sum \( \sum W_i\) to normalize our weights, we need to use the log-sum-exp trick

We note that

\[\begin{align*} \log \left( \sum_{i=1}^N w_i \right) =& \log \left( \sum_{i=1}^N \exp(\log w_i) \right) \\ =& \log \left( \sum_{i=1}^N \exp(\log w_i - \max_j \log w_j) \right) + \max_j \log w_j \\ \end{align*}\]

Idea: Let \(w_j\) be the largest weight. Since \(\log\) is monotone, \(\log w_j\) is also the largest.

Then

\[\begin{align*} \log \sum w_i =& \log \sum_i \frac{w_i}{w_j} w_j \\ =& \log \sum_i \frac{w_i}{w_j} + \log w_j \\ =& \log( \sum_i \exp(\log w_i - \log w_j)) + \log w_j \\ \end{align*}\]

The reason we want to subtract the max log weight from each \(\log w_i\):
Suppose we have log weights: -100, -99, -98. After subtracting we have -2, -1, 0.
So less likely for underflow

Now following same idea, we have the log mean weight:

\[\begin{align*} \log \left( \frac{1}{N} \sum_{i=1}^N w_i \right) = \log \left( \sum_{i=1}^N w_i \right) - \log N \end{align*}\]

And the normalized weights, i.e. doing \(w_j = \frac{W_j}{\sum_i W_i}\)

\[\begin{align*} \log w_i =& \log w_i - \log \sum w_i\\ \log w_i \mathrel{-}=& \log \left( \sum_{j=1}^N \exp(\log w_j - \max_k \log w_k) \right) + \max_k \log w_k \\ \end{align*}\]
import numpy as np
weights = np.array([0.1, 0.2, 0.3, 0.4])
weights /= np.random.uniform(1, 1000) # unnormalized weights
log_W = np.log(weights) # unnormalized log weights
print("unnormalized log weights:", log_W)
log_W -= np.log(np.sum(np.exp(log_W - np.max(log_W)))) + np.max(log_W)
weights = np.exp(log_W)
print(weights)
unnormalized log weights: [-8.63722187 -7.94407469 -7.53860958 -7.25092751]
[0.1 0.2 0.3 0.4]