diff --git a/pkg/util/errutil/errors.go b/pkg/util/errutil/errors.go index afdba2604f3..6feb0c7dac9 100644 --- a/pkg/util/errutil/errors.go +++ b/pkg/util/errutil/errors.go @@ -80,16 +80,40 @@ func (b Base) Errorf(format string, args ...interface{}) Error { } } +// Error makes Base implement the error type. Relying on this is +// discouraged, as the Error type can carry additional information +// that's valuable when debugging. +func (b Base) Error() string { + return b.Errorf("").Error() +} + +func (b Base) Status() StatusReason { + if b.reason == nil { + return StatusUnknown + } + return b.reason.Status() +} + // Is validates that an Error has the same reason and messageID as the // Base. func (b Base) Is(err error) bool { - gfErr := Error{} - ok := errors.As(err, &gfErr) - if !ok { + // The linter complains that it wants to use errors.As because it + // handles unwrapping, we don't want to do that here since we want + // to validate the equality between the two objects. + // errors.Is handles the unwrapping, should you want it. + //nolint:errorlint + base, isBase := err.(Base) + //nolint:errorlint + gfErr, isGrafanaError := err.(Error) + + switch { + case isGrafanaError: + return b.reason == gfErr.Reason && b.messageID == gfErr.MessageID + case isBase: + return b.reason == base.reason && b.messageID == base.messageID + default: return false } - - return b.reason.Status() == gfErr.Reason.Status() && b.messageID == gfErr.MessageID } // Error is the error type for errors within Grafana, extending @@ -138,12 +162,18 @@ func (e Error) Is(other error) bool { // to validate the equality between the two objects. // errors.Is handles the unwrapping, should you want it. //nolint:errorlint - o, ok := other.(Error) - if !ok { + o, isGrafanaError := other.(Error) + //nolint:errorlint + base, isBase := other.(Base) + + switch { + case isGrafanaError: + return o.Reason == e.Reason && o.MessageID == e.MessageID && o.Error() == e.Error() + case isBase: + return base.Is(e) + default: return false } - - return o.Reason == e.Reason && o.MessageID == e.MessageID && o.Error() == e.Error() } // PublicError is derived from Error and only contains information diff --git a/pkg/util/errutil/errors_test.go b/pkg/util/errutil/errors_test.go new file mode 100644 index 00000000000..3db114c2581 --- /dev/null +++ b/pkg/util/errutil/errors_test.go @@ -0,0 +1,74 @@ +package errutil + +import ( + "errors" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBase_Is(t *testing.T) { + baseNotFound := NewBase(StatusNotFound, "test:not-found") + baseInternal := NewBase(StatusInternal, "test:internal") + + tests := []struct { + Base Base + Other error + Expect bool + ExpectUnwrapped bool + }{ + { + Base: Base{}, + Other: errors.New(""), + Expect: false, + }, + { + Base: Base{}, + Other: Base{}, + Expect: true, + }, + { + Base: Base{}, + Other: Error{}, + Expect: true, + }, + { + Base: baseNotFound, + Other: baseNotFound, + Expect: true, + }, + { + Base: baseNotFound, + Other: baseNotFound.Errorf("this is an error derived from baseNotFound, it is considered to be equal to baseNotFound"), + Expect: true, + }, + { + Base: baseNotFound, + Other: baseInternal, + Expect: false, + }, + { + Base: baseInternal, + Other: fmt.Errorf("wrapped, like a burrito: %w", baseInternal.Errorf("oh noes")), + Expect: false, + ExpectUnwrapped: true, + }, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf( + "Base '%s' == '%s' of type %s = %v (%v unwrapped)", + tc.Base.Error(), + tc.Other.Error(), + reflect.TypeOf(tc.Other), + tc.Expect, + tc.Expect || tc.ExpectUnwrapped, + ), func(t *testing.T) { + assert.Equal(t, tc.Expect, tc.Base.Is(tc.Other), "direct comparison") + assert.Equal(t, tc.Expect, errors.Is(tc.Base, tc.Other), "comparison using errors.Is with other as target") + assert.Equal(t, tc.Expect || tc.ExpectUnwrapped, errors.Is(tc.Other, tc.Base), "comparison using errors.Is with base as target, should unwrap other") + }) + } +}