diff --git a/call.go b/call.go index 4efc5bdd..7f1345b2 100644 --- a/call.go +++ b/call.go @@ -50,7 +50,8 @@ import ( // On error, it returns the error and indicates whether the call should be retried. // // TODO(zhaoq): Check whether the received message sequence is valid. -func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) { +// TODO ctx is userCtx, not stream.Context. It is used for stats handling. Change this later if necessary. +func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) { // Try to acquire header metadata from the server if there is any. defer func() { if err != nil { @@ -81,7 +82,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK { // TODO in the current implementation, inTrailer may be handled before inStats in some cases. // Fix the order if necessary. - stats.Handle(stream.Context(), inPayload) + stats.Handle(ctx, inPayload) } c.trailerMD = stream.Trailer() return nil @@ -117,10 +118,12 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd if err != nil { return nil, Errorf(codes.Internal, "grpc: %v", err) } - err = t.Write(stream, outBuf, opts) if outPayload != nil { outPayload.SentTime = time.Now() - stats.Handle(stream.Context(), outPayload) + } + err = t.Write(stream, outBuf, opts) + if outPayload != nil { + stats.Handle(ctx, outPayload) } // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following @@ -247,7 +250,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - err = recvResponse(cc.dopts, t, &c, stream, reply) + err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) if err != nil { if put != nil { put() diff --git a/server.go b/server.go index bac34637..3c4eb067 100644 --- a/server.go +++ b/server.go @@ -552,16 +552,16 @@ func (s *Server) removeConn(c io.Closer) { func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { var ( - cbuf *bytes.Buffer - outStats *stats.OutPayload + cbuf *bytes.Buffer + outPayload *stats.OutPayload ) if cp != nil { cbuf = new(bytes.Buffer) } if stats.On() { - outStats = &stats.OutPayload{} + outPayload = &stats.OutPayload{} } - p, err := encode(s.opts.codec, msg, cp, cbuf, outStats) + p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) if err != nil { // This typically indicates a fatal issue (e.g., memory // corruption or hardware faults) the application program @@ -572,11 +572,12 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str // the optimal option. grpclog.Fatalf("grpc: Server failed to encode response %v", err) } + if outPayload != nil { + outPayload.SentTime = time.Now() + } err = t.Write(stream, p, opts) - if outStats != nil { - outStats.SentTime = time.Now() - - stats.Handle(stream.Context(), outStats) + if outPayload != nil { + stats.Handle(stream.Context(), outPayload) } return err } diff --git a/stream.go b/stream.go index 08a93476..c3838e0b 100644 --- a/stream.go +++ b/stream.go @@ -213,6 +213,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth tracing: EnableTracing, trInfo: trInfo, + + userCtx: ctx, } if cc.dopts.cp != nil { cs.cbuf = new(bytes.Buffer) @@ -265,6 +267,10 @@ type clientStream struct { // trInfo.tr is set when the clientStream is created (if EnableTracing is true), // and is set to nil when the clientStream's finish method is called. trInfo traceInfo + + // Keep the user context for stats handling. + // All stats handling should use the user context instead of the stream context. + userCtx context.Context } func (cs *clientStream) Context() context.Context { @@ -280,7 +286,7 @@ func (cs *clientStream) Header() (_ metadata.MD, err error) { EndTime: time.Now(), Error: err, } - stats.Handle(cs.s.Context(), end) + stats.Handle(cs.userCtx, end) } }() m, err := cs.s.Header() @@ -311,7 +317,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { Client: true, Error: err, } - stats.Handle(cs.s.Context(), end) + stats.Handle(cs.userCtx, end) } }() defer func() { @@ -336,13 +342,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } err = toRPCErr(err) }() - var outStats *stats.OutPayload + var outPayload *stats.OutPayload if stats.On() { - outStats = &stats.OutPayload{ + outPayload = &stats.OutPayload{ Client: true, } } - out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outStats) + out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload) defer func() { if cs.cbuf != nil { cs.cbuf.Reset() @@ -351,10 +357,12 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { return Errorf(codes.Internal, "grpc: %v", err) } + if outPayload != nil { + outPayload.SentTime = time.Now() + } err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) - if outStats != nil { - outStats.SentTime = time.Now() - stats.Handle(cs.s.Context(), outStats) + if outPayload != nil { + stats.Handle(cs.userCtx, outPayload) } return err } @@ -371,7 +379,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { EndTime: time.Now(), Error: e, } - stats.Handle(cs.s.Context(), end) + stats.Handle(cs.userCtx, end) } }() var inStats *stats.InPayload @@ -396,7 +404,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { cs.mu.Unlock() } if inStats != nil { - stats.Handle(cs.s.Context(), inStats) + stats.Handle(cs.userCtx, inStats) } if !cs.desc.ClientStreams || cs.desc.ServerStreams { return @@ -557,11 +565,11 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - var outStats *stats.OutPayload + var outPayload *stats.OutPayload if stats.On() { - outStats = &stats.OutPayload{} + outPayload = &stats.OutPayload{} } - out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outStats) + out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload) defer func() { if ss.cbuf != nil { ss.cbuf.Reset() @@ -571,12 +579,14 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { err = Errorf(codes.Internal, "grpc: %v", err) return err } + if outPayload != nil { + outPayload.SentTime = time.Now() + } if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } - if outStats != nil { - outStats.SentTime = time.Now() - stats.Handle(ss.s.Context(), outStats) + if outPayload != nil { + stats.Handle(ss.s.Context(), outPayload) } return nil } diff --git a/transport/http2_client.go b/transport/http2_client.go index 3d034740..edb78853 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -277,6 +277,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if t.authInfo != nil { pr.AuthInfo = t.authInfo } + userCtx := ctx ctx = peer.NewContext(ctx, pr) authData := make(map[string]string) for _, c := range t.creds { @@ -348,6 +349,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, ErrConnClosing } s := t.newStream(ctx, callHdr) + s.userCtx = userCtx t.activeStreams[s.id] = s // This stream is not counted when applySetings(...) initialize t.streamsQuota. @@ -459,7 +461,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea Encryption: callHdr.SendCompress, FailFast: callHdr.FailFast, } - stats.Handle(s.Context(), outHeader) + stats.Handle(s.userCtx, outHeader) } t.writableChan <- 0 return s, nil @@ -896,13 +898,13 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { Client: true, WireLength: int(frame.Header().Length), } - stats.Handle(s.ctx, inHeader) + stats.Handle(s.userCtx, inHeader) } else { inTrailer := &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), } - stats.Handle(s.ctx, inTrailer) + stats.Handle(s.userCtx, inTrailer) } } }() diff --git a/transport/transport.go b/transport/transport.go index 44d77488..b941d8e4 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -168,6 +168,9 @@ type Stream struct { id uint32 // nil for client side Stream. st ServerTransport + // Keep the user context for stats handling. + // All stats handling should use the user context instead of the stream context. + userCtx context.Context // ctx is the associated context of the stream. ctx context.Context // cancel is always nil for client side Stream.