diff --git a/server.go b/server.go index f8ca25a8..41d8c253 100644 --- a/server.go +++ b/server.go @@ -248,18 +248,15 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str return t.Write(stream, p, opts) } -func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc) { - var ( - traceInfo traceInfo - err error - ) +func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc) (err error) { + var traceInfo traceInfo if EnableTracing { traceInfo.tr = trace.New("Recv."+methodFamily(stream.Method()), stream.Method()) defer traceInfo.tr.Finish() traceInfo.firstLine.client = false traceInfo.tr.LazyLog(&traceInfo.firstLine, false) defer func() { - if err != nil { + if err != nil || err != io.EOF { traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) traceInfo.tr.SetError() } @@ -270,7 +267,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. pf, req, err := p.recvMsg() if err == io.EOF { // The entire stream is done (for unary RPC only). - return + return err } if err != nil { switch err := err.(type) { @@ -283,7 +280,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. default: panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) } - return + return err } if traceInfo.tr != nil { traceInfo.tr.LazyLog(&payload{sent: false, msg: req}, true) @@ -303,8 +300,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) + return err } - return + return nil } opts := &transport.Options{ Last: true, @@ -312,17 +310,19 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if err := s.sendResponse(t, stream, reply, compressionNone, opts); err != nil { if _, ok := err.(transport.ConnectionError); ok { - return + return err } if e, ok := err.(transport.StreamError); ok { statusCode = e.Code statusDesc = e.Desc + return err } else { statusCode = codes.Unknown statusDesc = err.Error() + return err } } - t.WriteStatus(stream, statusCode, statusDesc) + return t.WriteStatus(stream, statusCode, statusDesc) if traceInfo.tr != nil { traceInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } @@ -332,7 +332,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } } -func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc) { +func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc) (err error) { ss := &serverStream{ t: t, s: stream, @@ -344,23 +344,34 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.traceInfo.tr = trace.New("Recv."+methodFamily(stream.Method()), stream.Method()) ss.traceInfo.firstLine.client = false ss.traceInfo.tr.LazyLog(&ss.traceInfo.firstLine, false) + defer func() { + ss.mu.Lock() + ss.traceInfo.tr.Finish() + ss.traceInfo.tr = nil + ss.mu.Unlock() + }() + defer func() { + if err != nil { + ss.mu.Lock() + ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + ss.traceInfo.tr.SetError() + ss.mu.Unlock() + } + }() } if appErr := sd.Handler(srv.server, ss); appErr != nil { if err, ok := appErr.(rpcError); ok { ss.statusCode = err.code ss.statusDesc = err.desc + return err } else { ss.statusCode = convertCode(appErr) ss.statusDesc = appErr.Error() + return nil } } - t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) - if ss.tracing { - ss.mu.Lock() - ss.traceInfo.tr.Finish() - ss.traceInfo.tr = nil - ss.mu.Unlock() - } + return t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) + } func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { diff --git a/stream.go b/stream.go index 3579979b..e4df3af6 100644 --- a/stream.go +++ b/stream.go @@ -307,14 +307,21 @@ func (ss *serverStream) SetTrailer(md metadata.MD) { return } -func (ss *serverStream) SendMsg(m interface{}) error { - if ss.tracing { - ss.mu.Lock() - if ss.traceInfo.tr != nil { - ss.traceInfo.tr.LazyLog(&payload{sent: true, msg: m}, true) +func (ss *serverStream) SendMsg(m interface{}) (err error) { + defer func() { + if ss.tracing { + ss.mu.Lock() + if ss.traceInfo.tr != nil { + if err == nil { + ss.traceInfo.tr.LazyLog(&payload{sent: true, msg: m}, true) + } else { + ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + ss.traceInfo.tr.SetError() + } + } + ss.mu.Unlock() } - ss.mu.Unlock() - } + }() out, err := encode(ss.codec, m, compressionNone) if err != nil { err = transport.StreamErrorf(codes.Internal, "grpc: %v", err) @@ -323,12 +330,17 @@ func (ss *serverStream) SendMsg(m interface{}) error { return ss.t.Write(ss.s, out, &transport.Options{Last: false}) } -func (ss *serverStream) RecvMsg(m interface{}) error { +func (ss *serverStream) RecvMsg(m interface{}) (err error) { defer func() { if ss.tracing { ss.mu.Lock() if ss.traceInfo.tr != nil { - ss.traceInfo.tr.LazyLog(&payload{sent: false, msg: m}, true) + if err == nil { + ss.traceInfo.tr.LazyLog(&payload{sent: false, msg: m}, true) + } else { + ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + ss.traceInfo.tr.SetError() + } } ss.mu.Unlock() }