diff --git a/call.go b/call.go index 772c817e..5d9214d1 100644 --- a/call.go +++ b/call.go @@ -42,6 +42,7 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" "google.golang.org/grpc/codes" + "google.golang.org/grpc/stats" "google.golang.org/grpc/transport" ) @@ -49,7 +50,8 @@ import ( // On error, it returns the error and indicates whether the call should be retried. // // TODO(zhaoq): Check whether the received message sequence is valid. -func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) { +// TODO ctx is used for stats collection and processing. It is the context passed from the application. +func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) { // Try to acquire header metadata from the server if there is any. defer func() { if err != nil { @@ -63,14 +65,25 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s return } p := &parser{r: stream} + var inPayload *stats.InPayload + if stats.On() { + inPayload = &stats.InPayload{ + Client: true, + } + } for { - if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil { + if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil { if err == io.EOF { break } return } } + if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK { + // TODO in the current implementation, inTrailer may be handled before inPayload in some cases. + // Fix the order if necessary. + stats.Handle(ctx, inPayload) + } c.trailerMD = stream.Trailer() return nil } @@ -89,15 +102,27 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd } } }() - var cbuf *bytes.Buffer + var ( + cbuf *bytes.Buffer + outPayload *stats.OutPayload + ) if compressor != nil { cbuf = new(bytes.Buffer) } - outBuf, err := encode(codec, args, compressor, cbuf) + if stats.On() { + outPayload = &stats.OutPayload{ + Client: true, + } + } + outBuf, err := encode(codec, args, compressor, cbuf, outPayload) if err != nil { return nil, Errorf(codes.Internal, "grpc: %v", err) } err = t.Write(stream, outBuf, opts) + if err == nil && outPayload != nil { + outPayload.SentTime = time.Now() + stats.Handle(ctx, outPayload) + } // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following // recvResponse to get the final status. @@ -118,7 +143,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli return invoke(ctx, method, args, reply, cc, opts...) } -func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { +func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { c := defaultCallInfo for _, o := range opts { if err := o.before(&c); err != nil { @@ -140,12 +165,30 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false) // TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set. defer func() { - if err != nil { - c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + if e != nil { + c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true) c.traceInfo.tr.SetError() } }() } + if stats.On() { + begin := &stats.Begin{ + Client: true, + BeginTime: time.Now(), + FailFast: c.failFast, + } + stats.Handle(ctx, begin) + } + defer func() { + if stats.On() { + end := &stats.End{ + Client: true, + EndTime: time.Now(), + Error: e, + } + stats.Handle(ctx, end) + } + }() topts := &transport.Options{ Last: true, Delay: false, @@ -205,7 +248,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - err = recvResponse(cc.dopts, t, &c, stream, reply) + err = recvResponse(ctx, cc.dopts, t, &c, stream, reply) if err != nil { if put != nil { put() diff --git a/call_test.go b/call_test.go index 589e63ea..3c2165ea 100644 --- a/call_test.go +++ b/call_test.go @@ -118,7 +118,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { } } // send a response back to end the stream. - reply, err := encode(testCodec{}, &expectedResponse, nil, nil) + reply, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) if err != nil { t.Errorf("Failed to encode the response: %v", err) return @@ -185,6 +185,8 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { } go st.HandleStreams(func(s *transport.Stream) { go h.handleStream(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx }) } } diff --git a/rpc_util.go b/rpc_util.go index a25eaa8a..66d08b5a 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -42,11 +42,13 @@ import ( "io/ioutil" "math" "os" + "time" "github.com/golang/protobuf/proto" "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/transport" ) @@ -255,9 +257,11 @@ func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err erro // encode serializes msg and prepends the message header. If msg is nil, it // generates the message header of 0 message length. -func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte, error) { - var b []byte - var length uint +func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, error) { + var ( + b []byte + length uint + ) if msg != nil { var err error // TODO(zhaoq): optimize to reduce memory alloc and copying. @@ -265,6 +269,12 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte if err != nil { return nil, err } + if outPayload != nil { + outPayload.Payload = msg + // TODO truncate large payload. + outPayload.Data = b + outPayload.Length = len(b) + } if cp != nil { if err := cp.Do(cbuf, b); err != nil { return nil, err @@ -295,6 +305,10 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte // Copy encoded msg to buf copy(buf[5:], b) + if outPayload != nil { + outPayload.WireLength = len(buf) + } + return buf, nil } @@ -311,11 +325,14 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er return nil } -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error { +func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int, inPayload *stats.InPayload) error { pf, d, err := p.recvMsg(maxMsgSize) if err != nil { return err } + if inPayload != nil { + inPayload.WireLength = len(d) + } if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil { return err } @@ -333,6 +350,13 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ if err := c.Unmarshal(d, m); err != nil { return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err) } + if inPayload != nil { + inPayload.RecvTime = time.Now() + inPayload.Payload = m + // TODO truncate large payload. + inPayload.Data = d + inPayload.Length = len(d) + } return nil } diff --git a/rpc_util_test.go b/rpc_util_test.go index 0ba2d44c..375e42bc 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -114,7 +114,7 @@ func TestEncode(t *testing.T) { }{ {nil, nil, []byte{0, 0, 0, 0, 0}, nil}, } { - b, err := encode(protoCodec{}, test.msg, nil, nil) + b, err := encode(protoCodec{}, test.msg, nil, nil, nil) if err != test.err || !bytes.Equal(b, test.b) { t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err) } @@ -199,12 +199,12 @@ func TestErrorsWithSameParameters(t *testing.T) { // bytes. func bmEncode(b *testing.B, mSize int) { msg := &perfpb.Buffer{Body: make([]byte, mSize)} - encoded, _ := encode(protoCodec{}, msg, nil, nil) + encoded, _ := encode(protoCodec{}, msg, nil, nil, nil) encodedSz := int64(len(encoded)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - encode(protoCodec{}, msg, nil, nil) + encode(protoCodec{}, msg, nil, nil, nil) } b.SetBytes(encodedSz) } diff --git a/server.go b/server.go index 142a9a5b..3af001ac 100644 --- a/server.go +++ b/server.go @@ -54,6 +54,7 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/tap" "google.golang.org/grpc/transport" ) @@ -466,6 +467,12 @@ func (s *Server) serveStreams(st transport.ServerTransport) { defer wg.Done() s.handleStream(st, stream, s.traceInfo(st, stream)) }() + }, func(ctx context.Context, method string) context.Context { + if !EnableTracing { + return ctx + } + tr := trace.New("grpc.Recv."+methodFamily(method), method) + return trace.NewContext(ctx, tr) }) wg.Wait() } @@ -515,15 +522,17 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. // If tracing is not enabled, it returns nil. func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { - if !EnableTracing { + tr, ok := trace.FromContext(stream.Context()) + if !ok { return nil } + trInfo = &traceInfo{ - tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()), + tr: tr, } trInfo.firstLine.client = false trInfo.firstLine.remoteAddr = st.RemoteAddr() - stream.TraceContext(trInfo.tr) + if dl, ok := stream.Context().Deadline(); ok { trInfo.firstLine.deadline = dl.Sub(time.Now()) } @@ -550,11 +559,17 @@ func (s *Server) removeConn(c io.Closer) { } func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options) error { - var cbuf *bytes.Buffer + var ( + cbuf *bytes.Buffer + outPayload *stats.OutPayload + ) if cp != nil { cbuf = new(bytes.Buffer) } - p, err := encode(s.opts.codec, msg, cp, cbuf) + if stats.On() { + outPayload = &stats.OutPayload{} + } + p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) if err != nil { // This typically indicates a fatal issue (e.g., memory // corruption or hardware faults) the application program @@ -565,10 +580,32 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str // the optimal option. grpclog.Fatalf("grpc: Server failed to encode response %v", err) } - return t.Write(stream, p, opts) + err = t.Write(stream, p, opts) + if err == nil && outPayload != nil { + outPayload.SentTime = time.Now() + stats.Handle(stream.Context(), outPayload) + } + return err } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { + if stats.On() { + begin := &stats.Begin{ + BeginTime: time.Now(), + } + stats.Handle(stream.Context(), begin) + } + defer func() { + if stats.On() { + end := &stats.End{ + EndTime: time.Now(), + } + if err != nil && err != io.EOF { + end.Error = toRPCErr(err) + } + stats.Handle(stream.Context(), end) + } + }() if trInfo != nil { defer trInfo.tr.Finish() trInfo.firstLine.client = false @@ -597,14 +634,14 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err != nil { switch err := err.(type) { case *rpcError: - if err := t.WriteStatus(stream, err.code, err.desc); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + if e := t.WriteStatus(stream, err.code, err.desc); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } case transport.ConnectionError: // Nothing to do here. case transport.StreamError: - if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + if e := t.WriteStatus(stream, err.Code, err.Desc); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } default: panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) @@ -615,20 +652,29 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := checkRecvPayload(pf, stream.RecvCompress(), s.opts.dc); err != nil { switch err := err.(type) { case *rpcError: - if err := t.WriteStatus(stream, err.code, err.desc); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + if e := t.WriteStatus(stream, err.code, err.desc); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } + return err default: - if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { - grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) + if e := t.WriteStatus(stream, codes.Internal, err.Error()); e != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e) } - + // TODO checkRecvPayload always return RPC error. Add a return here if necessary. + } + } + var inPayload *stats.InPayload + if stats.On() { + inPayload = &stats.InPayload{ + RecvTime: time.Now(), } - return err } statusCode := codes.OK statusDesc := "" df := func(v interface{}) error { + if inPayload != nil { + inPayload.WireLength = len(req) + } if pf == compressionMade { var err error req, err = s.opts.dc.Do(bytes.NewReader(req)) @@ -636,7 +682,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := t.WriteStatus(stream, codes.Internal, err.Error()); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err) } - return err + return Errorf(codes.Internal, err.Error()) } } if len(req) > s.opts.maxMsgSize { @@ -648,6 +694,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err := s.opts.codec.Unmarshal(req, v); err != nil { return err } + if inPayload != nil { + inPayload.Payload = v + inPayload.Data = req + inPayload.Length = len(req) + stats.Handle(stream.Context(), inPayload) + } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) } @@ -668,9 +720,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) - return err } - return nil + return Errorf(statusCode, statusDesc) } if trInfo != nil { trInfo.tr.LazyLog(stringer("OK"), false) @@ -695,11 +746,32 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) } - return t.WriteStatus(stream, statusCode, statusDesc) + errWrite := t.WriteStatus(stream, statusCode, statusDesc) + if statusCode != codes.OK { + return Errorf(statusCode, statusDesc) + } + return errWrite } } func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { + if stats.On() { + begin := &stats.Begin{ + BeginTime: time.Now(), + } + stats.Handle(stream.Context(), begin) + } + defer func() { + if stats.On() { + end := &stats.End{ + EndTime: time.Now(), + } + if err != nil && err != io.EOF { + end.Error = toRPCErr(err) + } + stats.Handle(stream.Context(), end) + } + }() if s.opts.cp != nil { stream.SetSendCompress(s.opts.cp.Type()) } @@ -762,7 +834,11 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp } ss.mu.Unlock() } - return t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) + errWrite := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) + if ss.statusCode != codes.OK { + return Errorf(ss.statusCode, ss.statusDesc) + } + return errWrite } @@ -777,7 +853,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true) trInfo.tr.SetError() } - if err := t.WriteStatus(stream, codes.InvalidArgument, fmt.Sprintf("malformed method name: %q", stream.Method())); err != nil { + errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) + if err := t.WriteStatus(stream, codes.InvalidArgument, errDesc); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -797,7 +874,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.LazyLog(&fmtStringer{"Unknown service %v", []interface{}{service}}, true) trInfo.tr.SetError() } - if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown service %v", service)); err != nil { + errDesc := fmt.Sprintf("unknown service %v", service) + if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() @@ -822,7 +900,8 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str trInfo.tr.LazyLog(&fmtStringer{"Unknown method %v", []interface{}{method}}, true) trInfo.tr.SetError() } - if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown method %v", method)); err != nil { + errDesc := fmt.Sprintf("unknown method %v", method) + if err := t.WriteStatus(stream, codes.Unimplemented, errDesc); err != nil { if trInfo != nil { trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) trInfo.tr.SetError() diff --git a/stats/grpc_testing/test.pb.go b/stats/grpc_testing/test.pb.go new file mode 100644 index 00000000..b24dcd8d --- /dev/null +++ b/stats/grpc_testing/test.pb.go @@ -0,0 +1,225 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package grpc_testing is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + SimpleRequest + SimpleResponse +*/ +package grpc_testing + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +// Unary request. +type SimpleRequest struct { + Id int32 `protobuf:"varint,2,opt,name=id" json:"id,omitempty"` +} + +func (m *SimpleRequest) Reset() { *m = SimpleRequest{} } +func (m *SimpleRequest) String() string { return proto.CompactTextString(m) } +func (*SimpleRequest) ProtoMessage() {} +func (*SimpleRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +// Unary response, as configured by the request. +type SimpleResponse struct { + Id int32 `protobuf:"varint,3,opt,name=id" json:"id,omitempty"` +} + +func (m *SimpleResponse) Reset() { *m = SimpleResponse{} } +func (m *SimpleResponse) String() string { return proto.CompactTextString(m) } +func (*SimpleResponse) ProtoMessage() {} +func (*SimpleResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} } + +func init() { + proto.RegisterType((*SimpleRequest)(nil), "grpc.testing.SimpleRequest") + proto.RegisterType((*SimpleResponse)(nil), "grpc.testing.SimpleResponse") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for TestService service + +type TestServiceClient interface { + // One request followed by one response. + // The server returns the client id as-is. + UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) +} + +type testServiceClient struct { + cc *grpc.ClientConn +} + +func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient { + return &testServiceClient{cc} +} + +func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) { + out := new(SimpleResponse) + err := grpc.Invoke(ctx, "/grpc.testing.TestService/UnaryCall", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceFullDuplexCallClient{stream} + return x, nil +} + +type TestService_FullDuplexCallClient interface { + Send(*SimpleRequest) error + Recv() (*SimpleResponse, error) + grpc.ClientStream +} + +type testServiceFullDuplexCallClient struct { + grpc.ClientStream +} + +func (x *testServiceFullDuplexCallClient) Send(m *SimpleRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceFullDuplexCallClient) Recv() (*SimpleResponse, error) { + m := new(SimpleResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for TestService service + +type TestServiceServer interface { + // One request followed by one response. + // The server returns the client id as-is. + UnaryCall(context.Context, *SimpleRequest) (*SimpleResponse, error) + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + FullDuplexCall(TestService_FullDuplexCallServer) error +} + +func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { + s.RegisterService(&_TestService_serviceDesc, srv) +} + +func _TestService_UnaryCall_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SimpleRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TestServiceServer).UnaryCall(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/grpc.testing.TestService/UnaryCall", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TestServiceServer).UnaryCall(ctx, req.(*SimpleRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _TestService_FullDuplexCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).FullDuplexCall(&testServiceFullDuplexCallServer{stream}) +} + +type TestService_FullDuplexCallServer interface { + Send(*SimpleResponse) error + Recv() (*SimpleRequest, error) + grpc.ServerStream +} + +type testServiceFullDuplexCallServer struct { + grpc.ServerStream +} + +func (x *testServiceFullDuplexCallServer) Send(m *SimpleResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceFullDuplexCallServer) Recv() (*SimpleRequest, error) { + m := new(SimpleRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _TestService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*TestServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "UnaryCall", + Handler: _TestService_UnaryCall_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "FullDuplexCall", + Handler: _TestService_FullDuplexCall_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "test.proto", +} + +func init() { proto.RegisterFile("test.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 167 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, + 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x49, 0x2f, 0x2a, 0x48, 0xd6, 0x03, 0x09, 0x64, + 0xe6, 0xa5, 0x2b, 0xc9, 0x73, 0xf1, 0x06, 0x67, 0xe6, 0x16, 0xe4, 0xa4, 0x06, 0xa5, 0x16, 0x96, + 0xa6, 0x16, 0x97, 0x08, 0xf1, 0x71, 0x31, 0x65, 0xa6, 0x48, 0x30, 0x29, 0x30, 0x6a, 0xb0, 0x06, + 0x31, 0x65, 0xa6, 0x28, 0x29, 0x70, 0xf1, 0xc1, 0x14, 0x14, 0x17, 0xe4, 0xe7, 0x15, 0xa7, 0x42, + 0x55, 0x30, 0xc3, 0x54, 0x18, 0x2d, 0x63, 0xe4, 0xe2, 0x0e, 0x49, 0x2d, 0x2e, 0x09, 0x4e, 0x2d, + 0x2a, 0xcb, 0x4c, 0x4e, 0x15, 0x72, 0xe3, 0xe2, 0x0c, 0xcd, 0x4b, 0x2c, 0xaa, 0x74, 0x4e, 0xcc, + 0xc9, 0x11, 0x92, 0xd6, 0x43, 0xb6, 0x4e, 0x0f, 0xc5, 0x2e, 0x29, 0x19, 0xec, 0x92, 0x50, 0x7b, + 0xfc, 0xb9, 0xf8, 0xdc, 0x4a, 0x73, 0x72, 0x5c, 0x4a, 0x0b, 0x72, 0x52, 0x2b, 0x28, 0x34, 0x4c, + 0x83, 0xd1, 0x80, 0x31, 0x89, 0x0d, 0x1c, 0x00, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x8d, + 0x82, 0x5b, 0xdd, 0x0e, 0x01, 0x00, 0x00, +} diff --git a/stats/grpc_testing/test.proto b/stats/grpc_testing/test.proto new file mode 100644 index 00000000..54e6f744 --- /dev/null +++ b/stats/grpc_testing/test.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package grpc.testing; + +message SimpleRequest { + int32 id = 2; +} + +message SimpleResponse { + int32 id = 3; +} + +// A simple test service. +service TestService { + // One request followed by one response. + // The server returns the client id as-is. + rpc UnaryCall(SimpleRequest) returns (SimpleResponse); + + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + rpc FullDuplexCall(stream SimpleRequest) returns (stream SimpleResponse); +} diff --git a/stats/stats.go b/stats/stats.go new file mode 100644 index 00000000..4b030d98 --- /dev/null +++ b/stats/stats.go @@ -0,0 +1,219 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package stats is for collecting and reporting various network and RPC stats. +// This package is for monitoring purpose only. All fields are read-only. +// All APIs are experimental. +package stats // import "google.golang.org/grpc/stats" + +import ( + "net" + "sync/atomic" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/grpclog" +) + +// RPCStats contains stats information about RPCs. +// All stats types in this package implements this interface. +type RPCStats interface { + // IsClient returns true if this RPCStats is from client side. + IsClient() bool +} + +// Begin contains stats when an RPC begins. +// FailFast are only valid if Client is true. +type Begin struct { + // Client is true if this Begin is from client side. + Client bool + // BeginTime is the time when the RPC begins. + BeginTime time.Time + // FailFast indicates if this RPC is failfast. + FailFast bool +} + +// IsClient indicates if this is from client side. +func (s *Begin) IsClient() bool { return s.Client } + +// InPayload contains the information for an incoming payload. +type InPayload struct { + // Client is true if this InPayload is from client side. + Client bool + // Payload is the payload with original type. + Payload interface{} + // Data is the serialized message payload. + Data []byte + // Length is the length of uncompressed data. + Length int + // WireLength is the length of data on wire (compressed, signed, encrypted). + WireLength int + // RecvTime is the time when the payload is received. + RecvTime time.Time +} + +// IsClient indicates if this is from client side. +func (s *InPayload) IsClient() bool { return s.Client } + +// InHeader contains stats when a header is received. +// FullMethod, addresses and Compression are only valid if Client is false. +type InHeader struct { + // Client is true if this InHeader is from client side. + Client bool + // WireLength is the wire length of header. + WireLength int + + // FullMethod is the full RPC method string, i.e., /package.service/method. + FullMethod string + // RemoteAddr is the remote address of the corresponding connection. + RemoteAddr net.Addr + // LocalAddr is the local address of the corresponding connection. + LocalAddr net.Addr + // Compression is the compression algorithm used for the RPC. + Compression string +} + +// IsClient indicates if this is from client side. +func (s *InHeader) IsClient() bool { return s.Client } + +// InTrailer contains stats when a trailer is received. +type InTrailer struct { + // Client is true if this InTrailer is from client side. + Client bool + // WireLength is the wire length of trailer. + WireLength int +} + +// IsClient indicates if this is from client side. +func (s *InTrailer) IsClient() bool { return s.Client } + +// OutPayload contains the information for an outgoing payload. +type OutPayload struct { + // Client is true if this OutPayload is from client side. + Client bool + // Payload is the payload with original type. + Payload interface{} + // Data is the serialized message payload. + Data []byte + // Length is the length of uncompressed data. + Length int + // WireLength is the length of data on wire (compressed, signed, encrypted). + WireLength int + // SentTime is the time when the payload is sent. + SentTime time.Time +} + +// IsClient indicates if this is from client side. +func (s *OutPayload) IsClient() bool { return s.Client } + +// OutHeader contains stats when a header is sent. +// FullMethod, addresses and Compression are only valid if Client is true. +type OutHeader struct { + // Client is true if this OutHeader is from client side. + Client bool + // WireLength is the wire length of header. + WireLength int + + // FullMethod is the full RPC method string, i.e., /package.service/method. + FullMethod string + // RemoteAddr is the remote address of the corresponding connection. + RemoteAddr net.Addr + // LocalAddr is the local address of the corresponding connection. + LocalAddr net.Addr + // Compression is the compression algorithm used for the RPC. + Compression string +} + +// IsClient indicates if this is from client side. +func (s *OutHeader) IsClient() bool { return s.Client } + +// OutTrailer contains stats when a trailer is sent. +type OutTrailer struct { + // Client is true if this OutTrailer is from client side. + Client bool + // WireLength is the wire length of trailer. + WireLength int +} + +// IsClient indicates if this is from client side. +func (s *OutTrailer) IsClient() bool { return s.Client } + +// End contains stats when an RPC ends. +type End struct { + // Client is true if this End is from client side. + Client bool + // EndTime is the time when the RPC ends. + EndTime time.Time + // Error is the error just happened. Its type is gRPC error. + Error error +} + +// IsClient indicates if this is from client side. +func (s *End) IsClient() bool { return s.Client } + +var ( + on = new(int32) + handler func(context.Context, RPCStats) +) + +// On indicates whether stats is started. +func On() bool { + return atomic.CompareAndSwapInt32(on, 1, 1) +} + +// Handle processes the stats using the call back function registered by user. +func Handle(ctx context.Context, s RPCStats) { + handler(ctx, s) +} + +// RegisterHandler registers the user handler function. +// If another handler was registered before, this new handler will overwrite the old one. +// This handler function will be called to process the stats. +func RegisterHandler(f func(context.Context, RPCStats)) { + handler = f +} + +// Start starts the stats collection and reporting if there is a registered stats handle. +func Start() { + if handler == nil { + grpclog.Println("handler is nil when starting stats. Stats is not started") + return + } + atomic.StoreInt32(on, 1) +} + +// Stop stops the stats collection and processing. +// Stop does not unregister handler. +func Stop() { + atomic.StoreInt32(on, 0) +} diff --git a/stats/stats_test.go b/stats/stats_test.go new file mode 100644 index 00000000..96037c9e --- /dev/null +++ b/stats/stats_test.go @@ -0,0 +1,1192 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package stats_test + +import ( + "fmt" + "io" + "net" + "reflect" + "sync" + "testing" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" + testpb "google.golang.org/grpc/stats/grpc_testing" +) + +func TestStartStop(t *testing.T) { + stats.RegisterHandler(nil) + stats.Start() + if stats.On() != false { + t.Fatalf("stats.Start() with nil handler, stats.On() = true, want false") + } + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {}) + if stats.On() != false { + t.Fatalf("after stats.RegisterHandler(), stats.On() = true, want false") + } + stats.Start() + if stats.On() != true { + t.Fatalf("after stats.Start(_), stats.On() = false, want true") + } + stats.Stop() + if stats.On() != false { + t.Fatalf("after stats.Stop(), stats.On() = true, want false") + } +} + +var ( + // For headers: + testMetadata = metadata.MD{ + "key1": []string{"value1"}, + "key2": []string{"value2"}, + } + // For trailers: + testTrailerMetadata = metadata.MD{ + "tkey1": []string{"trailerValue1"}, + "tkey2": []string{"trailerValue2"}, + } + // The id for which the service handler should return error. + errorID int32 = 32202 +) + +type testServer struct{} + +func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + md, ok := metadata.FromContext(ctx) + if ok { + if err := grpc.SendHeader(ctx, md); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want ", md, err) + } + if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want ", testTrailerMetadata, err) + } + } + + if in.Id == errorID { + return nil, fmt.Errorf("got error id: %v", in.Id) + } + + return &testpb.SimpleResponse{Id: in.Id}, nil +} + +func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { + md, ok := metadata.FromContext(stream.Context()) + if ok { + if err := stream.SendHeader(md); err != nil { + return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + } + stream.SetTrailer(testTrailerMetadata) + } + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + return nil + } + if err != nil { + return err + } + + if in.Id == errorID { + return fmt.Errorf("got error id: %v", in.Id) + } + + if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil { + return err + } + } +} + +// test is an end-to-end test. It should be created with the newTest +// func, modified as needed, and then started with its startServer method. +// It should be cleaned up with the tearDown method. +type test struct { + t *testing.T + compress string + + ctx context.Context // valid for life of test, before tearDown + cancel context.CancelFunc + + testServer testpb.TestServiceServer // nil means none + // srv and srvAddr are set once startServer is called. + srv *grpc.Server + srvAddr string + + cc *grpc.ClientConn // nil until requested via clientConn +} + +func (te *test) tearDown() { + if te.cancel != nil { + te.cancel() + te.cancel = nil + } + if te.cc != nil { + te.cc.Close() + te.cc = nil + } + te.srv.Stop() +} + +// newTest returns a new test using the provided testing.T and +// environment. It is returned with default values. Tests should +// modify it before calling its startServer and clientConn methods. +func newTest(t *testing.T, compress string) *test { + te := &test{t: t, compress: compress} + te.ctx, te.cancel = context.WithCancel(context.Background()) + return te +} + +// startServer starts a gRPC server listening. Callers should defer a +// call to te.tearDown to clean up. +func (te *test) startServer(ts testpb.TestServiceServer) { + te.testServer = ts + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + te.t.Fatalf("Failed to listen: %v", err) + } + var opts []grpc.ServerOption + if te.compress == "gzip" { + opts = append(opts, + grpc.RPCCompressor(grpc.NewGZIPCompressor()), + grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), + ) + } + s := grpc.NewServer(opts...) + te.srv = s + if te.testServer != nil { + testpb.RegisterTestServiceServer(s, te.testServer) + } + _, port, err := net.SplitHostPort(lis.Addr().String()) + if err != nil { + te.t.Fatalf("Failed to parse listener address: %v", err) + } + addr := "127.0.0.1:" + port + + go s.Serve(lis) + te.srvAddr = addr +} + +func (te *test) clientConn() *grpc.ClientConn { + if te.cc != nil { + return te.cc + } + opts := []grpc.DialOption{grpc.WithInsecure()} + if te.compress == "gzip" { + opts = append(opts, + grpc.WithCompressor(grpc.NewGZIPCompressor()), + grpc.WithDecompressor(grpc.NewGZIPDecompressor()), + ) + } + + var err error + te.cc, err = grpc.Dial(te.srvAddr, opts...) + if err != nil { + te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err) + } + return te.cc +} + +type rpcConfig struct { + count int // Number of requests and responses for streaming RPCs. + success bool // Whether the RPC should succeed or return error. + failfast bool +} + +func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { + var ( + resp *testpb.SimpleResponse + req *testpb.SimpleRequest + err error + ) + tc := testpb.NewTestServiceClient(te.clientConn()) + if c.success { + req = &testpb.SimpleRequest{Id: errorID + 1} + } else { + req = &testpb.SimpleRequest{Id: errorID} + } + ctx := metadata.NewContext(context.Background(), testMetadata) + + resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast)) + if err != nil { + return req, resp, err + } + + return req, resp, err +} + +func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) { + var ( + reqs []*testpb.SimpleRequest + resps []*testpb.SimpleResponse + err error + ) + tc := testpb.NewTestServiceClient(te.clientConn()) + stream, err := tc.FullDuplexCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast)) + if err != nil { + return reqs, resps, err + } + var startID int32 + if !c.success { + startID = errorID + } + for i := 0; i < c.count; i++ { + req := &testpb.SimpleRequest{ + Id: int32(i) + startID, + } + reqs = append(reqs, req) + if err = stream.Send(req); err != nil { + return reqs, resps, err + } + var resp *testpb.SimpleResponse + if resp, err = stream.Recv(); err != nil { + return reqs, resps, err + } + resps = append(resps, resp) + } + if err = stream.CloseSend(); err != nil { + return reqs, resps, err + } + if _, err = stream.Recv(); err != io.EOF { + return reqs, resps, err + } + + return reqs, resps, err +} + +type expectedData struct { + method string + serverAddr string + compression string + reqIdx int + requests []*testpb.SimpleRequest + respIdx int + responses []*testpb.SimpleResponse + err error + failfast bool +} + +type gotData struct { + ctx context.Context + client bool + s stats.RPCStats +} + +const ( + begin int = iota + end + inpay + inHeader + inTrailer + outPayload + outHeader + outTrailer +) + +func checkBegin(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.Begin + ) + if st, ok = d.s.(*stats.Begin); !ok { + t.Fatalf("got %T, want Begin", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + if st.BeginTime.IsZero() { + t.Fatalf("st.BeginTime = %v, want ", st.BeginTime) + } + if d.client { + if st.FailFast != e.failfast { + t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) + } + } +} + +func checkInHeader(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.InHeader + ) + if st, ok = d.s.(*stats.InHeader); !ok { + t.Fatalf("got %T, want InHeader", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + // TODO check real length, not just > 0. + if st.WireLength <= 0 { + t.Fatalf("st.Lenght = 0, want > 0") + } + if !d.client { + if st.FullMethod != e.method { + t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) + } + if st.LocalAddr.String() != e.serverAddr { + t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr) + } + if st.Compression != e.compression { + t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) + } + } +} + +func checkInPayload(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.InPayload + ) + if st, ok = d.s.(*stats.InPayload); !ok { + t.Fatalf("got %T, want InPayload", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + if d.client { + b, err := proto.Marshal(e.responses[e.respIdx]) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { + t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) + } + e.respIdx++ + if string(st.Data) != string(b) { + t.Fatalf("st.Data = %v, want %v", st.Data, b) + } + if st.Length != len(b) { + t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + } + } else { + b, err := proto.Marshal(e.requests[e.reqIdx]) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { + t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) + } + e.reqIdx++ + if string(st.Data) != string(b) { + t.Fatalf("st.Data = %v, want %v", st.Data, b) + } + if st.Length != len(b) { + t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + } + } + // TODO check WireLength and ReceivedTime. + if st.RecvTime.IsZero() { + t.Fatalf("st.ReceivedTime = %v, want ", st.RecvTime) + } +} + +func checkInTrailer(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.InTrailer + ) + if st, ok = d.s.(*stats.InTrailer); !ok { + t.Fatalf("got %T, want InTrailer", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + // TODO check real length, not just > 0. + if st.WireLength <= 0 { + t.Fatalf("st.Lenght = 0, want > 0") + } +} + +func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.OutHeader + ) + if st, ok = d.s.(*stats.OutHeader); !ok { + t.Fatalf("got %T, want OutHeader", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + // TODO check real length, not just > 0. + if st.WireLength <= 0 { + t.Fatalf("st.Lenght = 0, want > 0") + } + if d.client { + if st.FullMethod != e.method { + t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) + } + if st.RemoteAddr.String() != e.serverAddr { + t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr) + } + if st.Compression != e.compression { + t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) + } + } +} + +func checkOutPayload(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.OutPayload + ) + if st, ok = d.s.(*stats.OutPayload); !ok { + t.Fatalf("got %T, want OutPayload", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + if d.client { + b, err := proto.Marshal(e.requests[e.reqIdx]) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { + t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) + } + e.reqIdx++ + if string(st.Data) != string(b) { + t.Fatalf("st.Data = %v, want %v", st.Data, b) + } + if st.Length != len(b) { + t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + } + } else { + b, err := proto.Marshal(e.responses[e.respIdx]) + if err != nil { + t.Fatalf("failed to marshal message: %v", err) + } + if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { + t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) + } + e.respIdx++ + if string(st.Data) != string(b) { + t.Fatalf("st.Data = %v, want %v", st.Data, b) + } + if st.Length != len(b) { + t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) + } + } + // TODO check WireLength and ReceivedTime. + if st.SentTime.IsZero() { + t.Fatalf("st.SentTime = %v, want ", st.SentTime) + } +} + +func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.OutTrailer + ) + if st, ok = d.s.(*stats.OutTrailer); !ok { + t.Fatalf("got %T, want OutTrailer", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + if st.Client { + t.Fatalf("st IsClient = true, want false") + } + // TODO check real length, not just > 0. + if st.WireLength <= 0 { + t.Fatalf("st.Lenght = 0, want > 0") + } +} + +func checkEnd(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.End + ) + if st, ok = d.s.(*stats.End); !ok { + t.Fatalf("got %T, want End", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + if st.EndTime.IsZero() { + t.Fatalf("st.EndTime = %v, want ", st.EndTime) + } + if grpc.Code(st.Error) != grpc.Code(e.err) || grpc.ErrorDesc(st.Error) != grpc.ErrorDesc(e.err) { + t.Fatalf("st.Error = %v, want %v", st.Error, e.err) + } +} + +func TestServerStatsUnaryRPC(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() + if !s.IsClient() { + got = append(got, &gotData{ctx, false, s}) + } + }) + stats.Start() + defer stats.Stop() + + te := newTest(t, "") + te.startServer(&testServer{}) + defer te.tearDown() + + req, resp, err := te.doUnaryCall(&rpcConfig{success: true}) + if err != nil { + t.Fatalf(err.Error()) + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/UnaryCall", + serverAddr: te.srvAddr, + requests: []*testpb.SimpleRequest{req}, + responses: []*testpb.SimpleResponse{resp}, + } + + checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkInPayload, + checkOutHeader, + checkOutPayload, + checkOutTrailer, + checkEnd, + } + + if len(got) != len(checkFuncs) { + t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) + } + + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + + for i, f := range checkFuncs { + mu.Lock() + f(t, got[i], expect) + mu.Unlock() + } +} + +func TestServerStatsUnaryRPCError(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() + if !s.IsClient() { + got = append(got, &gotData{ctx, false, s}) + } + }) + stats.Start() + defer stats.Stop() + + te := newTest(t, "") + te.startServer(&testServer{}) + defer te.tearDown() + + req, resp, err := te.doUnaryCall(&rpcConfig{success: false}) + if err == nil { + t.Fatalf("got error ; want ") + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/UnaryCall", + serverAddr: te.srvAddr, + requests: []*testpb.SimpleRequest{req}, + responses: []*testpb.SimpleResponse{resp}, + err: err, + } + + checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkInPayload, + checkOutHeader, + checkOutTrailer, + checkEnd, + } + + if len(got) != len(checkFuncs) { + t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) + } + + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + + for i, f := range checkFuncs { + mu.Lock() + f(t, got[i], expect) + mu.Unlock() + } +} + +func TestServerStatsStreamingRPC(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() + if !s.IsClient() { + got = append(got, &gotData{ctx, false, s}) + } + }) + stats.Start() + defer stats.Stop() + + te := newTest(t, "gzip") + te.startServer(&testServer{}) + defer te.tearDown() + + count := 5 + reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: true}) + if err == nil { + t.Fatalf(err.Error()) + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/FullDuplexCall", + serverAddr: te.srvAddr, + compression: "gzip", + requests: reqs, + responses: resps, + } + + checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkOutHeader, + } + ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkInPayload, + checkOutPayload, + } + for i := 0; i < count; i++ { + checkFuncs = append(checkFuncs, ioPayFuncs...) + } + checkFuncs = append(checkFuncs, checkOutTrailer, checkEnd) + + if len(got) != len(checkFuncs) { + t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) + } + + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + + for i, f := range checkFuncs { + mu.Lock() + f(t, got[i], expect) + mu.Unlock() + } +} + +func TestServerStatsStreamingRPCError(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() + if !s.IsClient() { + got = append(got, &gotData{ctx, false, s}) + } + }) + stats.Start() + defer stats.Stop() + + te := newTest(t, "gzip") + te.startServer(&testServer{}) + defer te.tearDown() + + count := 5 + reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: false}) + if err == nil { + t.Fatalf("got error ; want ") + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/FullDuplexCall", + serverAddr: te.srvAddr, + compression: "gzip", + requests: reqs, + responses: resps, + err: err, + } + + checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkInHeader, + checkBegin, + checkOutHeader, + checkInPayload, + checkOutTrailer, + checkEnd, + } + + if len(got) != len(checkFuncs) { + t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) + } + + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + + for i, f := range checkFuncs { + mu.Lock() + f(t, got[i], expect) + mu.Unlock() + } +} + +type checkFuncWithCount struct { + f func(t *testing.T, d *gotData, e *expectedData) + c int // expected count +} + +func TestClientStatsUnaryRPC(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() + if s.IsClient() { + got = append(got, &gotData{ctx, true, s}) + } + }) + stats.Start() + defer stats.Stop() + + te := newTest(t, "") + te.startServer(&testServer{}) + defer te.tearDown() + + failfast := false + req, resp, err := te.doUnaryCall(&rpcConfig{success: true, failfast: failfast}) + if err != nil { + t.Fatalf(err.Error()) + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/UnaryCall", + serverAddr: te.srvAddr, + requests: []*testpb.SimpleRequest{req}, + responses: []*testpb.SimpleResponse{resp}, + failfast: failfast, + } + + checkFuncs := map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inpay: {checkInPayload, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + } + + var expectLen int + for _, v := range checkFuncs { + expectLen += v.c + } + if len(got) != expectLen { + t.Fatalf("got %v stats, want %v stats", len(got), expectLen) + } + + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + + for _, s := range got { + mu.Lock() + switch s.s.(type) { + case *stats.Begin: + if checkFuncs[begin].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[begin].f(t, s, expect) + checkFuncs[begin].c-- + case *stats.OutHeader: + if checkFuncs[outHeader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outHeader].f(t, s, expect) + checkFuncs[outHeader].c-- + case *stats.OutPayload: + if checkFuncs[outPayload].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outPayload].f(t, s, expect) + checkFuncs[outPayload].c-- + case *stats.InHeader: + if checkFuncs[inHeader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inHeader].f(t, s, expect) + checkFuncs[inHeader].c-- + case *stats.InPayload: + if checkFuncs[inpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inpay].f(t, s, expect) + checkFuncs[inpay].c-- + case *stats.InTrailer: + if checkFuncs[inTrailer].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inTrailer].f(t, s, expect) + checkFuncs[inTrailer].c-- + case *stats.End: + if checkFuncs[end].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[end].f(t, s, expect) + checkFuncs[end].c-- + default: + t.Fatalf("unexpected stats: %T", s) + } + mu.Unlock() + } +} + +func TestClientStatsUnaryRPCError(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() + if s.IsClient() { + got = append(got, &gotData{ctx, true, s}) + } + }) + stats.Start() + defer stats.Stop() + + te := newTest(t, "") + te.startServer(&testServer{}) + defer te.tearDown() + + failfast := true + req, resp, err := te.doUnaryCall(&rpcConfig{success: false, failfast: failfast}) + if err == nil { + t.Fatalf("got error ; want ") + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/UnaryCall", + serverAddr: te.srvAddr, + requests: []*testpb.SimpleRequest{req}, + responses: []*testpb.SimpleResponse{resp}, + err: err, + failfast: failfast, + } + + checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkBegin, + checkOutHeader, + checkOutPayload, + checkInHeader, + checkInTrailer, + checkEnd, + } + + if len(got) != len(checkFuncs) { + t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) + } + + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + + for i, f := range checkFuncs { + mu.Lock() + f(t, got[i], expect) + mu.Unlock() + } +} + +func TestClientStatsStreamingRPC(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() + if s.IsClient() { + // t.Logf(" == %T %v", s, s.IsClient()) + got = append(got, &gotData{ctx, true, s}) + } + }) + stats.Start() + defer stats.Stop() + + te := newTest(t, "gzip") + te.startServer(&testServer{}) + defer te.tearDown() + + count := 5 + failfast := false + reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: true, failfast: failfast}) + if err == nil { + t.Fatalf(err.Error()) + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/FullDuplexCall", + serverAddr: te.srvAddr, + compression: "gzip", + requests: reqs, + responses: resps, + failfast: failfast, + } + + checkFuncs := map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, count}, + inHeader: {checkInHeader, 1}, + inpay: {checkInPayload, count}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + } + + var expectLen int + for _, v := range checkFuncs { + expectLen += v.c + } + if len(got) != expectLen { + t.Fatalf("got %v stats, want %v stats", len(got), expectLen) + } + + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + + for _, s := range got { + mu.Lock() + switch s.s.(type) { + case *stats.Begin: + if checkFuncs[begin].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[begin].f(t, s, expect) + checkFuncs[begin].c-- + case *stats.OutHeader: + if checkFuncs[outHeader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outHeader].f(t, s, expect) + checkFuncs[outHeader].c-- + case *stats.OutPayload: + if checkFuncs[outPayload].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outPayload].f(t, s, expect) + checkFuncs[outPayload].c-- + case *stats.InHeader: + if checkFuncs[inHeader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inHeader].f(t, s, expect) + checkFuncs[inHeader].c-- + case *stats.InPayload: + if checkFuncs[inpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inpay].f(t, s, expect) + checkFuncs[inpay].c-- + case *stats.InTrailer: + if checkFuncs[inTrailer].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inTrailer].f(t, s, expect) + checkFuncs[inTrailer].c-- + case *stats.End: + if checkFuncs[end].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[end].f(t, s, expect) + checkFuncs[end].c-- + default: + t.Fatalf("unexpected stats: %T", s) + } + mu.Unlock() + } +} + +func TestClientStatsStreamingRPCError(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() + if s.IsClient() { + got = append(got, &gotData{ctx, true, s}) + } + }) + stats.Start() + defer stats.Stop() + + te := newTest(t, "gzip") + te.startServer(&testServer{}) + defer te.tearDown() + + count := 5 + failfast := true + reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: false, failfast: failfast}) + if err == nil { + t.Fatalf("got error ; want ") + } + te.srv.GracefulStop() // Wait for the server to stop. + + expect := &expectedData{ + method: "/grpc.testing.TestService/FullDuplexCall", + serverAddr: te.srvAddr, + compression: "gzip", + requests: reqs, + responses: resps, + err: err, + failfast: failfast, + } + + checkFuncs := map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, 1}, + inHeader: {checkInHeader, 1}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + } + + var expectLen int + for _, v := range checkFuncs { + expectLen += v.c + } + if len(got) != expectLen { + t.Fatalf("got %v stats, want %v stats", len(got), expectLen) + } + + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + + for _, s := range got { + mu.Lock() + switch s.s.(type) { + case *stats.Begin: + if checkFuncs[begin].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[begin].f(t, s, expect) + checkFuncs[begin].c-- + case *stats.OutHeader: + if checkFuncs[outHeader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outHeader].f(t, s, expect) + checkFuncs[outHeader].c-- + case *stats.OutPayload: + if checkFuncs[outPayload].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[outPayload].f(t, s, expect) + checkFuncs[outPayload].c-- + case *stats.InHeader: + if checkFuncs[inHeader].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inHeader].f(t, s, expect) + checkFuncs[inHeader].c-- + case *stats.InPayload: + if checkFuncs[inpay].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inpay].f(t, s, expect) + checkFuncs[inpay].c-- + case *stats.InTrailer: + if checkFuncs[inTrailer].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[inTrailer].f(t, s, expect) + checkFuncs[inTrailer].c-- + case *stats.End: + if checkFuncs[end].c <= 0 { + t.Fatalf("unexpected stats: %T", s) + } + checkFuncs[end].f(t, s, expect) + checkFuncs[end].c-- + default: + t.Fatalf("unexpected stats: %T", s) + } + mu.Unlock() + } +} diff --git a/stream.go b/stream.go index 46810544..95c8acf8 100644 --- a/stream.go +++ b/stream.go @@ -45,6 +45,7 @@ import ( "golang.org/x/net/trace" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/transport" ) @@ -97,7 +98,7 @@ type ClientStream interface { // NewClientStream creates a new Stream for the client side. This is called // by generated code. -func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { +func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { if cc.dopts.streamInt != nil { return cc.dopts.streamInt(ctx, desc, cc, method, newClientStream, opts...) } @@ -143,6 +144,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } }() } + if stats.On() { + begin := &stats.Begin{ + Client: true, + BeginTime: time.Now(), + FailFast: c.failFast, + } + stats.Handle(ctx, begin) + } + defer func() { + if err != nil && stats.On() { + // Only handle end stats if err != nil. + end := &stats.End{ + Client: true, + Error: err, + } + stats.Handle(ctx, end) + } + }() gopts := BalancerGetOptions{ BlockingWait: !c.failFast, } @@ -194,6 +213,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth tracing: EnableTracing, trInfo: trInfo, + + statsCtx: ctx, } if cc.dopts.cp != nil { cs.cbuf = new(bytes.Buffer) @@ -246,6 +267,11 @@ type clientStream struct { // trInfo.tr is set when the clientStream is created (if EnableTracing is true), // and is set to nil when the clientStream's finish method is called. trInfo traceInfo + + // statsCtx keeps the user context for stats handling. + // All stats collection should use the statsCtx (instead of the stream context) + // so that all the generated stats for a particular RPC can be associated in the processing phase. + statsCtx context.Context } func (cs *clientStream) Context() context.Context { @@ -274,6 +300,8 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } cs.mu.Unlock() } + // TODO Investigate how to signal the stats handling party. + // generate error stats if err != nil && err != io.EOF? defer func() { if err != nil { cs.finish(err) @@ -296,7 +324,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } err = toRPCErr(err) }() - out, err := encode(cs.codec, m, cs.cp, cs.cbuf) + var outPayload *stats.OutPayload + if stats.On() { + outPayload = &stats.OutPayload{ + Client: true, + } + } + out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload) defer func() { if cs.cbuf != nil { cs.cbuf.Reset() @@ -305,11 +339,37 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if err != nil { return Errorf(codes.Internal, "grpc: %v", err) } - return cs.t.Write(cs.s, out, &transport.Options{Last: false}) + err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) + if err == nil && outPayload != nil { + outPayload.SentTime = time.Now() + stats.Handle(cs.statsCtx, outPayload) + } + return err } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32) + defer func() { + if err != nil && stats.On() { + // Only generate End if err != nil. + // If err == nil, it's not the last RecvMsg. + // The last RecvMsg gets either an RPC error or io.EOF. + end := &stats.End{ + Client: true, + EndTime: time.Now(), + } + if err != io.EOF { + end.Error = toRPCErr(err) + } + stats.Handle(cs.statsCtx, end) + } + }() + var inPayload *stats.InPayload + if stats.On() { + inPayload = &stats.InPayload{ + Client: true, + } + } + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload) defer func() { // err != nil indicates the termination of the stream. if err != nil { @@ -324,11 +384,15 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { } cs.mu.Unlock() } + if inPayload != nil { + stats.Handle(cs.statsCtx, inPayload) + } if !cs.desc.ClientStreams || cs.desc.ServerStreams { return } // Special handling for client streaming rpc. - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32) + // This recv expects EOF or errors, so we don't collect inPayload. + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil) cs.closeTransportStream(err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -482,7 +546,11 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - out, err := encode(ss.codec, m, ss.cp, ss.cbuf) + var outPayload *stats.OutPayload + if stats.On() { + outPayload = &stats.OutPayload{} + } + out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload) defer func() { if ss.cbuf != nil { ss.cbuf.Reset() @@ -495,6 +563,10 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil { return toRPCErr(err) } + if outPayload != nil { + outPayload.SentTime = time.Now() + stats.Handle(ss.s.Context(), outPayload) + } return nil } @@ -513,7 +585,11 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { ss.mu.Unlock() } }() - if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize); err != nil { + var inPayload *stats.InPayload + if stats.On() { + inPayload = &stats.InPayload{} + } + if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil { if err == io.EOF { return err } @@ -522,5 +598,8 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } return toRPCErr(err) } + if inPayload != nil { + stats.Handle(ss.s.Context(), inPayload) + } return nil } diff --git a/transport/handler_server.go b/transport/handler_server.go index 114e3490..10b6dc0b 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -268,7 +268,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { }) } -func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { +func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { // With this transport type there will be exactly 1 stream: this HTTP request. var ctx context.Context diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 84fc917f..9843d36b 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -300,7 +300,10 @@ func TestHandlerTransport_HandleStreams(t *testing.T) { st.bodyw.Close() // no body st.ht.WriteStatus(s, codes.OK, "") } - st.ht.HandleStreams(func(s *Stream) { go handleStream(s) }) + st.ht.HandleStreams( + func(s *Stream) { go handleStream(s) }, + func(ctx context.Context, method string) context.Context { return ctx }, + ) wantHeader := http.Header{ "Date": nil, "Content-Type": {"application/grpc"}, @@ -327,7 +330,10 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) handleStream := func(s *Stream) { st.ht.WriteStatus(s, statusCode, msg) } - st.ht.HandleStreams(func(s *Stream) { go handleStream(s) }) + st.ht.HandleStreams( + func(s *Stream) { go handleStream(s) }, + func(ctx context.Context, method string) context.Context { return ctx }, + ) wantHeader := http.Header{ "Date": nil, "Content-Type": {"application/grpc"}, @@ -375,7 +381,10 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { } ht.WriteStatus(s, codes.DeadlineExceeded, "too slow") } - ht.HandleStreams(func(s *Stream) { go runStream(s) }) + ht.HandleStreams( + func(s *Stream) { go runStream(s) }, + func(ctx context.Context, method string) context.Context { return ctx }, + ) wantHeader := http.Header{ "Date": nil, "Content-Type": {"application/grpc"}, diff --git a/transport/http2_client.go b/transport/http2_client.go index 2b0f6801..8d31aa65 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -51,16 +51,19 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" ) // http2Client implements the ClientTransport interface with HTTP2. type http2Client struct { - target string // server name/addr - userAgent string - md interface{} - conn net.Conn // underlying communication channel - authInfo credentials.AuthInfo // auth info about the connection - nextID uint32 // the next stream ID to be used + target string // server name/addr + userAgent string + md interface{} + conn net.Conn // underlying communication channel + remoteAddr net.Addr + localAddr net.Addr + authInfo credentials.AuthInfo // auth info about the connection + nextID uint32 // the next stream ID to be used // writableChan synchronizes write access to the transport. // A writer acquires the write lock by sending a value on writableChan @@ -175,11 +178,13 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } var buf bytes.Buffer t := &http2Client{ - target: addr.Addr, - userAgent: ua, - md: addr.Metadata, - conn: conn, - authInfo: authInfo, + target: addr.Addr, + userAgent: ua, + md: addr.Metadata, + conn: conn, + remoteAddr: conn.RemoteAddr(), + localAddr: conn.LocalAddr(), + authInfo: authInfo, // The client initiated stream id is odd starting from 1. nextID: 1, writableChan: make(chan int, 1), @@ -270,12 +275,13 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // streams. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { pr := &peer.Peer{ - Addr: t.conn.RemoteAddr(), + Addr: t.remoteAddr, } // Attach Auth info if there is any. if t.authInfo != nil { pr.AuthInfo = t.authInfo } + userCtx := ctx ctx = peer.NewContext(ctx, pr) authData := make(map[string]string) for _, c := range t.creds { @@ -347,6 +353,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, ErrConnClosing } s := t.newStream(ctx, callHdr) + s.clientStatsCtx = userCtx t.activeStreams[s.id] = s // This stream is not counted when applySetings(...) initialize t.streamsQuota. @@ -413,6 +420,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } } first := true + bufLen := t.hBuf.Len() // Sends the headers in a single batch even when they span multiple frames. for !endHeaders { size := t.hBuf.Len() @@ -447,6 +455,17 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, connectionErrorf(true, err, "transport: %v", err) } } + if stats.On() { + outHeader := &stats.OutHeader{ + Client: true, + WireLength: bufLen, + FullMethod: callHdr.Method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: callHdr.SendCompress, + } + stats.Handle(s.clientStatsCtx, outHeader) + } t.writableChan <- 0 return s, nil } @@ -874,6 +893,24 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } endStream := frame.StreamEnded() + var isHeader bool + defer func() { + if stats.On() { + if isHeader { + inHeader := &stats.InHeader{ + Client: true, + WireLength: int(frame.Header().Length), + } + stats.Handle(s.clientStatsCtx, inHeader) + } else { + inTrailer := &stats.InTrailer{ + Client: true, + WireLength: int(frame.Header().Length), + } + stats.Handle(s.clientStatsCtx, inTrailer) + } + } + }() s.mu.Lock() if !endStream { @@ -885,6 +922,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } close(s.headerChan) s.headerDone = true + isHeader = true } if !endStream || s.state == streamDone { s.mu.Unlock() diff --git a/transport/http2_server.go b/transport/http2_server.go index a20208cf..db9beb90 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -50,6 +50,7 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" "google.golang.org/grpc/tap" ) @@ -60,6 +61,8 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { conn net.Conn + remoteAddr net.Addr + localAddr net.Addr maxStreamID uint32 // max stream ID ever seen authInfo credentials.AuthInfo // auth info about the connection inTapHandle tap.ServerInHandle @@ -125,6 +128,8 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err var buf bytes.Buffer t := &http2Server{ conn: conn, + remoteAddr: conn.RemoteAddr(), + localAddr: conn.LocalAddr(), authInfo: config.AuthInfo, framer: framer, hBuf: &buf, @@ -146,7 +151,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err } // operateHeader takes action on the decoded headers. -func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) { +func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { buf := newRecvBuffer() s := &Stream{ id: frame.Header().StreamID, @@ -177,7 +182,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.ctx, s.cancel = context.WithCancel(context.TODO()) } pr := &peer.Peer{ - Addr: t.conn.RemoteAddr(), + Addr: t.remoteAddr, } // Attach Auth info if there is any. if t.authInfo != nil { @@ -234,13 +239,25 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.windowHandler = func(n int) { t.updateWindow(s, uint32(n)) } + s.ctx = traceCtx(s.ctx, s.method) + if stats.On() { + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: s.recvCompress, + WireLength: int(frame.Header().Length), + } + stats.Handle(s.ctx, inHeader) + } handle(s) return } // HandleStreams receives incoming streams using the given handler. This is // typically run in a separate goroutine. -func (t *http2Server) HandleStreams(handle func(*Stream)) { +// traceCtx attaches trace to ctx and returns the new context. +func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { // Check the validity of client preface. preface := make([]byte, len(clientPreface)) if _, err := io.ReadFull(t.conn, preface); err != nil { @@ -295,7 +312,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { } switch frame := frame.(type) { case *http2.MetaHeadersFrame: - if t.operateHeaders(frame, handle) { + if t.operateHeaders(frame, handle, traceCtx) { t.Close() break } @@ -508,9 +525,16 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } } + bufLen := t.hBuf.Len() if err := t.writeHeaders(s, t.hBuf, false); err != nil { return err } + if stats.On() { + outHeader := &stats.OutHeader{ + WireLength: bufLen, + } + stats.Handle(s.Context(), outHeader) + } t.writableChan <- 0 return nil } @@ -563,10 +587,17 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } } + bufLen := t.hBuf.Len() if err := t.writeHeaders(s, t.hBuf, true); err != nil { t.Close() return err } + if stats.On() { + outTrailer := &stats.OutTrailer{ + WireLength: bufLen, + } + stats.Handle(s.Context(), outTrailer) + } t.closeStream(s) t.writableChan <- 0 return nil @@ -783,7 +814,7 @@ func (t *http2Server) closeStream(s *Stream) { } func (t *http2Server) RemoteAddr() net.Addr { - return t.conn.RemoteAddr() + return t.remoteAddr } func (t *http2Server) Drain() { diff --git a/transport/transport.go b/transport/transport.go index 7dd02a02..f7100670 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -45,7 +45,6 @@ import ( "sync" "golang.org/x/net/context" - "golang.org/x/net/trace" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" @@ -168,6 +167,11 @@ type Stream struct { id uint32 // nil for client side Stream. st ServerTransport + // clientStatsCtx keeps the user context for stats handling. + // It's only valid on client side. Server side stats context is same as s.ctx. + // All client side stats collection should use the clientStatsCtx (instead of the stream context) + // so that all the generated stats for a particular RPC can be associated in the processing phase. + clientStatsCtx context.Context // ctx is the associated context of the stream. ctx context.Context // cancel is always nil for client side Stream. @@ -267,11 +271,6 @@ func (s *Stream) Context() context.Context { return s.ctx } -// TraceContext recreates the context of s with a trace.Trace. -func (s *Stream) TraceContext(tr trace.Trace) { - s.ctx = trace.NewContext(s.ctx, tr) -} - // Method returns the method for the stream. func (s *Stream) Method() string { return s.method @@ -474,7 +473,7 @@ type ClientTransport interface { // Write methods for a given Stream will be called serially. type ServerTransport interface { // HandleStreams receives incoming streams using the given handler. - HandleStreams(func(*Stream)) + HandleStreams(func(*Stream), func(context.Context, string) context.Context) // WriteHeader sends the header metadata for the given stream. // WriteHeader may not be called on all streams. diff --git a/transport/transport_test.go b/transport/transport_test.go index b7659154..1ca6eb1a 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -197,22 +197,33 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { h := &testStreamHandler{transport.(*http2Server)} switch ht { case suspended: - go transport.HandleStreams(h.handleStreamSuspension) + go transport.HandleStreams(h.handleStreamSuspension, + func(ctx context.Context, method string) context.Context { + return ctx + }) case misbehaved: go transport.HandleStreams(func(s *Stream) { go h.handleStreamMisbehave(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx }) case encodingRequiredStatus: go transport.HandleStreams(func(s *Stream) { go h.handleStreamEncodingRequiredStatus(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx }) case invalidHeaderField: go transport.HandleStreams(func(s *Stream) { go h.handleStreamInvalidHeaderField(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx }) default: go transport.HandleStreams(func(s *Stream) { go h.handleStream(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx }) } }