diff --git a/root/findbracket.go b/root/findbracket.go new file mode 100644 index 0000000..5ba960b --- /dev/null +++ b/root/findbracket.go @@ -0,0 +1,59 @@ +// Copyright ©2025 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package root + +import "math" + +// FindBracketMono finds a bracket interval [a, b] where f(a)f(b) < 0. +// f must be a monotonically increasing function. +// guess is the initial guess of the bracket search. +func FindBracketMono(f func(float64) float64, guess float64) (a, b float64) { + // Make sure initial guess has the same sign as the root. + f0 := f(0) + if (guess < 0 && f0 < 0) || (guess > 0 && f0 > 0) { + guess *= -1 + } + + // r is the rate in which we adjust the interval. + var r float64 + a = guess + fa := f(a) + if (a > 0) == (fa < 0) { + r = 2 + } else { + r = 0.5 + } + + // Expand bracket until x-axis is crossed. + // maxiter value is based on https://github.com/boostorg/math/blob/boost-1.88.0/include/boost/math/policies/policy.hpp#L130 + const maxiter = 200 + crossed := false + b = a * r + fb := f(b) + for range maxiter { + if math.Signbit(fa) != math.Signbit(fb) || fa == 0 || fb == 0 { + crossed = true + break + } + a, fa = b, fb + b *= r + fb = f(b) + } + // If unable to cross x-axis, return the largest possible bracket. + if !crossed { + if r > 1 { + b = math.Inf(int(math.Copysign(1, b))) + } else { + b = 0 + } + } + + // Ensure a <= b + if a > b { + a, b = b, a + } + + return a, b +} diff --git a/root/findbracket_test.go b/root/findbracket_test.go new file mode 100644 index 0000000..ff4f1e1 --- /dev/null +++ b/root/findbracket_test.go @@ -0,0 +1,73 @@ +// Copyright ©2025 The Gonum Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package root_test + +import ( + "math" + "testing" + + "gonum.org/v1/exp/root" +) + +var findBracketMonoTests = []struct { + name string + f func(float64) float64 + guess float64 + validBracket func(a, b float64) bool +}{ + // Based on https://github.com/boostorg/math/blob/boost-1.88.0/test/test_toms748_solve.cpp + {name: "f4.4", f: func(x float64) float64 { return math.Pow(x, 4) - 0.2 }, guess: 2}, + {name: "f4.6", f: func(x float64) float64 { return math.Pow(x, 6) - 0.2 }, guess: 2}, + {name: "f4.8", f: func(x float64) float64 { return math.Pow(x, 8) - 0.2 }, guess: 2}, + {name: "f4.10", f: func(x float64) float64 { return math.Pow(x, 10) - 0.2 }, guess: 2}, + {name: "f4.12", f: func(x float64) float64 { return math.Pow(x, 12) - 0.2 }, guess: 2}, + + // Based on https://github.com/boostorg/math/blob/boost-1.88.0/test/test_root_finding_concepts.cpp + {name: "f1", f: func(x float64) float64 { return x*x*x - 27 }, guess: 27}, + + // Special cases. + {name: "+Inf", f: func(x float64) float64 { return math.Atan(x) - 2 }, guess: 3, validBracket: func(a, b float64) bool { return a > 0 && math.IsInf(b, 1) }}, + {name: "-Inf", f: func(x float64) float64 { return math.Atan(x) + 2 }, guess: 3, validBracket: func(a, b float64) bool { return math.IsInf(a, -1) && b < 0 }}, + {name: "tiny positive", f: func(x float64) float64 { + rt := math.SmallestNonzeroFloat64 + switch { + case x > rt: + return 1 + case x < rt: + return -1 + default: + return 0 + } + }, guess: 3, validBracket: func(a, b float64) bool { return a == 0 && b > 0 }}, + {name: "tiny negative", f: func(x float64) float64 { + rt := -math.SmallestNonzeroFloat64 + switch { + case x > rt: + return 1 + case x < rt: + return -1 + default: + return 0 + } + }, guess: -3, validBracket: func(a, b float64) bool { return a < 0 && b == 0 }}, +} + +func TestFindBracketMono(t *testing.T) { + t.Parallel() + + for _, test := range findBracketMonoTests { + t.Run(test.name, func(t *testing.T) { + validBracket := test.validBracket + if validBracket == nil { + validBracket = func(a, b float64) bool { return test.f(a)*test.f(b) < 0 && a <= b } + } + + a, b := root.FindBracketMono(test.f, test.guess) + if !validBracket(a, b) { + t.Errorf("%s: invalid bracket (%f, %f)", test.name, a, b) + } + }) + } +}