diff --git a/benchmark/worker/benchmark_client.go b/benchmark/worker/benchmark_client.go index 199bbe1f..dfe8c8f0 100644 --- a/benchmark/worker/benchmark_client.go +++ b/benchmark/worker/benchmark_client.go @@ -37,6 +37,7 @@ import ( "math" "runtime" "sync" + "syscall" "time" "golang.org/x/net/context" @@ -85,6 +86,7 @@ type benchmarkClient struct { lastResetTime time.Time histogramOptions stats.HistogramOptions lockingHistograms []lockingHistogram + rusageLastReset *syscall.Rusage } func printClientConfig(config *testpb.ClientConfig) { @@ -226,6 +228,9 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) return nil, err } + rusage := new(syscall.Rusage) + syscall.Getrusage(syscall.RUSAGE_SELF, rusage) + rpcCountPerConn := int(config.OutstandingRpcsPerChannel) bc := &benchmarkClient{ histogramOptions: stats.HistogramOptions{ @@ -236,9 +241,10 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) }, lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns), rpcCountPerConn*len(conns)), - stop: make(chan bool), - lastResetTime: time.Now(), - closeConns: closeConns, + stop: make(chan bool), + lastResetTime: time.Now(), + closeConns: closeConns, + rusageLastReset: rusage, } if err = performRPCs(config, conns, bc); err != nil { @@ -338,8 +344,9 @@ func (bc *benchmarkClient) doCloseLoopStreaming(conns []*grpc.ClientConn, rpcCou // getStats returns the stats for benchmark client. // It resets lastResetTime and all histograms if argument reset is true. func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats { - var timeElapsed float64 + var wallTimeElapsed, uTimeElapsed, sTimeElapsed float64 mergedHistogram := stats.NewHistogram(bc.histogramOptions) + latestRusage := new(syscall.Rusage) if reset { // Merging histogram may take some time. @@ -353,14 +360,21 @@ func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats { mergedHistogram.Merge(toMerge[i]) } - timeElapsed = time.Since(bc.lastResetTime).Seconds() + wallTimeElapsed = time.Since(bc.lastResetTime).Seconds() + syscall.Getrusage(syscall.RUSAGE_SELF, latestRusage) + uTimeElapsed, sTimeElapsed = cpuTimeDiff(bc.rusageLastReset, latestRusage) + + bc.rusageLastReset = latestRusage bc.lastResetTime = time.Now() } else { // Merge only, not reset. for i := range bc.lockingHistograms { bc.lockingHistograms[i].mergeInto(mergedHistogram) } - timeElapsed = time.Since(bc.lastResetTime).Seconds() + + wallTimeElapsed = time.Since(bc.lastResetTime).Seconds() + syscall.Getrusage(syscall.RUSAGE_SELF, latestRusage) + uTimeElapsed, sTimeElapsed = cpuTimeDiff(bc.rusageLastReset, latestRusage) } b := make([]uint32, len(mergedHistogram.Buckets), len(mergedHistogram.Buckets)) @@ -376,9 +390,9 @@ func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats { SumOfSquares: float64(mergedHistogram.SumOfSquares), Count: float64(mergedHistogram.Count), }, - TimeElapsed: timeElapsed, - TimeUser: 0, - TimeSystem: 0, + TimeElapsed: wallTimeElapsed, + TimeUser: uTimeElapsed, + TimeSystem: sTimeElapsed, } } diff --git a/benchmark/worker/benchmark_server.go b/benchmark/worker/benchmark_server.go index 667ef2c1..0d20581d 100644 --- a/benchmark/worker/benchmark_server.go +++ b/benchmark/worker/benchmark_server.go @@ -38,6 +38,7 @@ import ( "strconv" "strings" "sync" + "syscall" "time" "google.golang.org/grpc" @@ -55,11 +56,12 @@ var ( ) type benchmarkServer struct { - port int - cores int - closeFunc func() - mu sync.RWMutex - lastResetTime time.Time + port int + cores int + closeFunc func() + mu sync.RWMutex + lastResetTime time.Time + rusageLastReset *syscall.Rusage } func printServerConfig(config *testpb.ServerConfig) { @@ -156,18 +158,35 @@ func startBenchmarkServer(config *testpb.ServerConfig, serverPort int) (*benchma grpclog.Fatalf("failed to get port number from server address: %v", err) } - return &benchmarkServer{port: p, cores: numOfCores, closeFunc: closeFunc, lastResetTime: time.Now()}, nil + rusage := new(syscall.Rusage) + syscall.Getrusage(syscall.RUSAGE_SELF, rusage) + + return &benchmarkServer{ + port: p, + cores: numOfCores, + closeFunc: closeFunc, + lastResetTime: time.Now(), + rusageLastReset: rusage, + }, nil } // getStats returns the stats for benchmark server. // It resets lastResetTime if argument reset is true. func (bs *benchmarkServer) getStats(reset bool) *testpb.ServerStats { - // TODO wall time, sys time, user time. bs.mu.RLock() defer bs.mu.RUnlock() - timeElapsed := time.Since(bs.lastResetTime).Seconds() + wallTimeElapsed := time.Since(bs.lastResetTime).Seconds() + rusageLatest := new(syscall.Rusage) + syscall.Getrusage(syscall.RUSAGE_SELF, rusageLatest) + uTimeElapsed, sTimeElapsed := cpuTimeDiff(bs.rusageLastReset, rusageLatest) + if reset { bs.lastResetTime = time.Now() + bs.rusageLastReset = rusageLatest + } + return &testpb.ServerStats{ + TimeElapsed: wallTimeElapsed, + TimeUser: uTimeElapsed, + TimeSystem: sTimeElapsed, } - return &testpb.ServerStats{TimeElapsed: timeElapsed, TimeUser: 0, TimeSystem: 0} } diff --git a/benchmark/worker/main.go b/benchmark/worker/main.go index 17c52519..8a804062 100644 --- a/benchmark/worker/main.go +++ b/benchmark/worker/main.go @@ -38,6 +38,8 @@ import ( "fmt" "io" "net" + "net/http" + _ "net/http/pprof" "runtime" "strconv" "time" @@ -50,8 +52,10 @@ import ( ) var ( - driverPort = flag.Int("driver_port", 10000, "port for communication with driver") - serverPort = flag.Int("server_port", 0, "port for benchmark server if not specified by server config message") + driverPort = flag.Int("driver_port", 10000, "port for communication with driver") + serverPort = flag.Int("server_port", 0, "port for benchmark server if not specified by server config message") + pprofPort = flag.Int("pprof_port", -1, "Port for pprof debug server to listen on. Pprof server doesn't start if unset") + blockProfRate = flag.Int("block_prof_rate", 0, "fraction of goroutine blocking events to report in blocking profile") ) type byteBufCodec struct { @@ -227,5 +231,14 @@ func main() { s.Stop() }() + runtime.SetBlockProfileRate(*blockProfRate) + + if *pprofPort >= 0 { + go func() { + grpclog.Println("Starting pprof server on port " + strconv.Itoa(*pprofPort)) + grpclog.Println(http.ListenAndServe("localhost:"+strconv.Itoa(*pprofPort), nil)) + }() + } + s.Serve(lis) } diff --git a/benchmark/worker/util.go b/benchmark/worker/util.go index f0016ce4..6f9b2b03 100644 --- a/benchmark/worker/util.go +++ b/benchmark/worker/util.go @@ -36,6 +36,7 @@ import ( "log" "os" "path/filepath" + "syscall" ) // abs returns the absolute path the given relative file or directory path, @@ -52,6 +53,20 @@ func abs(rel string) string { return filepath.Join(v, rel) } +func cpuTimeDiff(first *syscall.Rusage, latest *syscall.Rusage) (float64, float64) { + var ( + utimeDiffs = latest.Utime.Sec - first.Utime.Sec + utimeDiffus = latest.Utime.Usec - first.Utime.Usec + stimeDiffs = latest.Stime.Sec - first.Stime.Sec + stimeDiffus = latest.Stime.Usec - first.Stime.Usec + ) + + uTimeElapsed := float64(utimeDiffs) + float64(utimeDiffus)*1.0e-6 + sTimeElapsed := float64(stimeDiffs) + float64(stimeDiffus)*1.0e-6 + + return uTimeElapsed, sTimeElapsed +} + func goPackagePath(pkg string) (path string, err error) { gp := os.Getenv("GOPATH") if gp == "" { diff --git a/call.go b/call.go index e92a4bc9..e370aeec 100644 --- a/call.go +++ b/call.go @@ -43,6 +43,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/transport" ) @@ -79,7 +80,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, msgSizeLimit int, t tr return } } - if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK { + if inPayload != nil && err == io.EOF && stream.Status().Code() == codes.OK { // TODO in the current implementation, inTrailer may be handled before inPayload in some cases. // Fix the order if necessary. dopts.copts.StatsHandler.HandleRPC(ctx, inPayload) @@ -267,7 +268,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if _, ok := err.(*rpcError); ok { + if _, ok := err.(status.Status); ok { return err } if err == errConnClosing || err == errConnUnavailable { @@ -321,6 +322,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli put() put = nil } - return Errorf(stream.StatusCode(), "%s", stream.StatusDesc()) + return stream.Status().Err() } } diff --git a/call_test.go b/call_test.go index 3c2165ea..63e87c21 100644 --- a/call_test.go +++ b/call_test.go @@ -46,6 +46,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/grpc/transport" ) @@ -99,21 +100,21 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { return } if v == "weird error" { - h.t.WriteStatus(s, codes.Internal, weirdError) + h.t.WriteStatus(s, status.New(codes.Internal, weirdError)) return } if v == "canceled" { canceled++ - h.t.WriteStatus(s, codes.Internal, "") + h.t.WriteStatus(s, status.New(codes.Internal, "")) return } if v == "port" { - h.t.WriteStatus(s, codes.Internal, h.port) + h.t.WriteStatus(s, status.New(codes.Internal, h.port)) return } if v != expectedRequest { - h.t.WriteStatus(s, codes.Internal, strings.Repeat("A", sizeLargeErr)) + h.t.WriteStatus(s, status.New(codes.Internal, strings.Repeat("A", sizeLargeErr))) return } } @@ -124,7 +125,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { return } h.t.Write(s, reply, &transport.Options{}) - h.t.WriteStatus(s, codes.OK, "") + h.t.WriteStatus(s, status.New(codes.OK, "")) } type server struct { @@ -239,7 +240,7 @@ func TestInvokeLargeErr(t *testing.T) { var reply string req := "hello" err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if _, ok := err.(*rpcError); !ok { + if _, ok := err.(status.Status); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } if Code(err) != codes.Internal || len(ErrorDesc(err)) != sizeLargeErr { @@ -255,7 +256,7 @@ func TestInvokeErrorSpecialChars(t *testing.T) { var reply string req := "weird error" err := Invoke(context.Background(), "/foo/bar", &req, &reply, cc) - if _, ok := err.(*rpcError); !ok { + if _, ok := err.(status.Status); !ok { t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") } if got, want := ErrorDesc(err), weirdError; got != want { diff --git a/rpc_util.go b/rpc_util.go index 832fd464..1fad9d47 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -37,7 +37,6 @@ import ( "bytes" "compress/gzip" "encoding/binary" - "fmt" "io" "io/ioutil" "math" @@ -50,6 +49,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/transport" ) @@ -189,7 +189,9 @@ func Trailer(md *metadata.MD) CallOption { // unary RPC. func Peer(peer *peer.Peer) CallOption { return afterCall(func(c *callInfo) { - *peer = *c.peer + if c.peer != nil { + *peer = *c.peer + } }) } @@ -370,88 +372,56 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ return nil } -// rpcError defines the status from an RPC. -type rpcError struct { - code codes.Code - desc string -} - -func (e *rpcError) Error() string { - return fmt.Sprintf("rpc error: code = %s desc = %s", e.code, e.desc) -} - // Code returns the error code for err if it was produced by the rpc system. // Otherwise, it returns codes.Unknown. +// +// Deprecated; use status.FromError and Code method instead. func Code(err error) codes.Code { - if err == nil { - return codes.OK - } - if e, ok := err.(*rpcError); ok { - return e.code + if s, ok := status.FromError(err); ok { + return s.Code() } return codes.Unknown } // ErrorDesc returns the error description of err if it was produced by the rpc system. // Otherwise, it returns err.Error() or empty string when err is nil. +// +// Deprecated; use status.FromError and Message method instead. func ErrorDesc(err error) string { - if err == nil { - return "" - } - if e, ok := err.(*rpcError); ok { - return e.desc + if s, ok := status.FromError(err); ok { + return s.Message() } return err.Error() } // Errorf returns an error containing an error code and a description; // Errorf returns nil if c is OK. +// +// Deprecated; use status.Errorf instead. func Errorf(c codes.Code, format string, a ...interface{}) error { - if c == codes.OK { - return nil - } - return &rpcError{ - code: c, - desc: fmt.Sprintf(format, a...), - } + return status.Errorf(c, format, a...) } -// toRPCErr converts an error into a rpcError. +// toRPCErr converts an error into an error from the status package. func toRPCErr(err error) error { switch e := err.(type) { - case *rpcError: + case status.Status: return err case transport.StreamError: - return &rpcError{ - code: e.Code, - desc: e.Desc, - } + return status.Error(e.Code, e.Desc) case transport.ConnectionError: - return &rpcError{ - code: codes.Internal, - desc: e.Desc, - } + return status.Error(codes.Internal, e.Desc) default: switch err { case context.DeadlineExceeded: - return &rpcError{ - code: codes.DeadlineExceeded, - desc: err.Error(), - } + return status.Error(codes.DeadlineExceeded, err.Error()) case context.Canceled: - return &rpcError{ - code: codes.Canceled, - desc: err.Error(), - } + return status.Error(codes.Canceled, err.Error()) case ErrClientConnClosing: - return &rpcError{ - code: codes.FailedPrecondition, - desc: err.Error(), - } + return status.Error(codes.FailedPrecondition, err.Error()) } - } - return Errorf(codes.Unknown, "%v", err) + return status.Error(codes.Unknown, err.Error()) } // convertCode converts a standard Go error into its canonical code. Note that diff --git a/rpc_util_test.go b/rpc_util_test.go index 375e42bc..f2b43f0f 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -41,8 +41,8 @@ import ( "testing" "github.com/golang/protobuf/proto" - "golang.org/x/net/context" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" perfpb "google.golang.org/grpc/test/codec_perf" "google.golang.org/grpc/transport" ) @@ -150,51 +150,21 @@ func TestToRPCErr(t *testing.T) { // input errIn error // outputs - errOut *rpcError + errOut error }{ - {transport.StreamError{codes.Unknown, ""}, Errorf(codes.Unknown, "").(*rpcError)}, - {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc).(*rpcError)}, + {transport.StreamError{Code: codes.Unknown, Desc: ""}, status.Error(codes.Unknown, "")}, + {transport.ErrConnClosing, status.Error(codes.Internal, transport.ErrConnClosing.Desc)}, } { err := toRPCErr(test.errIn) - rpcErr, ok := err.(*rpcError) - if !ok { - t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, rpcError{}) + if _, ok := err.(status.Status); !ok { + t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error(codes.Unknown, "")) } - if *rpcErr != *test.errOut { + if !reflect.DeepEqual(err, test.errOut) { t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) } } } -func TestContextErr(t *testing.T) { - for _, test := range []struct { - // input - errIn error - // outputs - errOut transport.StreamError - }{ - {context.DeadlineExceeded, transport.StreamError{codes.DeadlineExceeded, context.DeadlineExceeded.Error()}}, - {context.Canceled, transport.StreamError{codes.Canceled, context.Canceled.Error()}}, - } { - err := transport.ContextErr(test.errIn) - if err != test.errOut { - t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) - } - } -} - -func TestErrorsWithSameParameters(t *testing.T) { - const description = "some description" - e1 := Errorf(codes.AlreadyExists, description).(*rpcError) - e2 := Errorf(codes.AlreadyExists, description).(*rpcError) - if e1 == e2 { - t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2) - } - if Code(e1) != Code(e2) || ErrorDesc(e1) != ErrorDesc(e2) { - t.Fatalf("Expected errors to have same code and description - e1: %p - %v e2: %p - %v", e1, e1, e2, e2) - } -} - // bmEncode benchmarks encoding a Protocol Buffer message containing mSize // bytes. func bmEncode(b *testing.B, mSize int) { diff --git a/server.go b/server.go index 5049763d..e2f46831 100644 --- a/server.go +++ b/server.go @@ -56,6 +56,7 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/tap" "google.golang.org/grpc/transport" ) @@ -694,7 +695,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. stream.SetSendCompress(s.opts.cp.Type()) } p := &parser{r: stream} - for { + for { // TODO: delete pf, req, err := p.recvMsg(s.opts.maxReceiveMessageSize) if err == io.EOF { // The entire stream is done (for unary RPC only). @@ -704,36 +705,35 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. err = Errorf(codes.Internal, io.ErrUnexpectedEOF.Error()) } if err != nil { - switch err := err.(type) { - case *rpcError: - if e := t.WriteStatus(stream, err.code, err.desc); e != nil { + switch st := err.(type) { + case status.Status: + if e := t.WriteStatus(stream, st); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } case transport.ConnectionError: // Nothing to do here. case transport.StreamError: - if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil { + if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } default: - panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) + panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", st, st)) } return err } if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { - switch err := err.(type) { - case *rpcError: - if e := t.WriteStatus(stream, err.code, err.desc); e != nil { + if st, ok := err.(status.Status); ok { + if e := t.WriteStatus(stream, st); e != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } return err - default: - if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) - } - // TODO checkRecvPayload always return RPC error. Add a return here if necessary. } + if e := t.WriteStatus(stream, status.New(codes.Internal, err.Error())); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) + } + + // TODO checkRecvPayload always return RPC error. Add a return here if necessary. } var inPayload *stats.InPayload if sh != nil { @@ -741,8 +741,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. RecvTime: time.Now(), } } - statusCode := codes.OK - statusDesc := "" df := func(v interface{}) error { if inPayload != nil { inPayload.WireLength = len(req) @@ -751,20 +749,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. var err error req, err = s.opts.dc.Do(bytes.NewReader(req)) if err != nil { - if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) - } return Errorf(codes.Internal, err.Error()) } } if len(req) > s.opts.maxReceiveMessageSize { // TODO: Revisit the error code. Currently keep it consistent with // java implementation. - statusCode = codes.InvalidArgument - statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxReceiveMessageSize) + return status.Errorf(codes.InvalidArgument, "grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxReceiveMessageSize) } if err := s.opts.codec.Unmarshal(req, v); err != nil { - return err + return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } if inPayload != nil { inPayload.Payload = v @@ -779,21 +773,20 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) if appErr != nil { - if err, ok := appErr.(*rpcError); ok { - statusCode = err.code - statusDesc = err.desc - } else { - statusCode = convertCode(appErr) - statusDesc = appErr.Error() + appStatus, ok := status.FromError(appErr) + if !ok { + // Convert appErr if it is not a grpc status error. + appErr = status.Error(convertCode(appErr), appErr.Error()) + appStatus, _ = status.FromError(appErr) } - if trInfo != nil && statusCode != codes.OK { - trInfo.tr.LazyLog(stringer(statusDesc), true) + if trInfo != nil { + trInfo.tr.LazyLog(stringer(appStatus.Message()), true) trInfo.tr.SetError() } - if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) + if e := t.WriteStatus(stream, appStatus); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e) } - return Errorf(statusCode, statusDesc) + return appErr } if trInfo != nil { trInfo.tr.LazyLog(stringer("OK"), false) @@ -803,25 +796,17 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Delay: false, } if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { - switch err := err.(type) { - case transport.ConnectionError: - // Nothing to do here. - case transport.StreamError: - statusCode = err.Code - statusDesc = err.Desc - default: - statusCode = codes.Unknown - statusDesc = err.Error() - } + // TODO: Translate error into a status.Status error if necessary? + // TODO: Write status when appropriate. + return err } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } - errWrite := t.WriteStatus(stream, statusCode, statusDesc) - if statusCode != codes.OK { - return Errorf(statusCode, statusDesc) - } - return errWrite + // TODO: Should we be logging if writing status failed here, like above? + // Should the logging be in WriteStatus? Should we ignore the WriteStatus + // error or allow the stats handler to see it? + return t.WriteStatus(stream, status.New(codes.OK, "")) } } @@ -891,32 +876,31 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp appErr = s.opts.streamInt(server, ss, info, sd.Handler) } if appErr != nil { - if err, ok := appErr.(*rpcError); ok { - ss.statusCode = err.code - ss.statusDesc = err.desc - } else if err, ok := appErr.(transport.StreamError); ok { - ss.statusCode = err.Code - ss.statusDesc = err.Desc - } else { - ss.statusCode = convertCode(appErr) - ss.statusDesc = appErr.Error() + switch err := appErr.(type) { + case status.Status: + // Do nothing + case transport.StreamError: + appErr = status.Error(err.Code, err.Desc) + default: + appErr = status.Error(convertCode(appErr), appErr.Error()) } + appStatus, _ := status.FromError(appErr) + if trInfo != nil { + ss.mu.Lock() + ss.trInfo.tr.LazyLog(stringer(appStatus.Message()), true) + ss.trInfo.tr.SetError() + ss.mu.Unlock() + } + t.WriteStatus(ss.s, appStatus) + // TODO: Should we log an error from WriteStatus here and below? + return appErr } if trInfo != nil { ss.mu.Lock() - if ss.statusCode != codes.OK { - ss.trInfo.tr.LazyLog(stringer(ss.statusDesc), true) - ss.trInfo.tr.SetError() - } else { - ss.trInfo.tr.LazyLog(stringer("OK"), false) - } + ss.trInfo.tr.LazyLog(stringer("OK"), false) ss.mu.Unlock() } - errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) - if ss.statusCode != codes.OK { - return Errorf(ss.statusCode, ss.statusDesc) - } - return errWrite + return t.WriteStatus(ss.s, status.New(codes.OK, "")) } @@ -932,7 +916,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.SetError() } errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) - if err := t.WriteStatus(stream, codes.InvalidArgument, errDesc); err != nil { + if err := t.WriteStatus(stream, status.New(codes.InvalidArgument, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -957,7 +941,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.SetError() } errDesc := fmt.Sprintf("unknown service %v", service) - if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -987,7 +971,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str return } errDesc := fmt.Sprintf("unknown method %v", method) - if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { + if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() diff --git a/stats/stats.go b/stats/stats.go index a82448a6..43d6f005 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -184,7 +184,7 @@ type End struct { Client bool // EndTime is the time when the RPC ends. EndTime time.Time - // Error is the error just happened. Its type is gRPC error. + // Error is the error just happened. It implements status.Status if non-nil. Error error } diff --git a/status/status.go b/status/status.go new file mode 100644 index 00000000..0e402081 --- /dev/null +++ b/status/status.go @@ -0,0 +1,160 @@ +/* + * + * Copyright 2017, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package status implements errors returned by gRPC. These errors are +// serialized and transmitted on the wire between server and client, and allow +// for additional data to be transmitted via the Details field in the status +// proto. gRPC service handlers should return an error created by this +// package, and gRPC clients should expect a corresponding error to be +// returned from the RPC call. +// +// This package upholds the invariants that a non-nil error may not +// contain an OK code, and an OK code must result in a nil error. +package status + +import ( + "fmt" + + "github.com/golang/protobuf/proto" + spb "github.com/google/go-genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" +) + +// Status provides access to grpc status details and is implemented by all +// errors returned from this package except nil errors, which are not typed. +// Note: gRPC users should not implement their own Statuses. Custom data may +// be attached to the spb.Status proto's Details field. +type Status interface { + // Code returns the status code. + Code() codes.Code + // Message returns the status message. + Message() string + // Proto returns a copy of the status in proto form. + Proto() *spb.Status + // Err returns an error representing the status. + Err() error +} + +// okStatus is a Status whose Code method returns codes.OK, but does not +// implement error. To represent an OK code as an error, use an untyped nil. +type okStatus struct{} + +func (okStatus) Code() codes.Code { + return codes.OK +} + +func (okStatus) Message() string { + return "" +} + +func (okStatus) Proto() *spb.Status { + return nil +} + +func (okStatus) Err() error { + return nil +} + +// statusError contains a status proto. It is embedded and not aliased to +// allow for accessor functions of the same name. It implements error and +// Status, and a nil statusError should never be returned by this package. +type statusError struct { + *spb.Status +} + +func (se *statusError) Error() string { + return fmt.Sprintf("rpc error: code = %s desc = %s", se.Code(), se.Message()) +} + +func (se *statusError) Code() codes.Code { + return codes.Code(se.Status.Code) +} + +func (se *statusError) Message() string { + return se.Status.Message +} + +func (se *statusError) Proto() *spb.Status { + return proto.Clone(se.Status).(*spb.Status) +} + +func (se *statusError) Err() error { + return se +} + +// New returns a Status representing c and msg. +func New(c codes.Code, msg string) Status { + if c == codes.OK { + return okStatus{} + } + return &statusError{Status: &spb.Status{Code: int32(c), Message: msg}} +} + +// Newf returns New(c, fmt.Sprintf(format, a...)). +func Newf(c codes.Code, format string, a ...interface{}) Status { + return New(c, fmt.Sprintf(format, a...)) +} + +// Error returns an error representing c and msg. If c is OK, returns nil. +func Error(c codes.Code, msg string) error { + return New(c, msg).Err() +} + +// Errorf returns Error(c, fmt.Sprintf(format, a...)). +func Errorf(c codes.Code, format string, a ...interface{}) error { + return Error(c, fmt.Sprintf(format, a...)) +} + +// ErrorProto returns an error representing s. If s.Code is OK, returns nil. +func ErrorProto(s *spb.Status) error { + return FromProto(s).Err() +} + +// FromProto returns a Status representing s. If s.Code is OK, Message and +// Details may be lost. +func FromProto(s *spb.Status) Status { + if s.GetCode() == int32(codes.OK) { + return okStatus{} + } + return &statusError{Status: proto.Clone(s).(*spb.Status)} +} + +// FromError returns a Status representing err if it was produced from this +// package, otherwise it returns nil, false. +func FromError(err error) (s Status, ok bool) { + if err == nil { + return okStatus{}, true + } + s, ok = err.(Status) + return s, ok +} diff --git a/status/status_test.go b/status/status_test.go new file mode 100644 index 00000000..34de196c --- /dev/null +++ b/status/status_test.go @@ -0,0 +1,110 @@ +/* + * + * Copyright 2017, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package status + +import ( + "reflect" + "testing" + + apb "github.com/golang/protobuf/ptypes/any" + spb "github.com/google/go-genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" +) + +func TestErrorsWithSameParameters(t *testing.T) { + const description = "some description" + e1 := Errorf(codes.AlreadyExists, description) + e2 := Errorf(codes.AlreadyExists, description) + if e1 == e2 || !reflect.DeepEqual(e1, e2) { + t.Fatalf("Errors should be equivalent but unique - e1: %v, %v e2: %p, %v", e1.(*statusError), e1, e2.(*statusError), e2) + } +} + +func TestFromToProto(t *testing.T) { + s := &spb.Status{ + Code: int32(codes.Internal), + Message: "test test test", + Details: []*apb.Any{{TypeUrl: "foo", Value: []byte{3, 2, 1}}}, + } + + err := FromProto(s) + if got := err.Proto(); !reflect.DeepEqual(s, got) { + t.Fatalf("Expected errors to be identical - s: %v got: %v", s, got) + } +} + +func TestError(t *testing.T) { + err := Error(codes.Internal, "test description") + if got, want := err.Error(), "rpc error: code = Internal desc = test description"; got != want { + t.Fatalf("err.Error() = %q; want %q", got, want) + } + s := err.(Status) + if got, want := s.Code(), codes.Internal; got != want { + t.Fatalf("err.Code() = %s; want %s", got, want) + } + if got, want := s.Message(), "test description"; got != want { + t.Fatalf("err.Message() = %s; want %s", got, want) + } +} + +func TestErrorOK(t *testing.T) { + err := Error(codes.OK, "foo") + if err != nil { + t.Fatalf("Error(codes.OK, _) = %p; want nil", err.(*statusError)) + } +} + +func TestErrorProtoOK(t *testing.T) { + s := &spb.Status{Code: int32(codes.OK)} + if got := ErrorProto(s); got != nil { + t.Fatalf("ErrorProto(%v) = %v; want nil", s, got) + } +} + +func TestFromError(t *testing.T) { + code, message := codes.Internal, "test description" + err := Error(code, message) + s, ok := FromError(err) + if !ok || s.Code() != code || s.Message() != message || s.Err() == nil { + t.Fatalf("FromError(%v) = %v, %v; want , true", err, s, ok, code, message) + } +} + +func TestFromErrorOK(t *testing.T) { + code, message := codes.OK, "" + s, ok := FromError(nil) + if !ok || s.Code() != code || s.Message() != message || s.Err() != nil { + t.Fatalf("FromError(nil) = %v, %v; want , true", s, ok, code, message) + } +} diff --git a/stream.go b/stream.go index 034bcf7e..2f054867 100644 --- a/stream.go +++ b/stream.go @@ -45,6 +45,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/transport" ) @@ -205,7 +206,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth t, put, err = cc.getTransport(ctx, gopts) if err != nil { // TODO(zhaoq): Probably revisit the error handling. - if _, ok := err.(*rpcError); ok { + if _, ok := err.(status.Status); ok { return nil, err } if err == errConnClosing || err == errConnUnavailable { @@ -268,11 +269,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth case <-s.Done(): // TODO: The trace of the RPC is terminated here when there is no pending // I/O, which is probably not the optimal solution. - if s.StatusCode() == codes.OK { - cs.finish(nil) - } else { - cs.finish(Errorf(s.StatusCode(), "%s", s.StatusDesc())) - } + cs.finish(s.Status().Err()) cs.closeTransportStream(nil) case <-s.GoAway(): cs.finish(errConnDrain) @@ -445,11 +442,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) } if err == io.EOF { - if cs.s.StatusCode() == codes.OK { - cs.finish(err) - return nil + if se := cs.s.Status().Err(); se != nil { + return se } - return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc()) + cs.finish(err) + return nil } return toRPCErr(err) } @@ -457,11 +454,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { cs.closeTransportStream(err) } if err == io.EOF { - if cs.s.StatusCode() == codes.OK { - // Returns io.EOF to indicate the end of the stream. - return + if statusErr := cs.s.Status().Err(); statusErr != nil { + return statusErr } - return Errorf(cs.s.StatusCode(), "%s", cs.s.StatusDesc()) + // Returns io.EOF to indicate the end of the stream. + return } return toRPCErr(err) } @@ -545,18 +542,16 @@ type ServerStream interface { // serverStream implements a server side Stream. type serverStream struct { - t transport.ServerTransport - s *transport.Stream - p *parser - codec Codec - cp Compressor - dc Decompressor - cbuf *bytes.Buffer + t transport.ServerTransport + s *transport.Stream + p *parser + codec Codec + cp Compressor + dc Decompressor + cbuf *bytes.Buffer maxReceiveMessageSize int maxSendMessageSize int - statusCode codes.Code - statusDesc string - trInfo *traceInfo + trInfo *traceInfo statsHandler stats.Handler diff --git a/test/end2end_test.go b/test/end2end_test.go index b857bb62..32f20c76 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -54,6 +54,8 @@ import ( "time" "github.com/golang/protobuf/proto" + anypb "github.com/golang/protobuf/ptypes/any" + spb "github.com/google/go-genproto/googleapis/rpc/status" "golang.org/x/net/context" "golang.org/x/net/http2" "google.golang.org/grpc" @@ -65,6 +67,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" "google.golang.org/grpc/tap" testpb "google.golang.org/grpc/test/grpc_testing" ) @@ -92,8 +95,16 @@ var ( malformedHTTP2Metadata = metadata.MD{ "Key": []string{"foo"}, } - testAppUA = "myApp1/1.0 myApp2/0.9" - failAppUA = "fail-this-RPC" + testAppUA = "myApp1/1.0 myApp2/0.9" + failAppUA = "fail-this-RPC" + detailedError = status.ErrorProto(&spb.Status{ + Code: int32(codes.DataLoss), + Message: "error for testing: " + failAppUA, + Details: []*anypb.Any{{ + TypeUrl: "url", + Value: []byte{6, 0, 0, 6, 1, 3}, + }}, + }) ) var raceMode bool // set by race_test.go in race mode @@ -111,7 +122,7 @@ func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.E // For testing purpose, returns an error if user-agent is failAppUA. // To test that client gets the correct error. if ua, ok := md["user-agent"]; !ok || strings.HasPrefix(ua[0], failAppUA) { - return nil, grpc.Errorf(codes.DataLoss, "error for testing: "+failAppUA) + return nil, detailedError } var str []string for _, entry := range md["user-agent"] { @@ -1815,7 +1826,7 @@ func testHealthCheckOnFailure(t *testing.T, e env) { cc := te.clientConn() wantErr := grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") - if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) { + if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !reflect.DeepEqual(err, wantErr) { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.DeadlineExceeded) } awaitNewConnLogOutput() @@ -1837,7 +1848,7 @@ func testHealthCheckOff(t *testing.T, e env) { te.startServer(&testServer{security: e.security}) defer te.tearDown() want := grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1.Health") - if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) { + if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !reflect.DeepEqual(err, want) { t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want) } } @@ -1864,7 +1875,7 @@ func testUnknownHandler(t *testing.T, e env, unknownHandler grpc.StreamHandler) te.startServer(&testServer{security: e.security}) defer te.tearDown() want := grpc.Errorf(codes.Unauthenticated, "user unauthenticated") - if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !equalErrors(err, want) { + if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !reflect.DeepEqual(err, want) { t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want) } } @@ -1892,7 +1903,7 @@ func testHealthCheckServingStatus(t *testing.T, e env) { t.Fatalf("Got the serving status %v, want SERVING", out.Status) } wantErr := grpc.Errorf(codes.NotFound, "unknown service") - if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !equalErrors(err, wantErr) { + if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !reflect.DeepEqual(err, wantErr) { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.NotFound) } hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING) @@ -1974,8 +1985,8 @@ func testFailedEmptyUnary(t *testing.T, e env) { tc := testpb.NewTestServiceClient(te.clientConn()) ctx := metadata.NewContext(context.Background(), testMetadata) - wantErr := grpc.Errorf(codes.DataLoss, "error for testing: "+failAppUA) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !equalErrors(err, wantErr) { + wantErr := detailedError + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) } } @@ -2141,6 +2152,29 @@ func testPeerClientSide(t *testing.T, e env) { } } +// TestPeerNegative tests that if call fails setting peer +// doesn't cause a segmentation fault. +// issue#1141 https://github.com/grpc/grpc-go/issues/1141 +func TestPeerNegative(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testPeerNegative(t, e) + } +} + +func testPeerNegative(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + peer := new(peer.Peer) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + tc.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(peer)) +} + func TestMetadataUnaryRPC(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -3055,7 +3089,7 @@ func testFailedServerStreaming(t *testing.T, e env) { t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want ", tc, err) } wantErr := grpc.Errorf(codes.DataLoss, "error for testing: "+failAppUA) - if _, err := stream.Recv(); !equalErrors(err, wantErr) { + if _, err := stream.Recv(); !reflect.DeepEqual(err, wantErr) { t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, wantErr) } } @@ -4245,7 +4279,3 @@ func (fw *filterWriter) Write(p []byte) (n int, err error) { } return fw.dst.Write(p) } - -func equalErrors(l, r error) bool { - return grpc.Code(l) == grpc.Code(r) && grpc.ErrorDesc(l) == grpc.ErrorDesc(r) -} diff --git a/transport/handler_server.go b/transport/handler_server.go index 10b6dc0b..5bf63630 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -53,6 +53,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" ) // NewServerHandlerTransport returns a ServerTransport handling gRPC @@ -182,7 +183,7 @@ func (ht *serverHandlerTransport) do(fn func()) error { } } -func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error { +func (ht *serverHandlerTransport) WriteStatus(s *Stream, st status.Status) error { err := ht.do(func() { ht.writeCommonHeaders(s) @@ -192,10 +193,13 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, ht.rw.(http.Flusher).Flush() h := ht.rw.Header() - h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) - if statusDesc != "" { - h.Set("Grpc-Message", encodeGrpcMessage(statusDesc)) + h.Set("Grpc-Status", fmt.Sprintf("%d", st.Code())) + if m := st.Message(); m != "" { + h.Set("Grpc-Message", encodeGrpcMessage(m)) } + + // TODO: Support Grpc-Status-Details-Bin + if md := s.Trailer(); len(md) > 0 { for k, vv := range md { // Clients don't tolerate reading restricted headers after some non restricted ones were sent. @@ -234,6 +238,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { // and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers h.Add("Trailer", "Grpc-Status") h.Add("Trailer", "Grpc-Message") + // TODO: Support Grpc-Status-Details-Bin if s.sendCompress != "" { h.Set("Grpc-Encoding", s.sendCompress) diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 44adf2ee..84378485 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -46,6 +46,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { @@ -298,7 +299,7 @@ func TestHandlerTransport_HandleStreams(t *testing.T) { t.Errorf("stream method = %q; want %q", s.method, want) } st.bodyw.Close() // no body - st.ht.WriteStatus(s, codes.OK, "") + st.ht.WriteStatus(s, status.New(codes.OK, "")) } st.ht.HandleStreams( func(s *Stream) { go handleStream(s) }, @@ -328,7 +329,7 @@ func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) { func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { st := newHandleStreamTest(t) handleStream := func(s *Stream) { - st.ht.WriteStatus(s, statusCode, msg) + st.ht.WriteStatus(s, status.New(statusCode, msg)) } st.ht.HandleStreams( func(s *Stream) { go handleStream(s) }, @@ -379,7 +380,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded) return } - ht.WriteStatus(s, codes.DeadlineExceeded, "too slow") + ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow")) } ht.HandleStreams( func(s *Stream) { go runStream(s) }, diff --git a/transport/http2_client.go b/transport/http2_client.go index d6e2998b..7d726989 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -35,7 +35,6 @@ package transport import ( "bytes" - "fmt" "io" "math" "net" @@ -54,6 +53,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" ) // http2Client implements the ClientTransport interface with HTTP2. @@ -311,7 +311,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { return s } -// NewStream creates a stream and register it into the transport as "active" +// NewStream creates a stream and registers it into the transport as "active" // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { pr := &peer.Peer{ @@ -802,12 +802,9 @@ func (t *http2Client) handleData(f *http2.DataFrame) { return } if err := s.fc.onData(uint32(size)); err != nil { - s.state = streamDone - s.statusCode = codes.Internal - s.statusDesc = err.Error() s.rstStream = true s.rstError = http2.ErrCodeFlowControl - close(s.done) + s.finish(status.New(codes.Internal, err.Error())) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) return @@ -835,10 +832,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { s.mu.Unlock() return } - s.state = streamDone - s.statusCode = codes.Internal - s.statusDesc = "server closed the stream without sending trailers" - close(s.done) + s.finish(status.New(codes.Internal, "server closed the stream without sending trailers")) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) } @@ -854,18 +848,16 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { s.mu.Unlock() return } - s.state = streamDone if !s.headerDone { close(s.headerChan) s.headerDone = true } - s.statusCode, ok = http2ErrConvTab[http2.ErrCode(f.ErrCode)] + statusCode, ok := http2ErrConvTab[http2.ErrCode(f.ErrCode)] if !ok { grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) - s.statusCode = codes.Unknown + statusCode = codes.Unknown } - s.statusDesc = fmt.Sprintf("stream terminated by RST_STREAM with error code: %d", f.ErrCode) - close(s.done) + s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %d", f.ErrCode)) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) } @@ -944,18 +936,17 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } var state decodeState for _, hf := range frame.Fields { - state.processHeaderField(hf) - } - if state.err != nil { - s.mu.Lock() - if !s.headerDone { - close(s.headerChan) - s.headerDone = true + if err := state.processHeaderField(hf); err != nil { + s.mu.Lock() + if !s.headerDone { + close(s.headerChan) + s.headerDone = true + } + s.mu.Unlock() + s.write(recvMsg{err: err}) + // Something wrong. Stops reading even when there is remaining. + return } - s.mu.Unlock() - s.write(recvMsg{err: state.err}) - // Something wrong. Stops reading even when there is remaining. - return } endStream := frame.StreamEnded() @@ -998,10 +989,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { if len(state.mdata) > 0 { s.trailer = state.mdata } - s.statusCode = state.statusCode - s.statusDesc = state.statusDesc - close(s.done) - s.state = streamDone + s.finish(state.status()) s.mu.Unlock() s.write(recvMsg{err: io.EOF}) } diff --git a/transport/http2_server.go b/transport/http2_server.go index f3bc569d..9972a839 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -45,6 +45,7 @@ import ( "sync/atomic" "time" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -55,6 +56,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/tap" ) @@ -227,13 +229,12 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( var state decodeState for _, hf := range frame.Fields { - state.processHeaderField(hf) - } - if err := state.err; err != nil { - if se, ok := err.(StreamError); ok { - t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) + if err := state.processHeaderField(hf); err != nil { + if se, ok := err.(StreamError); ok { + t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) + } + return } - return } if frame.StreamEnded() { @@ -670,7 +671,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // There is no further I/O operations being able to perform on this stream. // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // OK is adopted. -func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error { +func (t *http2Server) WriteStatus(s *Stream, st status.Status) error { var headersSent, hasHeader bool s.mu.Lock() if s.state == streamDone { @@ -701,9 +702,24 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s t.hEnc.WriteField( hpack.HeaderField{ Name: "grpc-status", - Value: strconv.Itoa(int(statusCode)), + Value: strconv.Itoa(int(st.Code())), }) - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)}) + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) + + if p := st.Proto(); p != nil && len(p.Details) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + // TODO: return error instead, when callers are able to handle it. + panic(err) + } + + for k, v := range metadata.New(map[string]string{"grpc-status-details-bin": (string)(stBytes)}) { + for _, entry := range v { + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) + } + } + } + // Attach the trailer metadata. for k, v := range s.trailer { // Clients don't tolerate reading restricted headers after some non restricted ones were sent. diff --git a/transport/http_util.go b/transport/http_util.go index 6b968848..57aad62d 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -44,11 +44,14 @@ import ( "sync/atomic" "time" + "github.com/golang/protobuf/proto" + spb "github.com/google/go-genproto/googleapis/rpc/status" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) const ( @@ -90,13 +93,15 @@ var ( // Records the states during HPACK decoding. Must be reset once the // decoding of the entire headers are finished. type decodeState struct { - err error // first error encountered decoding - encoding string - // statusCode caches the stream status received from the trailer - // the server sent. Client side only. - statusCode codes.Code - statusDesc string + // statusGen caches the stream status received from the trailer the server + // sent. Client side only. Do not access directly. After all trailers are + // parsed, use the status method to retrieve the status. + statusGen status.Status + // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not + // intended for direct access outside of parsing. + rawStatusCode int32 + rawStatusMsg string // Server side only fields. timeoutSet bool timeout time.Duration @@ -119,6 +124,7 @@ func isReservedHeader(hdr string) bool { "grpc-message", "grpc-status", "grpc-timeout", + "grpc-status-details-bin", "te": return true default: @@ -137,12 +143,6 @@ func isWhitelistedPseudoHeader(hdr string) bool { } } -func (d *decodeState) setErr(err error) { - if d.err == nil { - d.err = err - } -} - func validContentType(t string) bool { e := "application/grpc" if !strings.HasPrefix(t, e) { @@ -156,31 +156,45 @@ func validContentType(t string) bool { return true } -func (d *decodeState) processHeaderField(f hpack.HeaderField) { +func (d *decodeState) status() status.Status { + if d.statusGen == nil { + // No status-details were provided; generate status using code/msg. + d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg) + } + return d.statusGen +} + +func (d *decodeState) processHeaderField(f hpack.HeaderField) error { switch f.Name { case "content-type": if !validContentType(f.Value) { - d.setErr(streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value)) - return + return streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value) } case "grpc-encoding": d.encoding = f.Value case "grpc-status": code, err := strconv.Atoi(f.Value) if err != nil { - d.setErr(streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err)) - return + return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err) } - d.statusCode = codes.Code(code) + d.rawStatusCode = int32(code) case "grpc-message": - d.statusDesc = decodeGrpcMessage(f.Value) + d.rawStatusMsg = decodeGrpcMessage(f.Value) + case "grpc-status-details-bin": + _, v, err := metadata.DecodeKeyValue("grpc-status-details-bin", f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + } + s := &spb.Status{} + if err := proto.Unmarshal([]byte(v), s); err != nil { + return streamErrorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) + } + d.statusGen = status.FromProto(s) case "grpc-timeout": d.timeoutSet = true var err error - d.timeout, err = decodeTimeout(f.Value) - if err != nil { - d.setErr(streamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) - return + if d.timeout, err = decodeTimeout(f.Value); err != nil { + return streamErrorf(codes.Internal, "transport: malformed time-out: %v", err) } case ":path": d.method = f.Value @@ -192,11 +206,12 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { k, v, err := metadata.DecodeKeyValue(f.Name, f.Value) if err != nil { grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err) - return + return nil } d.mdata[k] = append(d.mdata[k], v) } } + return nil } type timeoutUnit uint8 diff --git a/transport/transport.go b/transport/transport.go index 51716803..3b8bd01c 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -51,6 +51,7 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" + "google.golang.org/grpc/status" "google.golang.org/grpc/tap" ) @@ -212,9 +213,8 @@ type Stream struct { // true iff headerChan is closed. Used to avoid closing headerChan // multiple times. headerDone bool - // the status received from the server. - statusCode codes.Code - statusDesc string + // the status error received from the server. + status status.Status // rstStream indicates whether a RST_STREAM frame needs to be sent // to the server to signify that this stream is closing. rstStream bool @@ -284,14 +284,9 @@ func (s *Stream) Method() string { return s.method } -// StatusCode returns statusCode received from the server. -func (s *Stream) StatusCode() codes.Code { - return s.statusCode -} - -// StatusDesc returns statusDesc received from the server. -func (s *Stream) StatusDesc() string { - return s.statusDesc +// Status returns the status received from the server. +func (s *Stream) Status() status.Status { + return s.status } // SetHeader sets the header metadata. This can be called multiple times. @@ -338,6 +333,14 @@ func (s *Stream) Read(p []byte) (n int, err error) { return } +// finish sets the stream's state and status, and closes the done channel. +// s.mu must be held by the caller. +func (s *Stream) finish(st status.Status) { + s.status = st + s.state = streamDone + close(s.done) +} + // The key to save transport.Stream in the context. type streamKey struct{} @@ -503,10 +506,9 @@ type ServerTransport interface { // Write may not be called on all streams. Write(s *Stream, data []byte, opts *Options) error - // WriteStatus sends the status of a stream to the client. - // WriteStatus is the final call made on a stream and always - // occurs. - WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error + // WriteStatus sends the status of a stream to the client. WriteStatus is + // the final call made on a stream and always occurs. + WriteStatus(s *Stream, st status.Status) error // Close tears down the transport. Once it is called, the transport // should not be accessed any more. All the pending streams and their @@ -572,6 +574,8 @@ var ( ErrStreamDrain = streamErrorf(codes.Unavailable, "the server stops accepting new RPCs") ) +// TODO: See if we can replace StreamError with status package errors. + // StreamError is an error that only affects one stream within a connection. type StreamError struct { Code codes.Code diff --git a/transport/transport_test.go b/transport/transport_test.go index 3108b98c..4e986e56 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -39,6 +39,7 @@ import ( "io" "math" "net" + "reflect" "strconv" "strings" "sync" @@ -50,6 +51,7 @@ import ( "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/status" ) type server struct { @@ -100,7 +102,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { // send a response back to the client. h.t.Write(s, resp, &Options{}) // send the trailer to end the stream. - h.t.WriteStatus(s, codes.OK, "") + h.t.WriteStatus(s, status.New(codes.OK, "")) } // handleStreamSuspension blocks until s.ctx is canceled. @@ -142,7 +144,7 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) { // raw newline is not accepted by http2 framer so it must be encoded. - h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc) + h.t.WriteStatus(s, encodingTestStatus) } func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) { @@ -1070,8 +1072,11 @@ func TestServerWithMisbehavedClient(t *testing.T) { } // Server sent a resetStream for s already. code := http2ErrConvTab[http2.ErrCodeFlowControl] - if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF || s.statusCode != code { - t.Fatalf("%v got err %v with statusCode %d, want err with statusCode %d", s, err, s.statusCode, code) + if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF { + t.Fatalf("%v got err %v want ", s, err) + } + if s.status.Code() != code { + t.Fatalf("%v got status %v; want Code=%v", s, s.status, code) } if ss.fc.pendingData != 0 || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate <= initialWindowSize { @@ -1125,9 +1130,14 @@ func TestClientWithMisbehavedServer(t *testing.T) { if s.fc.pendingData <= initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData <= initialWindowSize || conn.fc.pendingUpdate != 0 { t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want >%d, %d, >%d, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, initialWindowSize, 0) } - if err != io.EOF || s.statusCode != codes.Internal { - t.Fatalf("Got err %v and the status code %d, want and the code %d", err, s.statusCode, codes.Internal) + + if err != io.EOF { + t.Fatalf("Got err %v, want ", err) } + if s.status.Code() != codes.Internal { + t.Fatalf("Got s.status %v, want s.status.Code()=Internal", s.status) + } + conn.CloseStream(s, err) if s.fc.pendingData != 0 || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate <= initialWindowSize { t.Fatalf("Client mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, >%d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize) @@ -1152,10 +1162,7 @@ func TestClientWithMisbehavedServer(t *testing.T) { server.stop() } -var ( - encodingTestStatusCode = codes.Internal - encodingTestStatusDesc = "\n" -) +var encodingTestStatus = status.New(codes.Internal, "\n") func TestEncodingRequiredStatus(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) @@ -1178,8 +1185,8 @@ func TestEncodingRequiredStatus(t *testing.T) { if _, err := s.dec.Read(p); err != io.EOF { t.Fatalf("Read got error %v, want %v", err, io.EOF) } - if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc { - t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc) + if !reflect.DeepEqual(s.Status(), encodingTestStatus) { + t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) } ct.Close() server.stop() @@ -1242,3 +1249,20 @@ func TestIsReservedHeader(t *testing.T) { } } } + +func TestContextErr(t *testing.T) { + for _, test := range []struct { + // input + errIn error + // outputs + errOut StreamError + }{ + {context.DeadlineExceeded, StreamError{codes.DeadlineExceeded, context.DeadlineExceeded.Error()}}, + {context.Canceled, StreamError{codes.Canceled, context.Canceled.Error()}}, + } { + err := ContextErr(test.errIn) + if err != test.errOut { + t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) + } + } +}