diff --git a/server.go b/server.go index 274f7329..ba68a215 100644 --- a/server.go +++ b/server.go @@ -284,12 +284,15 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc) (err error) { + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() var traceInfo traceInfo if EnableTracing { traceInfo.tr = trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) defer traceInfo.tr.Finish() traceInfo.firstLine.client = false traceInfo.tr.LazyLog(&traceInfo.firstLine, false) + ctx = trace.NewContext(ctx, traceInfo.tr) defer func() { if err != nil && err != io.EOF { traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) @@ -318,13 +321,15 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. return err } if traceInfo.tr != nil { - traceInfo.tr.LazyLog(&payload{sent: false, msg: req}, true) + // TODO: set payload.msg to something that + // prints usefully with %s; req is a []byte. + traceInfo.tr.LazyLog(&payload{sent: false}, true) } switch pf { case compressionNone: statusCode := codes.OK statusDesc := "" - reply, appErr := md.Handler(srv.server, stream.Context(), s.opts.codec, req) + reply, appErr := md.Handler(srv.server, ctx, s.opts.codec, req) if appErr != nil { if err, ok := appErr.(rpcError); ok { statusCode = err.code @@ -333,12 +338,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. statusCode = convertCode(appErr) statusDesc = appErr.Error() } + if traceInfo.tr != nil && statusCode != codes.OK { + traceInfo.tr.LazyLog(stringer(statusDesc), true) + traceInfo.tr.SetError() + } + if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) return err } return nil } + if traceInfo.tr != nil { + traceInfo.tr.LazyLog(stringer("OK"), false) + } opts := &transport.Options{ Last: true, Delay: false, @@ -367,9 +380,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc) (err error) { + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() ss := &serverStream{ t: t, s: stream, + ctx: ctx, p: &parser{s: stream}, codec: s.opts.codec, tracing: EnableTracing, @@ -378,6 +394,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.traceInfo.tr = trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) ss.traceInfo.firstLine.client = false ss.traceInfo.tr.LazyLog(&ss.traceInfo.firstLine, false) + ss.ctx = trace.NewContext(ss.ctx, ss.traceInfo.tr) defer func() { ss.mu.Lock() if err != nil && err != io.EOF { @@ -398,6 +415,16 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.statusDesc = appErr.Error() } } + if ss.tracing { + ss.mu.Lock() + if ss.statusCode != codes.OK { + ss.traceInfo.tr.LazyLog(stringer(ss.statusDesc), true) + ss.traceInfo.tr.SetError() + } else { + ss.traceInfo.tr.LazyLog(stringer("OK"), false) + } + ss.mu.Unlock() + } return t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) } diff --git a/stream.go b/stream.go index a9d7c49c..11737112 100644 --- a/stream.go +++ b/stream.go @@ -132,6 +132,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cs.traceInfo.firstLine.deadline = deadline.Sub(time.Now()) } cs.traceInfo.tr.LazyLog(&cs.traceInfo.firstLine, false) + ctx = trace.NewContext(ctx, cs.traceInfo.tr) } s, err := t.NewStream(ctx, callHdr) if err != nil { @@ -293,6 +294,7 @@ type ServerStream interface { type serverStream struct { t transport.ServerTransport s *transport.Stream + ctx context.Context // provides trace.FromContext when tracing p *parser codec Codec statusCode codes.Code @@ -307,7 +309,7 @@ type serverStream struct { } func (ss *serverStream) Context() context.Context { - return ss.s.Context() + return ss.ctx } func (ss *serverStream) SendHeader(md metadata.MD) error { @@ -332,7 +334,6 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) ss.traceInfo.tr.SetError() } - ss.mu.Unlock() } }() diff --git a/trace.go b/trace.go index 24635740..cde04fbf 100644 --- a/trace.go +++ b/trace.go @@ -114,3 +114,7 @@ type fmtStringer struct { func (f *fmtStringer) String() string { return fmt.Sprintf(f.format, f.a...) } + +type stringer string + +func (s stringer) String() string { return string(s) }