Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Numeric/SpecFunctions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ module Numeric.SpecFunctions (
, log1p
, log1pmx
, log2
-- * Log-sum-exp
, logSumExp
, logSumExpPair
-- * Exponent
, expm1
-- * Factorial
Expand Down
29 changes: 29 additions & 0 deletions Numeric/SpecFunctions/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,35 @@ log1pmx x
where
ax = abs x

-- | Compute log(sum(exp(x_i))) in a numerically stable way using
-- the log-sum-exp trick. This is useful when working with log
-- probabilities to avoid overflow and underflow.
--
-- Uses the identity:
--
-- \[
-- \log \sum_i \exp(x_i) = m + \log \sum_i \exp(x_i - m)
-- \]
--
-- where \(m = \max_i x_i\).
--
-- Returns @-Infinity@ for an empty vector.
logSumExp :: U.Vector Double -> Double
logSumExp xs
| U.null xs = m_neg_inf
| otherwise = m + log (U.sum (U.map (\x -> exp (x - m)) xs))
where
m = U.maximum xs

-- | Compute @log(exp(a) + exp(b))@ in a numerically stable way.
--
-- This is a special case of 'logSumExp' for two arguments, useful
-- when combining two log-probabilities.
logSumExpPair :: Double -> Double -> Double
logSumExpPair a b
| a >= b = a + log1p (exp (b - a))
| otherwise = b + log1p (exp (a - b))

-- | /O(log n)/ Compute the logarithm in base 2 of the given value.
log2 :: Int -> Int
log2 v0
Expand Down
25 changes: 25 additions & 0 deletions tests/Tests/SpecFunctions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,31 @@ tests = testGroup "Special functions"
checkTabularPure 1 (show x) exact (log1p x)
]
----------------
, testGroup "logSumExp"
[ testProperty "logSumExp [a] == a" $ \a ->
not (isNaN a) ==> logSumExp (U.singleton a) == a
, testProperty "logSumExpPair commutative" $ \a b ->
not (isNaN a) && not (isNaN b) ==>
logSumExpPair a b == logSumExpPair b a
, testProperty "logSumExpPair a a == a + log 2" $ \a ->
not (isNaN a) && not (isInfinite a) ==>
within 2 (logSumExpPair a a) (a + log 2)
, testProperty "logSumExp recovers log(sum(exp(x)))" $ \(getNonEmpty -> xs) ->
let v = U.fromList xs
naive = log (U.sum (U.map exp v))
stable = logSumExp v
in not (any isNaN xs) && not (any isInfinite xs) && all (\x -> abs x < 300) xs ==>
within 4 naive stable
, testCase "logSumExp empty == -Infinity" $
assertBool "should be -Infinity" (isInfinite (logSumExp U.empty) && logSumExp U.empty < 0)
, testCase "logSumExp with large values" $ do
let result = logSumExp (U.fromList [1000, 1001, 1002])
assertBool "should be close to 1002.41" (within 4 result 1002.4076059644443)
, testCase "logSumExpPair log-probability combination" $ do
let result = logSumExpPair (-1000) (-1001)
assertBool "should be close to -999.687" (within 4 result (-999.6867383124818))
]
----------------
, testGroup "gamma function"
[ testCase "logGamma table [fractional points" $
forTable "tests/tables/loggamma.dat" $ \[x, exact] -> do
Expand Down