Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -1708,21 +1708,25 @@ func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...in
}

// matchRegexp return true if a specified regexp matches a string.
func matchRegexp(rx interface{}, str interface{}) bool {
func matchRegexp(rx interface{}, str interface{}) (bool, error) {
var r *regexp.Regexp
if rr, ok := rx.(*regexp.Regexp); ok {
r = rr
} else {
r = regexp.MustCompile(fmt.Sprint(rx))
var err error
r, err = regexp.Compile(fmt.Sprint(rx))
if err != nil {
return false, err
}
}

switch v := str.(type) {
case []byte:
return r.Match(v)
return r.Match(v), nil
case string:
return r.MatchString(v)
return r.MatchString(v), nil
default:
return r.MatchString(fmt.Sprint(v))
return r.MatchString(fmt.Sprint(v)), nil
}
}

Expand All @@ -1735,7 +1739,10 @@ func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface
h.Helper()
}

match := matchRegexp(rx, str)
match, err := matchRegexp(rx, str)
if err != nil {
return Fail(t, fmt.Sprintf("Invalid regexp: %v", err), msgAndArgs...)
}

if !match {
Fail(t, fmt.Sprintf("Expect \"%v\" to match \"%v\"", str, rx), msgAndArgs...)
Expand All @@ -1752,7 +1759,10 @@ func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interf
if h, ok := t.(tHelper); ok {
h.Helper()
}
match := matchRegexp(rx, str)
match, err := matchRegexp(rx, str)
if err != nil {
return Fail(t, fmt.Sprintf("Invalid regexp: %v", err), msgAndArgs...)
}

if match {
Fail(t, fmt.Sprintf("Expect \"%v\" to NOT match \"%v\"", str, rx), msgAndArgs...)
Expand Down
15 changes: 15 additions & 0 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2557,6 +2557,21 @@ func TestRegexp(t *testing.T) {
}
}

func TestRegexpInvalidExpression(t *testing.T) {
t.Parallel()

// \C is an invalid escape sequence, it should fail gracefully.
mockT := new(mockTestingT)
False(t, Regexp(mockT, `\C`, "some string"))
True(t, mockT.Failed())
Contains(t, mockT.errorString(), "Invalid regexp")

mockT2 := new(mockTestingT)
False(t, NotRegexp(mockT2, `\C`, "some string"))
True(t, mockT2.Failed())
Contains(t, mockT2.errorString(), "Invalid regexp")
}

func testAutogeneratedFunction() {
defer func() {
if err := recover(); err == nil {
Expand Down