diff --git a/.travis.yml b/.travis.yml index d5fc4760..a3c54b6b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,8 @@ matrix: env: RUN386=1 - go: 1.9.x env: GAE=1 - + - go: 1.10.x + env: GRPC_GO_RETRY=on go_import_path: google.golang.org/grpc diff --git a/call.go b/call.go index f73b7d55..180d79d0 100644 --- a/call.go +++ b/call.go @@ -63,31 +63,12 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli var unaryStreamDesc = &StreamDesc{ServerStreams: false, ClientStreams: false} func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error { - // TODO: implement retries in clientStream and make this simply - // newClientStream, SendMsg, RecvMsg. - firstAttempt := true - for { - csInt, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...) - if err != nil { - return err - } - cs := csInt.(*clientStream) - if err := cs.SendMsg(req); err != nil { - if !cs.c.failFast && cs.attempt.s.Unprocessed() && firstAttempt { - // TODO: Add a field to header for grpc-transparent-retry-attempts - firstAttempt = false - continue - } - return err - } - if err := cs.RecvMsg(reply); err != nil { - if !cs.c.failFast && cs.attempt.s.Unprocessed() && firstAttempt { - // TODO: Add a field to header for grpc-transparent-retry-attempts - firstAttempt = false - continue - } - return err - } - return nil + cs, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...) + if err != nil { + return err } + if err := cs.SendMsg(req); err != nil { + return err + } + return cs.RecvMsg(reply) } diff --git a/clientconn.go b/clientconn.go index d5abdb19..199f7747 100644 --- a/clientconn.go +++ b/clientconn.go @@ -26,6 +26,7 @@ import ( "reflect" "strings" "sync" + "sync/atomic" "time" "golang.org/x/net/context" @@ -39,6 +40,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/resolver" _ "google.golang.org/grpc/resolver/dns" // To register dns resolver. @@ -116,6 +118,7 @@ type dialOptions struct { waitForHandshake bool channelzParentID int64 disableServiceConfig bool + disableRetry bool } const ( @@ -126,15 +129,6 @@ const ( defaultReadBufSize = 32 * 1024 ) -func defaultDialOptions() dialOptions { - return dialOptions{ - copts: transport.ConnectOptions{ - WriteBufferSize: defaultWriteBufSize, - ReadBufferSize: defaultReadBufSize, - }, - } -} - // RegisterChannelz turns on channelz service. // This is an EXPERIMENTAL API. func RegisterChannelz() { @@ -453,6 +447,32 @@ func WithDisableServiceConfig() DialOption { } } +// WithDisableRetry returns a DialOption that disables retries, even if the +// service config enables them. This does not impact transparent retries, +// which will happen automatically if no data is written to the wire or if the +// RPC is unprocessed by the remote server. +// +// Retry support is currently disabled by default, but will be enabled by +// default in the future. Until then, it may be enabled by setting the +// environment variable "GRPC_GO_RETRY" to "on". +// +// This API is EXPERIMENTAL. +func WithDisableRetry() DialOption { + return func(o *dialOptions) { + o.disableRetry = true + } +} + +func defaultDialOptions() dialOptions { + return dialOptions{ + disableRetry: !envconfig.Retry, + copts: transport.ConnectOptions{ + WriteBufferSize: defaultWriteBufSize, + ReadBufferSize: defaultReadBufSize, + }, + } +} + // Dial creates a client connection to the given target. func Dial(target string, opts ...DialOption) (*ClientConn, error) { return DialContext(context.Background(), target, opts...) @@ -482,8 +502,10 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * dopts: defaultDialOptions(), blockingpicker: newPickerWrapper(), } + cc.retryThrottler.Store((*retryThrottler)(nil)) cc.ctx, cc.cancel = context.WithCancel(context.Background()) + cc.dopts = defaultDialOptions() for _, opt := range opts { opt(&cc.dopts) } @@ -717,6 +739,7 @@ type ClientConn struct { preBalancerName string // previous balancer name. curAddresses []resolver.Address balancerWrapper *ccBalancerWrapper + retryThrottler atomic.Value channelzID int64 // channelz unique identification number czmu sync.RWMutex @@ -1049,6 +1072,19 @@ func (cc *ClientConn) handleServiceConfig(js string) error { cc.mu.Lock() cc.scRaw = js cc.sc = sc + + if sc.retryThrottling != nil { + newThrottler := &retryThrottler{ + tokens: sc.retryThrottling.MaxTokens, + max: sc.retryThrottling.MaxTokens, + thresh: sc.retryThrottling.MaxTokens / 2, + ratio: sc.retryThrottling.TokenRatio, + } + cc.retryThrottler.Store(newThrottler) + } else { + cc.retryThrottler.Store((*retryThrottler)(nil)) + } + if sc.LB != nil && *sc.LB != grpclbName { // "grpclb" is not a valid balancer option in service config. if cc.curBalancerName == grpclbName { // If current balancer is grpclb, there's at least one grpclb @@ -1062,6 +1098,7 @@ func (cc *ClientConn) handleServiceConfig(js string) error { cc.balancerWrapper.handleResolvedAddrs(cc.curAddresses, nil) } } + cc.mu.Unlock() return nil } @@ -1591,6 +1628,43 @@ func (ac *addrConn) incrCallsFailed() { ac.czmu.Unlock() } +type retryThrottler struct { + max float64 + thresh float64 + ratio float64 + + mu sync.Mutex + tokens float64 // TODO(dfawley): replace with atomic and remove lock. +} + +// throttle subtracts a retry token from the pool and returns whether a retry +// should be throttled (disallowed) based upon the retry throttling policy in +// the service config. +func (rt *retryThrottler) throttle() bool { + if rt == nil { + return false + } + rt.mu.Lock() + defer rt.mu.Unlock() + rt.tokens-- + if rt.tokens < 0 { + rt.tokens = 0 + } + return rt.tokens <= rt.thresh +} + +func (rt *retryThrottler) successfulRPC() { + if rt == nil { + return + } + rt.mu.Lock() + defer rt.mu.Unlock() + rt.tokens += rt.ratio + if rt.tokens > rt.max { + rt.tokens = rt.max + } +} + // ErrClientConnTimeout indicates that the ClientConn cannot establish the // underlying connections within the specified timeout. // diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go new file mode 100644 index 00000000..3ee8740f --- /dev/null +++ b/internal/envconfig/envconfig.go @@ -0,0 +1,35 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package envconfig contains grpc settings configured by environment variables. +package envconfig + +import ( + "os" + "strings" +) + +const ( + prefix = "GRPC_GO_" + retryStr = prefix + "RETRY" +) + +var ( + // Retry is set if retry is explicitly enabled via "GRPC_GO_RETRY=on". + Retry = strings.EqualFold(os.Getenv(retryStr), "on") +) diff --git a/rpc_util.go b/rpc_util.go index 033801f3..836944f5 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -162,10 +162,15 @@ type callInfo struct { creds credentials.PerRPCCredentials contentSubtype string codec baseCodec + disableRetry bool + maxRetryRPCBufferSize int } func defaultCallInfo() *callInfo { - return &callInfo{failFast: true} + return &callInfo{ + failFast: true, + maxRetryRPCBufferSize: 256 * 1024, // 256KB + } } // CallOption configures a Call before it starts or extracts information from @@ -415,6 +420,27 @@ func (o CustomCodecCallOption) before(c *callInfo) error { } func (o CustomCodecCallOption) after(c *callInfo) {} +// MaxRetryRPCBufferSize returns a CallOption that limits the amount of memory +// used for buffering this RPC's requests for retry purposes. +// +// This API is EXPERIMENTAL. +func MaxRetryRPCBufferSize(bytes int) CallOption { + return MaxRetryRPCBufferSizeCallOption{bytes} +} + +// MaxRetryRPCBufferSizeCallOption is a CallOption indicating the amount of +// memory to be used for caching this RPC for retry purposes. +// This is an EXPERIMENTAL API. +type MaxRetryRPCBufferSizeCallOption struct { + MaxRetryRPCBufferSize int +} + +func (o MaxRetryRPCBufferSizeCallOption) before(c *callInfo) error { + c.maxRetryRPCBufferSize = o.MaxRetryRPCBufferSize + return nil +} +func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo) {} + // The format of the payload: compressed or not? type payloadFormat uint8 diff --git a/service_config.go b/service_config.go index 51101c7d..e0d73526 100644 --- a/service_config.go +++ b/service_config.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" ) @@ -56,6 +57,8 @@ type MethodConfig struct { // MaxRespSize is the maximum allowed payload size for an individual response in a // stream (server->client) in bytes. MaxRespSize *int + // RetryPolicy configures retry options for the method. + retryPolicy *retryPolicy } // ServiceConfig is provided by the service provider and contains parameters for how @@ -68,11 +71,84 @@ type ServiceConfig struct { // LB is the load balancer the service providers recommends. The balancer specified // via grpc.WithBalancer will override this. LB *string - // Methods contains a map for the methods in this service. - // If there is an exact match for a method (i.e. /service/method) in the map, use the corresponding MethodConfig. - // If there's no exact match, look for the default config for the service (/service/) and use the corresponding MethodConfig if it exists. - // Otherwise, the method has no MethodConfig to use. + + // Methods contains a map for the methods in this service. If there is an + // exact match for a method (i.e. /service/method) in the map, use the + // corresponding MethodConfig. If there's no exact match, look for the + // default config for the service (/service/) and use the corresponding + // MethodConfig if it exists. Otherwise, the method has no MethodConfig to + // use. Methods map[string]MethodConfig + + // If a retryThrottlingPolicy is provided, gRPC will automatically throttle + // retry attempts and hedged RPCs when the client’s ratio of failures to + // successes exceeds a threshold. + // + // For each server name, the gRPC client will maintain a token_count which is + // initially set to maxTokens, and can take values between 0 and maxTokens. + // + // Every outgoing RPC (regardless of service or method invoked) will change + // token_count as follows: + // + // - Every failed RPC will decrement the token_count by 1. + // - Every successful RPC will increment the token_count by tokenRatio. + // + // If token_count is less than or equal to maxTokens / 2, then RPCs will not + // be retried and hedged RPCs will not be sent. + retryThrottling *retryThrottlingPolicy +} + +// retryPolicy defines the go-native version of the retry policy defined by the +// service config here: +// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#integration-with-service-config +type retryPolicy struct { + // MaxAttempts is the maximum number of attempts, including the original RPC. + // + // This field is required and must be two or greater. + maxAttempts int + + // Exponential backoff parameters. The initial retry attempt will occur at + // random(0, initialBackoffMS). In general, the nth attempt will occur at + // random(0, + // min(initialBackoffMS*backoffMultiplier**(n-1), maxBackoffMS)). + // + // These fields are required and must be greater than zero. + initialBackoff time.Duration + maxBackoff time.Duration + backoffMultiplier float64 + + // The set of status codes which may be retried. + // + // Status codes are specified as strings, e.g., "UNAVAILABLE". + // + // This field is required and must be non-empty. + // Note: a set is used to store this for easy lookup. + retryableStatusCodes map[codes.Code]bool +} + +type jsonRetryPolicy struct { + MaxAttempts int + InitialBackoff string + MaxBackoff string + BackoffMultiplier float64 + RetryableStatusCodes []codes.Code +} + +// retryThrottlingPolicy defines the go-native version of the retry throttling +// policy defined by the service config here: +// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#integration-with-service-config +type retryThrottlingPolicy struct { + // The number of tokens starts at maxTokens. The token_count will always be + // between 0 and maxTokens. + // + // This field is required and must be greater than zero. + MaxTokens float64 + // The amount of tokens to add on each successful RPC. Typically this will + // be some number between 0 and 1, e.g., 0.1. + // + // This field is required and must be greater than zero. Up to 3 decimal + // places are supported. + TokenRatio float64 } func parseDuration(s *string) (*time.Duration, error) { @@ -142,12 +218,14 @@ type jsonMC struct { Timeout *string MaxRequestMessageBytes *int64 MaxResponseMessageBytes *int64 + RetryPolicy *jsonRetryPolicy } // TODO(lyuxuan): delete this struct after cleaning up old service config implementation. type jsonSC struct { LoadBalancingPolicy *string MethodConfig *[]jsonMC + RetryThrottling *retryThrottlingPolicy } func parseServiceConfig(js string) (ServiceConfig, error) { @@ -158,8 +236,9 @@ func parseServiceConfig(js string) (ServiceConfig, error) { return ServiceConfig{}, err } sc := ServiceConfig{ - LB: rsc.LoadBalancingPolicy, - Methods: make(map[string]MethodConfig), + LB: rsc.LoadBalancingPolicy, + Methods: make(map[string]MethodConfig), + retryThrottling: rsc.RetryThrottling, } if rsc.MethodConfig == nil { return sc, nil @@ -179,6 +258,10 @@ func parseServiceConfig(js string) (ServiceConfig, error) { WaitForReady: m.WaitForReady, Timeout: d, } + if mc.retryPolicy, err = convertRetryPolicy(m.RetryPolicy); err != nil { + grpclog.Warningf("grpc: parseServiceConfig error unmarshaling %s due to %v", js, err) + return ServiceConfig{}, err + } if m.MaxRequestMessageBytes != nil { if *m.MaxRequestMessageBytes > int64(maxInt) { mc.MaxReqSize = newInt(maxInt) @@ -200,9 +283,56 @@ func parseServiceConfig(js string) (ServiceConfig, error) { } } + if sc.retryThrottling != nil { + if sc.retryThrottling.MaxTokens <= 0 || + sc.retryThrottling.MaxTokens >= 1000 || + sc.retryThrottling.TokenRatio <= 0 { + // Illegal throttling config; disable throttling. + sc.retryThrottling = nil + } + } return sc, nil } +func convertRetryPolicy(jrp *jsonRetryPolicy) (p *retryPolicy, err error) { + if jrp == nil { + return nil, nil + } + ib, err := parseDuration(&jrp.InitialBackoff) + if err != nil { + return nil, err + } + mb, err := parseDuration(&jrp.MaxBackoff) + if err != nil { + return nil, err + } + + if jrp.MaxAttempts <= 1 || + *ib <= 0 || + *mb <= 0 || + jrp.BackoffMultiplier <= 0 || + len(jrp.RetryableStatusCodes) == 0 { + grpclog.Warningf("grpc: ignoring retry policy %v due to illegal configuration", jrp) + return nil, nil + } + + rp := &retryPolicy{ + maxAttempts: jrp.MaxAttempts, + initialBackoff: *ib, + maxBackoff: *mb, + backoffMultiplier: jrp.BackoffMultiplier, + retryableStatusCodes: make(map[codes.Code]bool), + } + if rp.maxAttempts > 5 { + // TODO(retry): Make the max maxAttempts configurable. + rp.maxAttempts = 5 + } + for _, code := range jrp.RetryableStatusCodes { + rp.retryableStatusCodes[code] = true + } + return rp, nil +} + func min(a, b *int) *int { if *a < *b { return a diff --git a/stream.go b/stream.go index 2a93eda6..465e6753 100644 --- a/stream.go +++ b/stream.go @@ -21,6 +21,8 @@ package grpc import ( "errors" "io" + "math" + "strconv" "sync" "time" @@ -29,7 +31,9 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/encoding" + "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" @@ -57,7 +61,9 @@ type StreamDesc struct { // // All errors returned from Stream are compatible with the status package. type Stream interface { - // Context returns the context for this stream. + // Context returns the context for this stream. If called from the client, + // Should be done after Header or RecvMsg. Otherwise, retries may not be + // possible to perform. Context() context.Context // SendMsg is generally called by generated code. On error, SendMsg aborts // the stream and returns an RPC status on the client side. On the server @@ -228,15 +234,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } trInfo.tr.LazyLog(&trInfo.firstLine, false) ctx = trace.NewContext(ctx, trInfo.tr) - defer func() { - if err != nil { - // Need to call tr.finish() if error is returned. - // Because tr will not be returned to caller. - trInfo.tr.LazyPrintf("RPC: [%v]", err) - trInfo.tr.SetError() - trInfo.tr.Finish() - } - }() } ctx = newContextWithRPCInfo(ctx, c.failFast) sh := cc.dopts.copts.StatsHandler @@ -250,80 +247,41 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth FailFast: c.failFast, } sh.HandleRPC(ctx, begin) - defer func() { - if err != nil { - // Only handle end stats if err != nil. - end := &stats.End{ - Client: true, - Error: err, - BeginTime: beginTime, - EndTime: time.Now(), - } - sh.HandleRPC(ctx, end) - } - }() - } - - var ( - t transport.ClientTransport - s *transport.Stream - done func(balancer.DoneInfo) - ) - for { - // Check to make sure the context has expired. This will prevent us from - // looping forever if an error occurs for wait-for-ready RPCs where no data - // is sent on the wire. - select { - case <-ctx.Done(): - return nil, toRPCErr(ctx.Err()) - default: - } - - t, done, err = cc.getTransport(ctx, c.failFast) - if err != nil { - return nil, err - } - - s, err = t.NewStream(ctx, callHdr) - if err != nil { - if done != nil { - done(balancer.DoneInfo{Err: err}) - done = nil - } - // In the event of any error from NewStream, we never attempted to write - // anything to the wire, so we can retry indefinitely for non-fail-fast - // RPCs. - if !c.failFast { - continue - } - return nil, toRPCErr(err) - } - break } cs := &clientStream{ - opts: opts, - c: c, - cc: cc, - desc: desc, - codec: c.codec, - cp: cp, - comp: comp, - cancel: cancel, - attempt: &csAttempt{ - t: t, - s: s, - p: &parser{r: s}, - done: done, - dc: cc.dopts.dc, - ctx: ctx, - trInfo: trInfo, - statsHandler: sh, - beginTime: beginTime, - }, + callHdr: callHdr, + ctx: ctx, + methodConfig: &mc, + opts: opts, + callInfo: c, + cc: cc, + desc: desc, + codec: c.codec, + cp: cp, + comp: comp, + cancel: cancel, + beginTime: beginTime, + firstAttempt: true, } - cs.c.stream = cs - cs.attempt.cs = cs + if !cc.dopts.disableRetry { + cs.retryThrottler = cc.retryThrottler.Load().(*retryThrottler) + } + + cs.callInfo.stream = cs + // Only this initial attempt has stats/tracing. + // TODO(dfawley): move to newAttempt when per-attempt stats are implemented. + if err := cs.newAttemptLocked(sh, trInfo); err != nil { + cs.finish(err) + return nil, err + } + + op := func(a *csAttempt) error { return a.newStream() } + if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }); err != nil { + cs.finish(err) + return nil, err + } + if desc != unaryStreamDesc { // Listen on cc and stream contexts to cleanup when the user closes the // ClientConn or cancels the stream context. In all other cases, an error @@ -342,12 +300,45 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth return cs, nil } +func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo traceInfo) error { + cs.attempt = &csAttempt{ + cs: cs, + dc: cs.cc.dopts.dc, + statsHandler: sh, + trInfo: trInfo, + } + + if err := cs.ctx.Err(); err != nil { + return toRPCErr(err) + } + t, done, err := cs.cc.getTransport(cs.ctx, cs.callInfo.failFast) + if err != nil { + return err + } + cs.attempt.t = t + cs.attempt.done = done + return nil +} + +func (a *csAttempt) newStream() error { + cs := a.cs + cs.callHdr.PreviousAttempts = cs.numRetries + s, err := a.t.NewStream(cs.ctx, cs.callHdr) + if err != nil { + return toRPCErr(err) + } + cs.attempt.s = s + cs.attempt.p = &parser{r: s} + return nil +} + // clientStream implements a client side Stream. type clientStream struct { - opts []CallOption - c *callInfo - cc *ClientConn - desc *StreamDesc + callHdr *transport.CallHdr + opts []CallOption + callInfo *callInfo + cc *ClientConn + desc *StreamDesc codec baseCodec cp Compressor @@ -355,13 +346,25 @@ type clientStream struct { cancel context.CancelFunc // cancels all attempts - sentLast bool // sent an end stream + sentLast bool // sent an end stream + beginTime time.Time - mu sync.Mutex // guards finished - finished bool // TODO: replace with atomic cmpxchg or sync.Once? + methodConfig *MethodConfig - attempt *csAttempt // the active client stream attempt + ctx context.Context // the application's context, wrapped by stats/tracing + + retryThrottler *retryThrottler // The throttler active when the RPC began. + + mu sync.Mutex + firstAttempt bool // if true, transparent retry is valid + numRetries int // exclusive of transparent retry attempt(s) + numRetriesSincePushback int // retries since pushback; to reset backoff + finished bool // TODO: replace with atomic cmpxchg or sync.Once? + attempt *csAttempt // the active client stream attempt // TODO(hedging): hedging will have multiple attempts simultaneously. + committed bool // active attempt committed for retry? + buffer []func(a *csAttempt) error // operations to replay on retry + bufferSize int // current size of buffer } // csAttempt implements a single transport stream attempt within a @@ -373,53 +376,294 @@ type csAttempt struct { p *parser done func(balancer.DoneInfo) + finished bool dc Decompressor decomp encoding.Compressor decompSet bool - ctx context.Context // the application's context, wrapped by stats/tracing - mu sync.Mutex // guards trInfo.tr // trInfo.tr is set when created (if EnableTracing is true), // and cleared when the finish method is called. trInfo traceInfo statsHandler stats.Handler - beginTime time.Time +} + +func (cs *clientStream) commitAttemptLocked() { + cs.committed = true + cs.buffer = nil +} + +func (cs *clientStream) commitAttempt() { + cs.mu.Lock() + cs.commitAttemptLocked() + cs.mu.Unlock() +} + +// shouldRetry returns nil if the RPC should be retried; otherwise it returns +// the error that should be returned by the operation. +func (cs *clientStream) shouldRetry(err error) error { + if cs.attempt.s == nil && !cs.callInfo.failFast { + // In the event of any error from NewStream (attempt.s == nil), we + // never attempted to write anything to the wire, so we can retry + // indefinitely for non-fail-fast RPCs. + return nil + } + if cs.finished || cs.committed { + // RPC is finished or committed; cannot retry. + return err + } + if cs.firstAttempt && !cs.callInfo.failFast && (cs.attempt.s == nil || cs.attempt.s.Unprocessed()) { + // First attempt, wait-for-ready, stream unprocessed: transparently retry. + cs.firstAttempt = false + return nil + } + cs.firstAttempt = false + if cs.cc.dopts.disableRetry { + return err + } + + pushback := 0 + hasPushback := false + if cs.attempt.s != nil { + if to, toErr := cs.attempt.s.TrailersOnly(); toErr != nil { + // Context error; stop now. + return toErr + } else if !to { + return err + } + + // TODO(retry): Move down if the spec changes to not check server pushback + // before considering this a failure for throttling. + sps := cs.attempt.s.Trailer()["grpc-retry-pushback-ms"] + if len(sps) == 1 { + var e error + if pushback, e = strconv.Atoi(sps[0]); e != nil || pushback < 0 { + grpclog.Infof("Server retry pushback specified to abort (%q).", sps[0]) + cs.retryThrottler.throttle() // This counts as a failure for throttling. + return err + } + hasPushback = true + } else if len(sps) > 1 { + grpclog.Warningf("Server retry pushback specified multiple values (%q); not retrying.", sps) + cs.retryThrottler.throttle() // This counts as a failure for throttling. + return err + } + } + + var code codes.Code + if cs.attempt.s != nil { + code = cs.attempt.s.Status().Code() + } else { + code = status.Convert(err).Code() + } + + rp := cs.methodConfig.retryPolicy + if rp == nil || !rp.retryableStatusCodes[code] { + return err + } + + // Note: the ordering here is important; we count this as a failure + // only if the code matched a retryable code. + if cs.retryThrottler.throttle() { + return err + } + if cs.numRetries+1 >= rp.maxAttempts { + return err + } + + var dur time.Duration + if hasPushback { + dur = time.Millisecond * time.Duration(pushback) + cs.numRetriesSincePushback = 0 + } else { + fact := math.Pow(rp.backoffMultiplier, float64(cs.numRetriesSincePushback)) + cur := float64(rp.initialBackoff) * fact + if max := float64(rp.maxBackoff); cur > max { + cur = max + } + dur = time.Duration(grpcrand.Int63n(int64(cur))) + cs.numRetriesSincePushback++ + } + + // TODO(dfawley): we could eagerly fail here if dur puts us past the + // deadline, but unsure if it is worth doing. + t := time.NewTimer(dur) + select { + case <-t.C: + cs.numRetries++ + return nil + case <-cs.ctx.Done(): + t.Stop() + return status.FromContextError(cs.ctx.Err()).Err() + } +} + +// Returns nil if a retry was performed and succeeded; error otherwise. +func (cs *clientStream) retryLocked(lastErr error) error { + for { + if err := cs.shouldRetry(lastErr); err != nil { + cs.commitAttemptLocked() + return err + } + cs.attempt.finish(lastErr) + if err := cs.newAttemptLocked(nil, traceInfo{}); err != nil { + return err + } + if lastErr = cs.replayBufferLocked(); lastErr == nil { + return nil + } + } } func (cs *clientStream) Context() context.Context { - // TODO(retry): commit the current attempt (the context has peer-aware data). - return cs.attempt.context() + cs.commitAttempt() + // No need to lock before using attempt, since we know it is committed and + // cannot change. + return cs.attempt.s.Context() +} + +func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func()) error { + cs.mu.Lock() + for { + if cs.committed { + cs.mu.Unlock() + return op(cs.attempt) + } + a := cs.attempt + cs.mu.Unlock() + err := op(a) + cs.mu.Lock() + if a != cs.attempt { + // We started another attempt already. + continue + } + if err == nil || err == io.EOF { + onSuccess() + cs.mu.Unlock() + return err + } + if err := cs.retryLocked(err); err != nil { + cs.mu.Unlock() + return err + } + } } func (cs *clientStream) Header() (metadata.MD, error) { - m, err := cs.attempt.header() + var m metadata.MD + err := cs.withRetry(func(a *csAttempt) error { + var err error + m, err = a.s.Header() + return toRPCErr(err) + }, cs.commitAttemptLocked) if err != nil { - // TODO(retry): maybe retry on error or commit attempt on success. - err = toRPCErr(err) cs.finish(err) } return m, err } func (cs *clientStream) Trailer() metadata.MD { - // TODO(retry): on error, maybe retry (trailers-only). - return cs.attempt.trailer() + // On RPC failure, we never need to retry, because usage requires that + // RecvMsg() returned a non-nil error before calling this function is valid. + // We would have retried earlier if necessary. + // + // Commit the attempt anyway, just in case users are not following those + // directions -- it will prevent races and should not meaningfully impact + // performance. + cs.commitAttempt() + if cs.attempt.s == nil { + return nil + } + return cs.attempt.s.Trailer() +} + +func (cs *clientStream) replayBufferLocked() error { + a := cs.attempt + for _, f := range cs.buffer { + if err := f(a); err != nil { + return err + } + } + return nil +} + +func (cs *clientStream) bufferForRetryLocked(sz int, op func(a *csAttempt) error) { + // Note: we still will buffer if retry is disabled (for transparent retries). + if cs.committed { + return + } + cs.bufferSize += sz + if cs.bufferSize > cs.callInfo.maxRetryRPCBufferSize { + cs.commitAttemptLocked() + return + } + cs.buffer = append(cs.buffer, op) } func (cs *clientStream) SendMsg(m interface{}) (err error) { - // TODO(retry): buffer message for replaying if not committed. - return cs.attempt.sendMsg(m) + defer func() { + if err != nil && err != io.EOF { + // Call finish on the client stream for errors generated by this SendMsg + // call, as these indicate problems created by this client. (Transport + // errors are converted to an io.EOF error in csAttempt.sendMsg; the real + // error will be returned from RecvMsg eventually in that case, or be + // retried.) + cs.finish(err) + } + }() + // TODO: Check cs.sentLast and error if we already ended the stream. + if !cs.desc.ClientStreams { + cs.sentLast = true + } + data, err := encode(cs.codec, m) + if err != nil { + return err + } + compData, err := compress(data, cs.cp, cs.comp) + if err != nil { + return err + } + hdr, payload := msgHeader(data, compData) + // TODO(dfawley): should we be checking len(data) instead? + if len(payload) > *cs.callInfo.maxSendMessageSize { + return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize) + } + op := func(a *csAttempt) error { + err := a.sendMsg(m, hdr, payload, data) + // nil out the message and uncomp when replaying; they are only needed for + // stats which is disabled for subsequent attempts. + m, data = nil, nil + return err + } + return cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) }) } -func (cs *clientStream) RecvMsg(m interface{}) (err error) { - // TODO(retry): maybe retry on error or commit attempt on success. - return cs.attempt.recvMsg(m) +func (cs *clientStream) RecvMsg(m interface{}) error { + err := cs.withRetry(func(a *csAttempt) error { + err := a.recvMsg(m) + if err != nil || !cs.desc.ServerStreams { + // err != nil or non-server-streaming indicates end of stream. + a.finish(err) + } + return err + }, cs.commitAttemptLocked) + if err != nil || !cs.desc.ServerStreams { + // err != nil or non-server-streaming indicates end of stream. + cs.finish(err) + } + return err } func (cs *clientStream) CloseSend() error { - cs.attempt.closeSend() + if cs.sentLast { + // TODO: return an error and finish the stream instead, due to API misuse? + return nil + } + cs.sentLast = true + op := func(a *csAttempt) error { return a.t.Write(a.s, nil, nil, &transport.Options{Last: true}) } + cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }) + // We never returned an error here for reasons. return nil } @@ -434,7 +678,11 @@ func (cs *clientStream) finish(err error) { return } cs.finished = true + cs.commitAttemptLocked() cs.mu.Unlock() + if err == nil { + cs.retryThrottler.successfulRPC() + } if channelz.IsOn() { if err != nil { cs.cc.incrCallsFailed() @@ -442,46 +690,20 @@ func (cs *clientStream) finish(err error) { cs.cc.incrCallsSucceeded() } } - // TODO(retry): commit current attempt if necessary. - cs.attempt.finish(err) - for _, o := range cs.opts { - o.after(cs.c) + if cs.attempt != nil { + cs.attempt.finish(err) + } + // after functions all rely upon having a stream. + if cs.attempt.s != nil { + for _, o := range cs.opts { + o.after(cs.callInfo) + } } cs.cancel() } -func (a *csAttempt) context() context.Context { - return a.s.Context() -} - -func (a *csAttempt) header() (metadata.MD, error) { - return a.s.Header() -} - -func (a *csAttempt) trailer() metadata.MD { - return a.s.Trailer() -} - -func (a *csAttempt) sendMsg(m interface{}) (err error) { - // TODO Investigate how to signal the stats handling party. - // generate error stats if err != nil && err != io.EOF? +func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error { cs := a.cs - defer func() { - // For non-client-streaming RPCs, we return nil instead of EOF on success - // because the generated code requires it. finish is not called; RecvMsg() - // will call it with the stream's status independently. - if err == io.EOF && !cs.desc.ClientStreams { - err = nil - } - if err != nil && err != io.EOF { - // Call finish on the client stream for errors generated by this SendMsg - // call, as these indicate problems created by this client. (Transport - // errors are converted to an io.EOF error below; the real error will be - // returned from RecvMsg eventually in that case, or be retried.) - cs.finish(err) - } - }() - // TODO: Check cs.sentLast and error if we already ended the stream. if EnableTracing { a.mu.Lock() if a.trInfo.tr != nil { @@ -489,44 +711,26 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) { } a.mu.Unlock() } - data, err := encode(cs.codec, m) - if err != nil { - return err - } - compData, err := compress(data, cs.cp, cs.comp) - if err != nil { - return err - } - hdr, payload := msgHeader(data, compData) - // TODO(dfawley): should we be checking len(data) instead? - if len(payload) > *cs.c.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.c.maxSendMessageSize) - } - - if !cs.desc.ClientStreams { - cs.sentLast = true - } - err = a.t.Write(a.s, hdr, payload, &transport.Options{Last: !cs.desc.ClientStreams}) - if err == nil { - if a.statsHandler != nil { - a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, data, payload, time.Now())) + if err := a.t.Write(a.s, hdr, payld, &transport.Options{Last: !cs.desc.ClientStreams}); err != nil { + if !cs.desc.ClientStreams { + // For non-client-streaming RPCs, we return nil instead of EOF on error + // because the generated code requires it. finish is not called; RecvMsg() + // will call it with the stream's status independently. + return nil } - if channelz.IsOn() { - a.t.IncrMsgSent() - } - return nil + return io.EOF } - return io.EOF + if a.statsHandler != nil { + a.statsHandler.HandleRPC(cs.ctx, outPayload(true, m, data, payld, time.Now())) + } + if channelz.IsOn() { + a.t.IncrMsgSent() + } + return nil } func (a *csAttempt) recvMsg(m interface{}) (err error) { cs := a.cs - defer func() { - if err != nil || !cs.desc.ServerStreams { - // err != nil or non-server-streaming indicates end of stream. - cs.finish(err) - } - }() var inPayload *stats.InPayload if a.statsHandler != nil { inPayload = &stats.InPayload{ @@ -549,7 +753,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) { // Only initialize this state once per stream. a.decompSet = true } - err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, inPayload, a.decomp) + err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, inPayload, a.decomp) if err != nil { if err == io.EOF { if statusErr := a.s.Status().Err(); statusErr != nil { @@ -567,7 +771,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) { a.mu.Unlock() } if inPayload != nil { - a.statsHandler.HandleRPC(a.ctx, inPayload) + a.statsHandler.HandleRPC(cs.ctx, inPayload) } if channelz.IsOn() { a.t.IncrMsgRecv() @@ -579,7 +783,7 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) { // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.c.maxReceiveMessageSize, nil, a.decomp) + err = recv(a.p, cs.codec, a.s, a.dc, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decomp) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) } @@ -589,37 +793,39 @@ func (a *csAttempt) recvMsg(m interface{}) (err error) { return toRPCErr(err) } -func (a *csAttempt) closeSend() { - cs := a.cs - if cs.sentLast { - return - } - cs.sentLast = true - cs.attempt.t.Write(cs.attempt.s, nil, nil, &transport.Options{Last: true}) - // We ignore errors from Write. Any error it would return would also be - // returned by a subsequent RecvMsg call, and the user is supposed to always - // finish the stream by calling RecvMsg until it returns err != nil. -} - func (a *csAttempt) finish(err error) { a.mu.Lock() - a.t.CloseStream(a.s, err) + if a.finished { + return + } + a.finished = true + if err == io.EOF { + // Ending a stream with EOF indicates a success. + err = nil + } + if a.s != nil { + a.t.CloseStream(a.s, err) + } if a.done != nil { + br := false + if a.s != nil { + br = a.s.BytesReceived() + } a.done(balancer.DoneInfo{ Err: err, - BytesSent: true, - BytesReceived: a.s.BytesReceived(), + BytesSent: a.s != nil, + BytesReceived: br, }) } if a.statsHandler != nil { end := &stats.End{ Client: true, - BeginTime: a.beginTime, + BeginTime: a.cs.beginTime, EndTime: time.Now(), Error: err, } - a.statsHandler.HandleRPC(a.ctx, end) + a.statsHandler.HandleRPC(a.cs.ctx, end) } if a.trInfo.tr != nil { if err == nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index 67099000..e47b7534 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -1888,9 +1888,7 @@ func TestStreamingRPCWithTimeoutInServiceConfigRecv(t *testing.T) { te.resolverScheme = r.Scheme() te.nonBlockingDial = true - fmt.Println("1") cc := te.clientConn() - fmt.Println("10") tc := testpb.NewTestServiceClient(cc) r.NewAddress([]resolver.Address{{Addr: te.srvAddr}}) @@ -3075,7 +3073,9 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) { argSize = 1 respSize = -1 ) - ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = metadata.NewOutgoingContext(ctx, testMetadata) stream, err := tc.FullDuplexCall(ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -3151,20 +3151,20 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) { } } -func TestRetry(t *testing.T) { +func TestTransparentRetry(t *testing.T) { defer leakcheck.Check(t) for _, e := range listTestEnv() { if e.name == "handler-tls" { // Fails with RST_STREAM / FLOW_CONTROL_ERROR continue } - testRetry(t, e) + testTransparentRetry(t, e) } } -// This test make sure RPCs are retried times when they receive a RST_STREAM +// This test makes sure RPCs are retried times when they receive a RST_STREAM // with the REFUSED_STREAM error code, which the InTapHandle provokes. -func testRetry(t *testing.T, e env) { +func testTransparentRetry(t *testing.T, e env) { te := newTest(t, e) attempts := 0 successAttempt := 2 @@ -4845,8 +4845,11 @@ type stubServer struct { // A client connected to this service the test may use. Created in Start(). client testpb.TestServiceClient + cc *grpc.ClientConn cleanups []func() // Lambdas executed in Stop(); populated by Start(). + + r *manual.Resolver } func (ss *stubServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { @@ -4858,7 +4861,11 @@ func (ss *stubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallSer } // Start starts the server and creates a client connected to it. -func (ss *stubServer) Start(sopts []grpc.ServerOption) error { +func (ss *stubServer) Start(sopts []grpc.ServerOption, dopts ...grpc.DialOption) error { + r, cleanup := manual.GenerateAndRegisterManualResolver() + ss.r = r + ss.cleanups = append(ss.cleanups, cleanup) + lis, err := net.Listen("tcp", "localhost:0") if err != nil { return fmt.Errorf(`net.Listen("tcp", "localhost:0") = %v`, err) @@ -4870,16 +4877,40 @@ func (ss *stubServer) Start(sopts []grpc.ServerOption) error { go s.Serve(lis) ss.cleanups = append(ss.cleanups, s.Stop) - cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + target := ss.r.Scheme() + ":///" + lis.Addr().String() + + opts := append([]grpc.DialOption{grpc.WithInsecure()}, dopts...) + cc, err := grpc.Dial(target, opts...) if err != nil { - return fmt.Errorf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + return fmt.Errorf("grpc.Dial(%q) = %v", target, err) } + ss.cc = cc + ss.r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + if err := ss.waitForReady(cc); err != nil { + return err + } + ss.cleanups = append(ss.cleanups, func() { cc.Close() }) ss.client = testpb.NewTestServiceClient(cc) return nil } +func (ss *stubServer) waitForReady(cc *grpc.ClientConn) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for { + s := cc.GetState() + if s == connectivity.Ready { + return nil + } + if !cc.WaitForStateChange(ctx, s) { + // ctx got timeout or canceled. + return ctx.Err() + } + } +} + func (ss *stubServer) Stop() { for i := len(ss.cleanups) - 1; i >= 0; i-- { ss.cleanups[i]() @@ -5125,7 +5156,7 @@ func TestClientWriteFailsAfterServerClosesStream(t *testing.T) { } sopts := []grpc.ServerOption{} if err := ss.Start(sopts); err != nil { - t.Fatalf("Error starting endpoing server: %v", err) + t.Fatalf("Error starting endpoint server: %v", err) } defer ss.Stop() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -5143,7 +5174,6 @@ func TestClientWriteFailsAfterServerClosesStream(t *testing.T) { t.Fatalf("stream.Send(_) = %v, want io.EOF", err) } } - } type windowSizeConfig struct { diff --git a/test/retry_test.go b/test/retry_test.go new file mode 100644 index 00000000..c53b66db --- /dev/null +++ b/test/retry_test.go @@ -0,0 +1,551 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "fmt" + "io" + "os" + "reflect" + "strconv" + "strings" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/envconfig" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +func enableRetry() func() { + old := envconfig.Retry + envconfig.Retry = true + return func() { envconfig.Retry = old } +} + +func TestRetryUnary(t *testing.T) { + defer enableRetry()() + i := -1 + ss := &stubServer{ + emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + i++ + switch i { + case 0, 2, 5: + return &testpb.Empty{}, nil + case 6, 8, 11: + return nil, status.New(codes.Internal, "non-retryable error").Err() + } + return nil, status.New(codes.AlreadyExists, "retryable error").Err() + }, + } + if err := ss.Start([]grpc.ServerOption{}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + ss.r.NewServiceConfig(`{ + "methodConfig": [{ + "name": [{"service": "grpc.testing.TestService"}], + "waitForReady": true, + "retryPolicy": { + "MaxAttempts": 4, + "InitialBackoff": ".01s", + "MaxBackoff": ".01s", + "BackoffMultiplier": 1.0, + "RetryableStatusCodes": [ "ALREADY_EXISTS" ] + } + }]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + for { + if ctx.Err() != nil { + t.Fatalf("Timed out waiting for service config update") + } + if ss.cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall").WaitForReady != nil { + break + } + time.Sleep(time.Millisecond) + } + cancel() + + testCases := []struct { + code codes.Code + count int + }{ + {codes.OK, 0}, + {codes.OK, 2}, + {codes.OK, 5}, + {codes.Internal, 6}, + {codes.Internal, 8}, + {codes.Internal, 11}, + {codes.AlreadyExists, 15}, + } + for _, tc := range testCases { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}) + cancel() + if status.Code(err) != tc.code { + t.Fatalf("EmptyCall(_, _) = _, %v; want _, ", err, tc.code) + } + if i != tc.count { + t.Fatalf("i = %v; want %v", i, tc.count) + } + } +} + +func TestRetryDisabledByDefault(t *testing.T) { + if strings.EqualFold(os.Getenv("GRPC_GO_RETRY"), "on") { + return + } + i := -1 + ss := &stubServer{ + emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + i++ + switch i { + case 0: + return nil, status.New(codes.AlreadyExists, "retryable error").Err() + } + return &testpb.Empty{}, nil + }, + } + if err := ss.Start([]grpc.ServerOption{}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + ss.r.NewServiceConfig(`{ + "methodConfig": [{ + "name": [{"service": "grpc.testing.TestService"}], + "waitForReady": true, + "retryPolicy": { + "MaxAttempts": 4, + "InitialBackoff": ".01s", + "MaxBackoff": ".01s", + "BackoffMultiplier": 1.0, + "RetryableStatusCodes": [ "ALREADY_EXISTS" ] + } + }]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + for { + if ctx.Err() != nil { + t.Fatalf("Timed out waiting for service config update") + } + if ss.cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall").WaitForReady != nil { + break + } + time.Sleep(time.Millisecond) + } + cancel() + + testCases := []struct { + code codes.Code + count int + }{ + {codes.AlreadyExists, 0}, + } + for _, tc := range testCases { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}) + cancel() + if status.Code(err) != tc.code { + t.Fatalf("EmptyCall(_, _) = _, %v; want _, ", err, tc.code) + } + if i != tc.count { + t.Fatalf("i = %v; want %v", i, tc.count) + } + } +} + +func TestRetryThrottling(t *testing.T) { + defer enableRetry()() + i := -1 + ss := &stubServer{ + emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + i++ + switch i { + case 0, 3, 6, 10, 11, 12, 13, 14, 16, 18: + return &testpb.Empty{}, nil + } + return nil, status.New(codes.Unavailable, "retryable error").Err() + }, + } + if err := ss.Start([]grpc.ServerOption{}); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + ss.r.NewServiceConfig(`{ + "methodConfig": [{ + "name": [{"service": "grpc.testing.TestService"}], + "waitForReady": true, + "retryPolicy": { + "MaxAttempts": 4, + "InitialBackoff": ".01s", + "MaxBackoff": ".01s", + "BackoffMultiplier": 1.0, + "RetryableStatusCodes": [ "UNAVAILABLE" ] + } + }], + "retryThrottling": { + "maxTokens": 10, + "tokenRatio": 0.5 + } + }`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + for { + if ctx.Err() != nil { + t.Fatalf("Timed out waiting for service config update") + } + if ss.cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall").WaitForReady != nil { + break + } + time.Sleep(time.Millisecond) + } + cancel() + + testCases := []struct { + code codes.Code + count int + }{ + {codes.OK, 0}, // tokens = 10 + {codes.OK, 3}, // tokens = 8.5 (10 - 2 failures + 0.5 success) + {codes.OK, 6}, // tokens = 6 + {codes.Unavailable, 8}, // tokens = 5 -- first attempt is retried; second aborted. + {codes.Unavailable, 9}, // tokens = 4 + {codes.OK, 10}, // tokens = 4.5 + {codes.OK, 11}, // tokens = 5 + {codes.OK, 12}, // tokens = 5.5 + {codes.OK, 13}, // tokens = 6 + {codes.OK, 14}, // tokens = 6.5 + {codes.OK, 16}, // tokens = 5.5 + {codes.Unavailable, 17}, // tokens = 4.5 + } + for _, tc := range testCases { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}) + cancel() + if status.Code(err) != tc.code { + t.Errorf("EmptyCall(_, _) = _, %v; want _, ", err, tc.code) + } + if i != tc.count { + t.Errorf("i = %v; want %v", i, tc.count) + } + } +} + +func TestRetryStreaming(t *testing.T) { + defer enableRetry()() + req := func(b byte) *testpb.StreamingOutputCallRequest { + return &testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: []byte{b}}} + } + res := func(b byte) *testpb.StreamingOutputCallResponse { + return &testpb.StreamingOutputCallResponse{Payload: &testpb.Payload{Body: []byte{b}}} + } + + largePayload, _ := newPayload(testpb.PayloadType_COMPRESSABLE, 500) + + type serverOp func(stream testpb.TestService_FullDuplexCallServer) error + type clientOp func(stream testpb.TestService_FullDuplexCallClient) error + + // Server Operations + sAttempts := func(n int) serverOp { + return func(stream testpb.TestService_FullDuplexCallServer) error { + const key = "grpc-previous-rpc-attempts" + md, ok := metadata.FromIncomingContext(stream.Context()) + if !ok { + return status.Errorf(codes.Internal, "server: no header metadata received") + } + if got := md[key]; len(got) != 1 || got[0] != strconv.Itoa(n) { + return status.Errorf(codes.Internal, "server: metadata = %v; want ", md, key, n) + } + return nil + } + } + sReq := func(b byte) serverOp { + return func(stream testpb.TestService_FullDuplexCallServer) error { + want := req(b) + if got, err := stream.Recv(); err != nil || !proto.Equal(got, want) { + return status.Errorf(codes.Internal, "server: Recv() = %v, %v; want %v, ", got, err, want) + } + return nil + } + } + sReqPayload := func(p *testpb.Payload) serverOp { + return func(stream testpb.TestService_FullDuplexCallServer) error { + want := &testpb.StreamingOutputCallRequest{Payload: p} + if got, err := stream.Recv(); err != nil || !proto.Equal(got, want) { + return status.Errorf(codes.Internal, "server: Recv() = %v, %v; want %v, ", got, err, want) + } + return nil + } + } + sRes := func(b byte) serverOp { + return func(stream testpb.TestService_FullDuplexCallServer) error { + msg := res(b) + if err := stream.Send(msg); err != nil { + return status.Errorf(codes.Internal, "server: Send(%v) = %v; want ", msg, err) + } + return nil + } + } + sErr := func(c codes.Code) serverOp { + return func(stream testpb.TestService_FullDuplexCallServer) error { + return status.New(c, "").Err() + } + } + sCloseSend := func() serverOp { + return func(stream testpb.TestService_FullDuplexCallServer) error { + if msg, err := stream.Recv(); msg != nil || err != io.EOF { + return status.Errorf(codes.Internal, "server: Recv() = %v, %v; want , io.EOF", msg, err) + } + return nil + } + } + sPushback := func(s string) serverOp { + return func(stream testpb.TestService_FullDuplexCallServer) error { + stream.SetTrailer(metadata.MD{"grpc-retry-pushback-ms": []string{s}}) + return nil + } + } + + // Client Operations + cReq := func(b byte) clientOp { + return func(stream testpb.TestService_FullDuplexCallClient) error { + msg := req(b) + if err := stream.Send(msg); err != nil { + return fmt.Errorf("client: Send(%v) = %v; want ", msg, err) + } + return nil + } + } + cReqPayload := func(p *testpb.Payload) clientOp { + return func(stream testpb.TestService_FullDuplexCallClient) error { + msg := &testpb.StreamingOutputCallRequest{Payload: p} + if err := stream.Send(msg); err != nil { + return fmt.Errorf("client: Send(%v) = %v; want ", msg, err) + } + return nil + } + } + cRes := func(b byte) clientOp { + return func(stream testpb.TestService_FullDuplexCallClient) error { + want := res(b) + if got, err := stream.Recv(); err != nil || !proto.Equal(got, want) { + return fmt.Errorf("client: Recv() = %v, %v; want %v, ", got, err, want) + } + return nil + } + } + cErr := func(c codes.Code) clientOp { + return func(stream testpb.TestService_FullDuplexCallClient) error { + want := status.New(c, "").Err() + if c == codes.OK { + want = io.EOF + } + res, err := stream.Recv() + if res != nil || + ((err == nil) != (want == nil)) || + (want != nil && !reflect.DeepEqual(err, want)) { + return fmt.Errorf("client: Recv() = %v, %v; want , %v", res, err, want) + } + return nil + } + } + cCloseSend := func() clientOp { + return func(stream testpb.TestService_FullDuplexCallClient) error { + if err := stream.CloseSend(); err != nil { + return fmt.Errorf("client: CloseSend() = %v; want ", err) + } + return nil + } + } + var curTime time.Time + cGetTime := func() clientOp { + return func(_ testpb.TestService_FullDuplexCallClient) error { + curTime = time.Now() + return nil + } + } + cCheckElapsed := func(d time.Duration) clientOp { + return func(_ testpb.TestService_FullDuplexCallClient) error { + if elapsed := time.Since(curTime); elapsed < d { + return fmt.Errorf("Elapsed time: %v; want >= %v", elapsed, d) + } + return nil + } + } + cHdr := func() clientOp { + return func(stream testpb.TestService_FullDuplexCallClient) error { + _, err := stream.Header() + return err + } + } + cCtx := func() clientOp { + return func(stream testpb.TestService_FullDuplexCallClient) error { + stream.Context() + return nil + } + } + + testCases := []struct { + desc string + serverOps []serverOp + clientOps []clientOp + }{{ + desc: "Non-retryable error code", + serverOps: []serverOp{sReq(1), sErr(codes.Internal)}, + clientOps: []clientOp{cReq(1), cErr(codes.Internal)}, + }, { + desc: "One retry necessary", + serverOps: []serverOp{sReq(1), sErr(codes.Unavailable), sReq(1), sAttempts(1), sRes(1)}, + clientOps: []clientOp{cReq(1), cRes(1), cErr(codes.OK)}, + }, { + desc: "Exceed max attempts (4); check attempts header on server", + serverOps: []serverOp{ + sReq(1), sErr(codes.Unavailable), + sReq(1), sAttempts(1), sErr(codes.Unavailable), + sAttempts(2), sReq(1), sErr(codes.Unavailable), + sAttempts(3), sReq(1), sErr(codes.Unavailable), + }, + clientOps: []clientOp{cReq(1), cErr(codes.Unavailable)}, + }, { + desc: "Multiple requests", + serverOps: []serverOp{ + sReq(1), sReq(2), sErr(codes.Unavailable), + sReq(1), sReq(2), sRes(5), + }, + clientOps: []clientOp{cReq(1), cReq(2), cRes(5), cErr(codes.OK)}, + }, { + desc: "Multiple successive requests", + serverOps: []serverOp{ + sReq(1), sErr(codes.Unavailable), + sReq(1), sReq(2), sErr(codes.Unavailable), + sReq(1), sReq(2), sReq(3), sRes(5), + }, + clientOps: []clientOp{cReq(1), cReq(2), cReq(3), cRes(5), cErr(codes.OK)}, + }, { + desc: "No retry after receiving", + serverOps: []serverOp{ + sReq(1), sErr(codes.Unavailable), + sReq(1), sRes(3), sErr(codes.Unavailable), + }, + clientOps: []clientOp{cReq(1), cRes(3), cErr(codes.Unavailable)}, + }, { + desc: "No retry after header", + serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)}, + clientOps: []clientOp{cReq(1), cHdr(), cErr(codes.Unavailable)}, + }, { + desc: "No retry after context", + serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)}, + clientOps: []clientOp{cReq(1), cCtx(), cErr(codes.Unavailable)}, + }, { + desc: "Replaying close send", + serverOps: []serverOp{ + sReq(1), sReq(2), sCloseSend(), sErr(codes.Unavailable), + sReq(1), sReq(2), sCloseSend(), sRes(1), sRes(3), sRes(5), + }, + clientOps: []clientOp{cReq(1), cReq(2), cCloseSend(), cRes(1), cRes(3), cRes(5), cErr(codes.OK)}, + }, { + desc: "Negative server pushback - no retry", + serverOps: []serverOp{sReq(1), sPushback("-1"), sErr(codes.Unavailable)}, + clientOps: []clientOp{cReq(1), cErr(codes.Unavailable)}, + }, { + desc: "Non-numeric server pushback - no retry", + serverOps: []serverOp{sReq(1), sPushback("xxx"), sErr(codes.Unavailable)}, + clientOps: []clientOp{cReq(1), cErr(codes.Unavailable)}, + }, { + desc: "Multiple server pushback values - no retry", + serverOps: []serverOp{sReq(1), sPushback("100"), sPushback("10"), sErr(codes.Unavailable)}, + clientOps: []clientOp{cReq(1), cErr(codes.Unavailable)}, + }, { + desc: "1s server pushback - delayed retry", + serverOps: []serverOp{sReq(1), sPushback("1000"), sErr(codes.Unavailable), sReq(1), sRes(2)}, + clientOps: []clientOp{cGetTime(), cReq(1), cRes(2), cCheckElapsed(time.Second), cErr(codes.OK)}, + }, { + desc: "Overflowing buffer - no retry", + serverOps: []serverOp{sReqPayload(largePayload), sErr(codes.Unavailable)}, + clientOps: []clientOp{cReqPayload(largePayload), cErr(codes.Unavailable)}, + }} + + var serverOpIter int + var serverOps []serverOp + ss := &stubServer{ + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + for serverOpIter < len(serverOps) { + op := serverOps[serverOpIter] + serverOpIter++ + if err := op(stream); err != nil { + return err + } + } + return nil + }, + } + if err := ss.Start([]grpc.ServerOption{}, grpc.WithDefaultCallOptions(grpc.MaxRetryRPCBufferSize(200))); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + ss.r.NewServiceConfig(`{ + "methodConfig": [{ + "name": [{"service": "grpc.testing.TestService"}], + "waitForReady": true, + "retryPolicy": { + "MaxAttempts": 4, + "InitialBackoff": ".01s", + "MaxBackoff": ".01s", + "BackoffMultiplier": 1.0, + "RetryableStatusCodes": [ "UNAVAILABLE" ] + } + }]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + for { + if ctx.Err() != nil { + t.Fatalf("Timed out waiting for service config update") + } + if ss.cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall").WaitForReady != nil { + break + } + time.Sleep(time.Millisecond) + } + cancel() + + for _, tc := range testCases { + func() { + serverOpIter = 0 + serverOps = tc.serverOps + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + stream, err := ss.client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v: Error while creating stream: %v", tc.desc, err) + } + for _, op := range tc.clientOps { + if err := op(stream); err != nil { + t.Errorf("%v: %v", tc.desc, err) + break + } + } + if serverOpIter != len(serverOps) { + t.Errorf("%v: serverOpIter = %v; want %v", tc.desc, serverOpIter, len(serverOps)) + } + }() + } +} diff --git a/transport/go16.go b/transport/go16.go index 5babcf9b..e0d00115 100644 --- a/transport/go16.go +++ b/transport/go16.go @@ -25,6 +25,7 @@ import ( "net/http" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "golang.org/x/net/context" ) @@ -34,15 +35,15 @@ func dialContext(ctx context.Context, network, address string) (net.Conn, error) return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address) } -// ContextErr converts the error from context package into a StreamError. -func ContextErr(err error) StreamError { +// ContextErr converts the error from context package into a status error. +func ContextErr(err error) error { switch err { case context.DeadlineExceeded: - return streamErrorf(codes.DeadlineExceeded, "%v", err) + return status.Error(codes.DeadlineExceeded, err.Error()) case context.Canceled: - return streamErrorf(codes.Canceled, "%v", err) + return status.Error(codes.Canceled, err.Error()) } - return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) + return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) } // contextFromRequest returns a background context. diff --git a/transport/go17.go b/transport/go17.go index b7fa6bdb..4d515b00 100644 --- a/transport/go17.go +++ b/transport/go17.go @@ -26,6 +26,7 @@ import ( "net/http" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" netctx "golang.org/x/net/context" ) @@ -35,15 +36,15 @@ func dialContext(ctx context.Context, network, address string) (net.Conn, error) return (&net.Dialer{}).DialContext(ctx, network, address) } -// ContextErr converts the error from context package into a StreamError. -func ContextErr(err error) StreamError { +// ContextErr converts the error from context package into a status error. +func ContextErr(err error) error { switch err { case context.DeadlineExceeded, netctx.DeadlineExceeded: - return streamErrorf(codes.DeadlineExceeded, "%v", err) + return status.Error(codes.DeadlineExceeded, err.Error()) case context.Canceled, netctx.Canceled: - return streamErrorf(codes.Canceled, "%v", err) + return status.Error(codes.Canceled, err.Error()) } - return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) + return status.Errorf(codes.Internal, "Unexpected error from context packet: %v", err) } // contextFromRequest returns a context from the HTTP Request. diff --git a/transport/http2_client.go b/transport/http2_client.go index 968d4a92..528efcd4 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -22,6 +22,7 @@ import ( "io" "math" "net" + "strconv" "strings" "sync" "sync/atomic" @@ -373,6 +374,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)}) headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"}) + if callHdr.PreviousAttempts > 0 { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-previous-rpc-attempts", Value: strconv.Itoa(callHdr.PreviousAttempts)}) + } if callHdr.SendCompress != "" { headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) @@ -627,7 +631,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) { rst = true rstCode = http2.ErrCodeCancel } - t.closeStream(s, err, rst, rstCode, nil, nil, false) + t.closeStream(s, err, rst, rstCode, status.Convert(err), nil, false) } func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.ErrCode, st *status.Status, mdata map[string][]string, eosReceived bool) { @@ -651,6 +655,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. close(s.done) // If headerChan isn't closed, then close it. if atomic.SwapUint32(&s.headerDone, 1) == 0 { + s.noHeaders = true close(s.headerChan) } cleanup := &cleanupStream{ @@ -709,7 +714,7 @@ func (t *http2Client) Close() error { } // Notify all active streams. for _, s := range streams { - t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, nil, nil, false) + t.closeStream(s, ErrConnClosing, false, http2.ErrCodeNo, status.New(codes.Unavailable, ErrConnClosing.Desc), nil, false) } if t.statsHandler != nil { connEnd := &stats.ConnEnd{ @@ -1054,7 +1059,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { atomic.StoreUint32(&s.bytesReceived, 1) var state decodeState if err := state.decodeResponseHeader(frame); err != nil { - t.closeStream(s, err, true, http2.ErrCodeProtocol, nil, nil, false) + t.closeStream(s, err, true, http2.ErrCodeProtocol, status.New(codes.Internal, err.Error()), nil, false) // Something wrong. Stops reading even when there is remaining. return } @@ -1090,6 +1095,8 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { if len(state.mdata) > 0 { s.header = state.mdata } + } else { + s.noHeaders = true } close(s.headerChan) } @@ -1140,7 +1147,9 @@ func (t *http2Client) reader() { t.mu.Unlock() if s != nil { // use error detail to provide better err message - t.closeStream(s, streamErrorf(http2ErrConvTab[se.Code], "%v", t.framer.fr.ErrorDetail()), true, http2.ErrCodeProtocol, nil, nil, false) + code := http2ErrConvTab[se.Code] + msg := t.framer.fr.ErrorDetail().Error() + t.closeStream(s, streamError(code, msg), true, http2.ErrCodeProtocol, status.New(code, msg), nil, false) } continue } else { diff --git a/transport/http_util.go b/transport/http_util.go index 456edcb8..0c1b2b1c 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -137,6 +137,9 @@ func isReservedHeader(hdr string) bool { "grpc-status", "grpc-timeout", "grpc-status-details-bin", + // Intentionally exclude grpc-previous-rpc-attempts and + // grpc-retry-pushback-ms, which are "reserved", but their API + // intentionally works via metadata. "te": return true default: @@ -144,8 +147,8 @@ func isReservedHeader(hdr string) bool { } } -// isWhitelistedHeader checks whether hdr should be propagated -// into metadata visible to users. +// isWhitelistedHeader checks whether hdr should be propagated into metadata +// visible to users, even though it is classified as "reserved", above. func isWhitelistedHeader(hdr string) bool { switch hdr { case ":authority", "user-agent": diff --git a/transport/transport.go b/transport/transport.go index f51f8788..d409356f 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -191,6 +191,8 @@ type Stream struct { header metadata.MD // the received header metadata. trailer metadata.MD // the key-value map of trailer metadata. + noHeaders bool // set if the client never received headers (set only after the stream is done). + // On the server-side, headerSent is atomically set to 1 when the headers are sent out. headerSent uint32 @@ -282,6 +284,19 @@ func (s *Stream) Header() (metadata.MD, error) { return nil, err } +// TrailersOnly blocks until a header or trailers-only frame is received and +// then returns true if the stream was trailers-only. If the stream ends +// before headers are received, returns true, nil. If a context error happens +// first, returns it as a status error. Client-side only. +func (s *Stream) TrailersOnly() (bool, error) { + err := s.waitOnHeader() + if err != nil { + return false, err + } + // if !headerDone, some other connection error occurred. + return s.noHeaders && atomic.LoadUint32(&s.headerDone) == 1, nil +} + // Trailer returns the cached trailer metedata. Note that if it is not called // after the entire stream is done, it could return an empty MD. Client // side only. @@ -534,6 +549,8 @@ type CallHdr struct { // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests // for more details. ContentSubtype string + + PreviousAttempts int // value of grpc-previous-rpc-attempts header to set } // ClientTransport is the common interface for all gRPC client-side transport @@ -630,6 +647,11 @@ func streamErrorf(c codes.Code, format string, a ...interface{}) StreamError { } } +// streamError creates an StreamError with the specified error code and description. +func streamError(c codes.Code, desc string) StreamError { + return StreamError{Code: c, Desc: desc} +} + // connectionErrorf creates an ConnectionError with the specified error description. func connectionErrorf(temp bool, e error, format string, a ...interface{}) ConnectionError { return ConnectionError{ diff --git a/transport/transport_test.go b/transport/transport_test.go index 744c0dfa..59045290 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -1171,8 +1171,8 @@ func TestLargeMessageSuspension(t *testing.T) { if err != errStreamDone { t.Fatalf("Write got %v, want io.EOF", err) } - expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) - if _, err := s.Read(make([]byte, 8)); err != expectedErr { + expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) + if _, err := s.Read(make([]byte, 8)); err.Error() != expectedErr.Error() { t.Fatalf("Read got %v of type %T, want %v", err, err, expectedErr) } ct.Close() @@ -1200,7 +1200,7 @@ func TestMaxStreams(t *testing.T) { pctx, cancel := context.WithCancel(context.Background()) defer cancel() timer := time.NewTimer(time.Second * 10) - expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) + expectedErr := status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) for { select { case <-timer.C: @@ -1214,7 +1214,7 @@ func TestMaxStreams(t *testing.T) { if str, err := ct.NewStream(ctx, callHdr); err == nil { slist = append(slist, str) continue - } else if err != expectedErr { + } else if err.Error() != expectedErr.Error() { t.Fatalf("ct.NewStream(_,_) = _, %v, want _, %v", err, expectedErr) } timer.Stop() @@ -1735,13 +1735,13 @@ func TestContextErr(t *testing.T) { // input errIn error // outputs - errOut StreamError + errOut error }{ - {context.DeadlineExceeded, StreamError{codes.DeadlineExceeded, context.DeadlineExceeded.Error()}}, - {context.Canceled, StreamError{codes.Canceled, context.Canceled.Error()}}, + {context.DeadlineExceeded, status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())}, + {context.Canceled, status.Error(codes.Canceled, context.Canceled.Error())}, } { err := ContextErr(test.errIn) - if err != test.errOut { + if err.Error() != test.errOut.Error() { t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) } }