diff --git a/server.go b/server.go index d0a0e9a8..904d9156 100644 --- a/server.go +++ b/server.go @@ -282,12 +282,14 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc) (err error) { + ctx := stream.Context() var traceInfo traceInfo if EnableTracing { traceInfo.tr = trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) defer traceInfo.tr.Finish() traceInfo.firstLine.client = false traceInfo.tr.LazyLog(&traceInfo.firstLine, false) + ctx = trace.NewContext(ctx, traceInfo.tr) defer func() { if err != nil && err != io.EOF { traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) @@ -322,7 +324,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. case compressionNone: statusCode := codes.OK statusDesc := "" - reply, appErr := md.Handler(srv.server, stream.Context(), s.opts.codec, req) + reply, appErr := md.Handler(srv.server, ctx, s.opts.codec, req) if appErr != nil { if err, ok := appErr.(rpcError); ok { statusCode = err.code @@ -331,12 +333,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. statusCode = convertCode(appErr) statusDesc = appErr.Error() } + if traceInfo.tr != nil && statusCode != codes.OK { + traceInfo.tr.LazyLog(stringer(statusDesc), true) + traceInfo.tr.SetError() + } + if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) return err } return nil } + if traceInfo.tr != nil { + traceInfo.tr.LazyLog(stringer("OK"), true) + } opts := &transport.Options{ Last: true, Delay: false, @@ -371,11 +381,13 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp p: &parser{s: stream}, codec: s.opts.codec, tracing: EnableTracing, + ctx: stream.Context(), } if ss.tracing { ss.traceInfo.tr = trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) ss.traceInfo.firstLine.client = false ss.traceInfo.tr.LazyLog(&ss.traceInfo.firstLine, false) + ss.ctx = trace.NewContext(ss.ctx, ss.traceInfo.tr) defer func() { ss.mu.Lock() if err != nil && err != io.EOF { @@ -396,10 +408,24 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.statusDesc = appErr.Error() } } + if ss.tracing { + ss.mu.Lock() + if ss.statusCode != codes.OK { + ss.traceInfo.tr.LazyLog(stringer(ss.statusDesc), true) + ss.traceInfo.tr.SetError() + } else { + ss.traceInfo.tr.LazyLog(stringer("OK"), true) + } + ss.mu.Unlock() + } return t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) } +type stringer string + +func (s stringer) String() string { return string(s) } + func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { sm := stream.Method() if sm != "" && sm[0] == '/' { diff --git a/stream.go b/stream.go index e14664cb..91d8115d 100644 --- a/stream.go +++ b/stream.go @@ -113,6 +113,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth cs.traceInfo.firstLine.deadline = deadline.Sub(time.Now()) } cs.traceInfo.tr.LazyLog(&cs.traceInfo.firstLine, false) + ctx = trace.NewContext(ctx, cs.traceInfo.tr) } t, err := cc.wait(ctx) if err != nil { @@ -283,7 +284,8 @@ type serverStream struct { statusCode codes.Code statusDesc string - tracing bool // set to EnableTracing when the serverStream is created. + tracing bool // set to EnableTracing when the serverStream is created. + ctx context.Context // provides trace.FromContext when tracing mu sync.Mutex // protects traceInfo // traceInfo.tr is set when the serverStream is created (if EnableTracing is true), @@ -292,7 +294,7 @@ type serverStream struct { } func (ss *serverStream) Context() context.Context { - return ss.s.Context() + return ss.ctx } func (ss *serverStream) SendHeader(md metadata.MD) error { @@ -317,7 +319,6 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) ss.traceInfo.tr.SetError() } - ss.mu.Unlock() } }()