From 4ed348913c8a37cfab46b8684ff1ed794b13cb86 Mon Sep 17 00:00:00 2001 From: MakMukhi Date: Mon, 9 Jan 2017 13:29:20 -0800 Subject: [PATCH 1/3] ClientHandshake to return AuthInfo (#956) * Initial commit * Initial commit 2 * minor update * goimport update * resolved race condition * added test for TLSInfo on server side * Post review updates * port review changes debug debug * refactoring and added third function * post review changes * post review changes * post review updates * post review commit * post review commit * post review update * post review update * post review update * post review update * post review commit * post review update --- credentials/credentials.go | 4 +- credentials/credentials_test.go | 160 ++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 3 deletions(-) diff --git a/credentials/credentials.go b/credentials/credentials.go index 5555ef02..4d45c3e3 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net case <-ctx.Done(): return nil, nil, ctx.Err() } - // TODO(zhaoq): Omit the auth info for client now. It is more for - // information than anything else. - return conn, nil, nil + return conn, TLSInfo{conn.ConnectionState()}, nil } func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index caf35b2f..a5db3867 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -34,7 +34,11 @@ package credentials import ( + "crypto/tls" + "net" "testing" + + "golang.org/x/net/context" ) func TestTLSOverrideServerName(t *testing.T) { @@ -58,4 +62,160 @@ func TestTLSClone(t *testing.T) { if c.Info().ServerName != expectedServerName { t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) } + +} + +const tlsDir = "../test/testdata/" + +type serverHandshake func(net.Conn) (AuthInfo, error) + +func TestClientHandshakeReturnsAuthInfo(t *testing.T) { + done := make(chan AuthInfo, 1) + lis := launchServer(t, tlsServerHandshake, done) + defer lis.Close() + lisAddr := lis.Addr().String() + clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) + } +} + +func TestServerHandshakeReturnsAuthInfo(t *testing.T) { + done := make(chan AuthInfo, 1) + lis := launchServer(t, gRPCServerHandshake, done) + defer lis.Close() + clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) + } +} + +func TestServerAndClientHandshake(t *testing.T) { + done := make(chan AuthInfo, 1) + lis := launchServer(t, gRPCServerHandshake, done) + defer lis.Close() + clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) + } +} + +func compare(a1, a2 AuthInfo) bool { + if a1.AuthType() != a2.AuthType() { + return false + } + switch a1.AuthType() { + case "tls": + state1 := a1.(TLSInfo).State + state2 := a2.(TLSInfo).State + if state1.Version == state2.Version && + state1.HandshakeComplete == state2.HandshakeComplete && + state1.CipherSuite == state2.CipherSuite && + state1.NegotiatedProtocol == state2.NegotiatedProtocol { + return true + } + return false + default: + return false + } +} + +func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + go serverHandle(t, hs, done, lis) + return lis +} + +// Is run in a seperate goroutine. +func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) { + serverRawConn, err := lis.Accept() + if err != nil { + t.Errorf("Server failed to accept connection: %v", err) + close(done) + return + } + serverAuthInfo, err := hs(serverRawConn) + if err != nil { + t.Errorf("Server failed while handshake. Error: %v", err) + close(done) + return + } + done <- serverAuthInfo +} + +func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo { + conn, err := net.Dial("tcp", lisAddr) + if err != nil { + t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) + } + defer conn.Close() + clientAuthInfo, err := hs(conn, lisAddr) + if err != nil { + t.Fatalf("Error on client while handshake. Error: %v", err) + } + return clientAuthInfo +} + +// Server handshake implementation in gRPC. +func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) { + serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") + if err != nil { + return nil, err + } + _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) + if err != nil { + return nil, err + } + return serverAuthInfo, nil +} + +// Client handshake implementation in gRPC. +func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) { + clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true}) + _, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +func tlsServerHandshake(conn net.Conn) (AuthInfo, error) { + cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key") + if err != nil { + return nil, err + } + serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + serverConn := tls.Server(conn, serverTLSConfig) + err = serverConn.Handshake() + if err != nil { + return nil, err + } + return TLSInfo{State: serverConn.ConnectionState()}, nil +} + +func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) { + clientTLSConfig := &tls.Config{InsecureSkipVerify: true} + clientConn := tls.Client(conn, clientTLSConfig) + if err := clientConn.Handshake(); err != nil { + return nil, err + } + return TLSInfo{State: clientConn.ConnectionState()}, nil } From cb653e4b6150b81ba5157618b57c6f910a6a99f7 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 9 Jan 2017 17:11:32 -0800 Subject: [PATCH 2/3] Change stats APIs (#1030) Change stats API from one static handler to one handler per server or client. --- call.go | 25 +++--- clientconn.go | 9 ++ server.go | 58 +++++++------ stats/handlers.go | 106 ++++-------------------- stats/stats_test.go | 169 ++++++++++++++------------------------ stream.go | 39 +++++---- transport/http2_client.go | 23 +++--- transport/http2_server.go | 27 +++--- transport/transport.go | 10 ++- 9 files changed, 193 insertions(+), 273 deletions(-) diff --git a/call.go b/call.go index 4d8023d9..ba177219 100644 --- a/call.go +++ b/call.go @@ -66,7 +66,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } p := &parser{r: stream} var inPayload *stats.InPayload - if stats.On() { + if dopts.copts.StatsHandler != nil { inPayload = &stats.InPayload{ Client: true, } @@ -82,14 +82,14 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK { // TODO in the current implementation, inTrailer may be handled before inPayload in some cases. // Fix the order if necessary. - stats.HandleRPC(ctx, inPayload) + dopts.copts.StatsHandler.HandleRPC(ctx, inPayload) } c.trailerMD = stream.Trailer() return nil } // sendRequest writes out various information of an RPC such as Context and Message. -func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { +func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { stream, err := t.NewStream(ctx, callHdr) if err != nil { return nil, err @@ -109,19 +109,19 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd if compressor != nil { cbuf = new(bytes.Buffer) } - if stats.On() { + if dopts.copts.StatsHandler != nil { outPayload = &stats.OutPayload{ Client: true, } } - outBuf, err := encode(codec, args, compressor, cbuf, outPayload) + outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload) if err != nil { return nil, Errorf(codes.Internal, "grpc: %v", err) } err = t.Write(stream, outBuf, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() - stats.HandleRPC(ctx, outPayload) + dopts.copts.StatsHandler.HandleRPC(ctx, outPayload) } // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following @@ -179,23 +179,24 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } }() } - if stats.On() { - ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) + sh := cc.dopts.copts.StatsHandler + if sh != nil { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), FailFast: c.failFast, } - stats.HandleRPC(ctx, begin) + sh.HandleRPC(ctx, begin) } defer func() { - if stats.On() { + if sh != nil { end := &stats.End{ Client: true, EndTime: time.Now(), Error: e, } - stats.HandleRPC(ctx, end) + sh.HandleRPC(ctx, end) } }() topts := &transport.Options{ @@ -241,7 +242,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) } - stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts) + stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts) if err != nil { if put != nil { put() diff --git a/clientconn.go b/clientconn.go index aa6b63de..146166a7 100644 --- a/clientconn.go +++ b/clientconn.go @@ -45,6 +45,7 @@ import ( "golang.org/x/net/trace" "google.golang.org/grpc/credentials" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/stats" "google.golang.org/grpc/transport" ) @@ -222,6 +223,14 @@ func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { } } +// WithStatsHandler returns a DialOption that specifies the stats handler +// for all the RPCs and underlying network connections in this ClientConn. +func WithStatsHandler(h stats.Handler) DialOption { + return func(o *dialOptions) { + o.copts.StatsHandler = h + } +} + // FailOnNonTempDialError returns a DialOption that specified if gRPC fails on non-temporary dial errors. // If f is true, and dialer returns a non-temporary error, gRPC will fail the connection to the network // address and won't try to reconnect. diff --git a/server.go b/server.go index b52a5630..985226d6 100644 --- a/server.go +++ b/server.go @@ -113,6 +113,7 @@ type options struct { unaryInt UnaryServerInterceptor streamInt StreamServerInterceptor inTapHandle tap.ServerInHandle + statsHandler stats.Handler maxConcurrentStreams uint32 useHandlerImpl bool // use http.Handler-based server } @@ -200,6 +201,13 @@ func InTapHandle(h tap.ServerInHandle) ServerOption { } } +// StatsHandler returns a ServerOption that sets the stats handler for the server. +func StatsHandler(h stats.Handler) ServerOption { + return func(o *options) { + o.statsHandler = h + } +} + // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -438,9 +446,10 @@ func (s *Server) handleRawConn(rawConn net.Conn) { // transport.NewServerTransport). func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { config := &transport.ServerConfig{ - MaxStreams: s.opts.maxConcurrentStreams, - AuthInfo: authInfo, - InTapHandle: s.opts.inTapHandle, + MaxStreams: s.opts.maxConcurrentStreams, + AuthInfo: authInfo, + InTapHandle: s.opts.inTapHandle, + StatsHandler: s.opts.statsHandler, } st, err := transport.NewServerTransport("http2", c, config) if err != nil { @@ -567,7 +576,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str if cp != nil { cbuf = new(bytes.Buffer) } - if stats.On() { + if s.opts.statsHandler != nil { outPayload = &stats.OutPayload{} } p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload) @@ -584,27 +593,28 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str err = t.Write(stream, p, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() - stats.HandleRPC(stream.Context(), outPayload) + s.opts.statsHandler.HandleRPC(stream.Context(), outPayload) } return err } func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { - if stats.On() { + sh := s.opts.statsHandler + if sh != nil { begin := &stats.Begin{ BeginTime: time.Now(), } - stats.HandleRPC(stream.Context(), begin) + sh.HandleRPC(stream.Context(), begin) } defer func() { - if stats.On() { + if sh != nil { end := &stats.End{ EndTime: time.Now(), } if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - stats.HandleRPC(stream.Context(), end) + sh.HandleRPC(stream.Context(), end) } }() if trInfo != nil { @@ -665,7 +675,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } } var inPayload *stats.InPayload - if stats.On() { + if sh != nil { inPayload = &stats.InPayload{ RecvTime: time.Now(), } @@ -699,7 +709,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. inPayload.Payload = v inPayload.Data = req inPayload.Length = len(req) - stats.HandleRPC(stream.Context(), inPayload) + sh.HandleRPC(stream.Context(), inPayload) } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) @@ -756,35 +766,37 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) { - if stats.On() { + sh := s.opts.statsHandler + if sh != nil { begin := &stats.Begin{ BeginTime: time.Now(), } - stats.HandleRPC(stream.Context(), begin) + sh.HandleRPC(stream.Context(), begin) } defer func() { - if stats.On() { + if sh != nil { end := &stats.End{ EndTime: time.Now(), } if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - stats.HandleRPC(stream.Context(), end) + sh.HandleRPC(stream.Context(), end) } }() if s.opts.cp != nil { stream.SetSendCompress(s.opts.cp.Type()) } ss := &serverStream{ - t: t, - s: stream, - p: &parser{r: stream}, - codec: s.opts.codec, - cp: s.opts.cp, - dc: s.opts.dc, - maxMsgSize: s.opts.maxMsgSize, - trInfo: trInfo, + t: t, + s: stream, + p: &parser{r: stream}, + codec: s.opts.codec, + cp: s.opts.cp, + dc: s.opts.dc, + maxMsgSize: s.opts.maxMsgSize, + trInfo: trInfo, + statsHandler: sh, } if ss.cp != nil { ss.cbuf = new(bytes.Buffer) diff --git a/stats/handlers.go b/stats/handlers.go index ce47786d..26e1a8e2 100644 --- a/stats/handlers.go +++ b/stats/handlers.go @@ -35,10 +35,8 @@ package stats import ( "net" - "sync/atomic" "golang.org/x/net/context" - "google.golang.org/grpc/grpclog" ) // ConnTagInfo defines the relevant information needed by connection context tagger. @@ -56,91 +54,23 @@ type RPCTagInfo struct { FullMethodName string } -var ( - on = new(int32) - rpcHandler func(context.Context, RPCStats) - connHandler func(context.Context, ConnStats) - connTagger func(context.Context, *ConnTagInfo) context.Context - rpcTagger func(context.Context, *RPCTagInfo) context.Context -) +// Handler defines the interface for the related stats handling (e.g., RPCs, connections). +type Handler interface { + // TagRPC can attach some information to the given context. + // The returned context is used in the rest lifetime of the RPC. + TagRPC(context.Context, *RPCTagInfo) context.Context + // HandleRPC processes the RPC stats. + HandleRPC(context.Context, RPCStats) -// HandleRPC processes the RPC stats using the rpc handler registered by the user. -func HandleRPC(ctx context.Context, s RPCStats) { - if rpcHandler == nil { - return - } - rpcHandler(ctx, s) -} - -// RegisterRPCHandler registers the user handler function for RPC stats processing. -// It should be called only once. The later call will overwrite the former value if it is called multiple times. -// This handler function will be called to process the rpc stats. -func RegisterRPCHandler(f func(context.Context, RPCStats)) { - rpcHandler = f -} - -// HandleConn processes the stats using the call back function registered by user. -func HandleConn(ctx context.Context, s ConnStats) { - if connHandler == nil { - return - } - connHandler(ctx, s) -} - -// RegisterConnHandler registers the user handler function for conn stats. -// It should be called only once. The later call will overwrite the former value if it is called multiple times. -// This handler function will be called to process the conn stats. -func RegisterConnHandler(f func(context.Context, ConnStats)) { - connHandler = f -} - -// TagConn calls user registered connection context tagger. -func TagConn(ctx context.Context, info *ConnTagInfo) context.Context { - if connTagger == nil { - return ctx - } - return connTagger(ctx, info) -} - -// RegisterConnTagger registers the user connection context tagger function. -// The connection context tagger can attach some information to the given context. -// The returned context will be used for stats handling. -// For conn stats handling, the context used in connHandler for this -// connection will be derived from the context returned. -// For RPC stats handling, -// - On server side, the context used in rpcHandler for all RPCs on this -// connection will be derived from the context returned. -// - On client side, the context is not derived from the context returned. -func RegisterConnTagger(t func(context.Context, *ConnTagInfo) context.Context) { - connTagger = t -} - -// TagRPC calls the user registered RPC context tagger. -func TagRPC(ctx context.Context, info *RPCTagInfo) context.Context { - if rpcTagger == nil { - return ctx - } - return rpcTagger(ctx, info) -} - -// RegisterRPCTagger registers the user RPC context tagger function. -// The RPC context tagger can attach some information to the given context. -// The context used in stats rpcHandler for this RPC will be derived from the -// context returned. -func RegisterRPCTagger(t func(context.Context, *RPCTagInfo) context.Context) { - rpcTagger = t -} - -// Start starts the stats collection and processing if there is a registered stats handle. -func Start() { - if rpcHandler == nil && connHandler == nil { - grpclog.Println("rpcHandler and connHandler are both nil when starting stats. Stats is not started") - return - } - atomic.StoreInt32(on, 1) -} - -// On indicates whether the stats collection and processing is on. -func On() bool { - return atomic.CompareAndSwapInt32(on, 1, 1) + // TagConn can attach some information to the given context. + // The returned context will be used for stats handling. + // For conn stats handling, the context used in HandleConn for this + // connection will be derived from the context returned. + // For RPC stats handling, + // - On server side, the context used in HandleRPC for all RPCs on this + // connection will be derived from the context returned. + // - On client side, the context is not derived from the context returned. + TagConn(context.Context, *ConnTagInfo) context.Context + // HandleConn processes the Conn stats. + HandleConn(context.Context, ConnStats) } diff --git a/stats/stats_test.go b/stats/stats_test.go index a1e116de..3e5424be 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -57,40 +57,6 @@ func init() { type connCtxKey struct{} type rpcCtxKey struct{} -func TestTagConnCtx(t *testing.T) { - defer stats.RegisterConnTagger(nil) - ctx1 := context.Background() - stats.RegisterConnTagger(nil) - ctx2 := stats.TagConn(ctx1, nil) - if ctx2 != ctx1 { - t.Fatalf("nil conn ctx tagger should not modify context, got %v; want %v", ctx2, ctx1) - } - stats.RegisterConnTagger(func(ctx context.Context, info *stats.ConnTagInfo) context.Context { - return context.WithValue(ctx, connCtxKey{}, "connctxvalue") - }) - ctx3 := stats.TagConn(ctx1, nil) - if v, ok := ctx3.Value(connCtxKey{}).(string); !ok || v != "connctxvalue" { - t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, connCtxKey{}, "connctxvalue")) - } -} - -func TestTagRPCCtx(t *testing.T) { - defer stats.RegisterRPCTagger(nil) - ctx1 := context.Background() - stats.RegisterRPCTagger(nil) - ctx2 := stats.TagRPC(ctx1, nil) - if ctx2 != ctx1 { - t.Fatalf("nil rpc ctx tagger should not modify context, got %v; want %v", ctx2, ctx1) - } - stats.RegisterRPCTagger(func(ctx context.Context, info *stats.RPCTagInfo) context.Context { - return context.WithValue(ctx, rpcCtxKey{}, "rpcctxvalue") - }) - ctx3 := stats.TagRPC(ctx1, nil) - if v, ok := ctx3.Value(rpcCtxKey{}).(string); !ok || v != "rpcctxvalue" { - t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, rpcCtxKey{}, "rpcctxvalue")) - } -} - var ( // For headers: testMetadata = metadata.MD{ @@ -158,8 +124,10 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ // func, modified as needed, and then started with its startServer method. // It should be cleaned up with the tearDown method. type test struct { - t *testing.T - compress string + t *testing.T + compress string + clientStatsHandler stats.Handler + serverStatsHandler stats.Handler testServer testpb.TestServiceServer // nil means none // srv and srvAddr are set once startServer is called. @@ -184,8 +152,13 @@ type testConfig struct { // newTest returns a new test using the provided testing.T and // environment. It is returned with default values. Tests should // modify it before calling its startServer and clientConn methods. -func newTest(t *testing.T, tc *testConfig) *test { - te := &test{t: t, compress: tc.compress} +func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test { + te := &test{ + t: t, + compress: tc.compress, + clientStatsHandler: ch, + serverStatsHandler: sh, + } return te } @@ -204,6 +177,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) { grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), ) } + if te.serverStatsHandler != nil { + opts = append(opts, grpc.StatsHandler(te.serverStatsHandler)) + } s := grpc.NewServer(opts...) te.srv = s if te.testServer != nil { @@ -230,6 +206,9 @@ func (te *test) clientConn() *grpc.ClientConn { grpc.WithDecompressor(grpc.NewGZIPDecompressor()), ) } + if te.clientStatsHandler != nil { + opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler)) + } var err error te.cc, err = grpc.Dial(te.srvAddr, opts...) @@ -617,14 +596,32 @@ func checkConnEnd(t *testing.T, d *gotData, e *expectedData) { st.IsClient() // TODO remove this. } -func tagConnCtx(ctx context.Context, info *stats.ConnTagInfo) context.Context { +type statshandler struct { + mu sync.Mutex + gotRPC []*gotData + gotConn []*gotData +} + +func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { return context.WithValue(ctx, connCtxKey{}, info) } -func tagRPCCtx(ctx context.Context, info *stats.RPCTagInfo) context.Context { +func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { return context.WithValue(ctx, rpcCtxKey{}, info) } +func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) { + h.mu.Lock() + defer h.mu.Unlock() + h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s}) +} + +func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) { + h.mu.Lock() + defer h.mu.Unlock() + h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s}) +} + func checkConnStats(t *testing.T, got []*gotData) { if len(got) <= 0 || len(got)%2 != 0 { for i, g := range got { @@ -662,30 +659,8 @@ func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkF } func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { - var ( - mu sync.Mutex - gotRPC []*gotData - gotConn []*gotData - ) - stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { - mu.Lock() - defer mu.Unlock() - if !s.IsClient() { - gotRPC = append(gotRPC, &gotData{ctx, false, s}) - } - }) - stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { - mu.Lock() - defer mu.Unlock() - if !s.IsClient() { - gotConn = append(gotConn, &gotData{ctx, false, s}) - } - }) - stats.RegisterConnTagger(tagConnCtx) - stats.RegisterRPCTagger(tagRPCCtx) - stats.Start() - - te := newTest(t, tc) + h := &statshandler{} + te := newTest(t, tc, nil, h) te.startServer(&testServer{}) defer te.tearDown() @@ -709,22 +684,22 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f te.srv.GracefulStop() // Wait for the server to stop. for { - mu.Lock() - if len(gotRPC) >= len(checkFuncs) { - mu.Unlock() + h.mu.Lock() + if len(h.gotRPC) >= len(checkFuncs) { + h.mu.Unlock() break } - mu.Unlock() + h.mu.Unlock() time.Sleep(10 * time.Millisecond) } for { - mu.Lock() - if _, ok := gotConn[len(gotConn)-1].s.(*stats.ConnEnd); ok { - mu.Unlock() + h.mu.Lock() + if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { + h.mu.Unlock() break } - mu.Unlock() + h.mu.Unlock() time.Sleep(10 * time.Millisecond) } @@ -741,8 +716,8 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f expect.method = "/grpc.testing.TestService/FullDuplexCall" } - checkConnStats(t, gotConn) - checkServerStats(t, gotRPC, expect, checkFuncs) + checkConnStats(t, h.gotConn) + checkServerStats(t, h.gotRPC, expect, checkFuncs) } func TestServerStatsUnaryRPC(t *testing.T) { @@ -891,30 +866,8 @@ func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkF } func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) { - var ( - mu sync.Mutex - gotRPC []*gotData - gotConn []*gotData - ) - stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { - mu.Lock() - defer mu.Unlock() - if s.IsClient() { - gotRPC = append(gotRPC, &gotData{ctx, true, s}) - } - }) - stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { - mu.Lock() - defer mu.Unlock() - if s.IsClient() { - gotConn = append(gotConn, &gotData{ctx, true, s}) - } - }) - stats.RegisterConnTagger(tagConnCtx) - stats.RegisterRPCTagger(tagRPCCtx) - stats.Start() - - te := newTest(t, tc) + h := &statshandler{} + te := newTest(t, tc, h, nil) te.startServer(&testServer{}) defer te.tearDown() @@ -942,22 +895,22 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map lenRPCStats += v.c } for { - mu.Lock() - if len(gotRPC) >= lenRPCStats { - mu.Unlock() + h.mu.Lock() + if len(h.gotRPC) >= lenRPCStats { + h.mu.Unlock() break } - mu.Unlock() + h.mu.Unlock() time.Sleep(10 * time.Millisecond) } for { - mu.Lock() - if _, ok := gotConn[len(gotConn)-1].s.(*stats.ConnEnd); ok { - mu.Unlock() + h.mu.Lock() + if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { + h.mu.Unlock() break } - mu.Unlock() + h.mu.Unlock() time.Sleep(10 * time.Millisecond) } @@ -975,8 +928,8 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map expect.method = "/grpc.testing.TestService/FullDuplexCall" } - checkConnStats(t, gotConn) - checkClientStats(t, gotRPC, expect, checkFuncs) + checkConnStats(t, h.gotConn) + checkClientStats(t, h.gotRPC, expect, checkFuncs) } func TestClientStatsUnaryRPC(t *testing.T) { diff --git a/stream.go b/stream.go index d3a4debf..bb468dc3 100644 --- a/stream.go +++ b/stream.go @@ -151,23 +151,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } }() } - if stats.On() { - ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) + sh := cc.dopts.copts.StatsHandler + if sh != nil { + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), FailFast: c.failFast, } - stats.HandleRPC(ctx, begin) + sh.HandleRPC(ctx, begin) } defer func() { - if err != nil && stats.On() { + if err != nil && sh != nil { // Only handle end stats if err != nil. end := &stats.End{ Client: true, Error: err, } - stats.HandleRPC(ctx, end) + sh.HandleRPC(ctx, end) } }() gopts := BalancerGetOptions{ @@ -223,7 +224,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth tracing: EnableTracing, trInfo: trInfo, - statsCtx: ctx, + statsCtx: ctx, + statsHandler: cc.dopts.copts.StatsHandler, } if cc.dopts.cp != nil { cs.cbuf = new(bytes.Buffer) @@ -281,7 +283,8 @@ type clientStream struct { // statsCtx keeps the user context for stats handling. // All stats collection should use the statsCtx (instead of the stream context) // so that all the generated stats for a particular RPC can be associated in the processing phase. - statsCtx context.Context + statsCtx context.Context + statsHandler stats.Handler } func (cs *clientStream) Context() context.Context { @@ -335,7 +338,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { err = toRPCErr(err) }() var outPayload *stats.OutPayload - if stats.On() { + if cs.statsHandler != nil { outPayload = &stats.OutPayload{ Client: true, } @@ -352,14 +355,14 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() - stats.HandleRPC(cs.statsCtx, outPayload) + cs.statsHandler.HandleRPC(cs.statsCtx, outPayload) } return err } func (cs *clientStream) RecvMsg(m interface{}) (err error) { defer func() { - if err != nil && stats.On() { + if err != nil && cs.statsHandler != nil { // Only generate End if err != nil. // If err == nil, it's not the last RecvMsg. // The last RecvMsg gets either an RPC error or io.EOF. @@ -370,11 +373,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if err != io.EOF { end.Error = toRPCErr(err) } - stats.HandleRPC(cs.statsCtx, end) + cs.statsHandler.HandleRPC(cs.statsCtx, end) } }() var inPayload *stats.InPayload - if stats.On() { + if cs.statsHandler != nil { inPayload = &stats.InPayload{ Client: true, } @@ -395,7 +398,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { cs.mu.Unlock() } if inPayload != nil { - stats.HandleRPC(cs.statsCtx, inPayload) + cs.statsHandler.HandleRPC(cs.statsCtx, inPayload) } if !cs.desc.ClientStreams || cs.desc.ServerStreams { return @@ -520,6 +523,8 @@ type serverStream struct { statusDesc string trInfo *traceInfo + statsHandler stats.Handler + mu sync.Mutex // protects trInfo.tr after the service handler runs. } @@ -562,7 +567,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } }() var outPayload *stats.OutPayload - if stats.On() { + if ss.statsHandler != nil { outPayload = &stats.OutPayload{} } out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload) @@ -580,7 +585,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } if outPayload != nil { outPayload.SentTime = time.Now() - stats.HandleRPC(ss.s.Context(), outPayload) + ss.statsHandler.HandleRPC(ss.s.Context(), outPayload) } return nil } @@ -601,7 +606,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { } }() var inPayload *stats.InPayload - if stats.On() { + if ss.statsHandler != nil { inPayload = &stats.InPayload{} } if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil { @@ -614,7 +619,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { return toRPCErr(err) } if inPayload != nil { - stats.HandleRPC(ss.s.Context(), inPayload) + ss.statsHandler.HandleRPC(ss.s.Context(), inPayload) } return nil } diff --git a/transport/http2_client.go b/transport/http2_client.go index 605b1e5a..892f8ba6 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -99,6 +99,8 @@ type http2Client struct { creds []credentials.PerRPCCredentials + statsHandler stats.Handler + mu sync.Mutex // guard the following variables state transportState // the state of underlying connection activeStreams map[uint32]*Stream @@ -208,16 +210,17 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( creds: opts.PerRPCCredentials, maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, + statsHandler: opts.StatsHandler, } - if stats.On() { - t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{ + if t.statsHandler != nil { + t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, }) connBegin := &stats.ConnBegin{ Client: true, } - stats.HandleConn(t.ctx, connBegin) + t.statsHandler.HandleConn(t.ctx, connBegin) } // Start the reader goroutine for incoming message. Each transport has // a dedicated goroutine which reads HTTP2 frame from network. Then it @@ -470,7 +473,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea return nil, connectionErrorf(true, err, "transport: %v", err) } } - if stats.On() { + if t.statsHandler != nil { outHeader := &stats.OutHeader{ Client: true, WireLength: bufLen, @@ -479,7 +482,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea LocalAddr: t.localAddr, Compression: callHdr.SendCompress, } - stats.HandleRPC(s.clientStatsCtx, outHeader) + t.statsHandler.HandleRPC(s.clientStatsCtx, outHeader) } t.writableChan <- 0 return s, nil @@ -559,11 +562,11 @@ func (t *http2Client) Close() (err error) { s.mu.Unlock() s.write(recvMsg{err: ErrConnClosing}) } - if stats.On() { + if t.statsHandler != nil { connEnd := &stats.ConnEnd{ Client: true, } - stats.HandleConn(t.ctx, connEnd) + t.statsHandler.HandleConn(t.ctx, connEnd) } return } @@ -911,19 +914,19 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { endStream := frame.StreamEnded() var isHeader bool defer func() { - if stats.On() { + if t.statsHandler != nil { if isHeader { inHeader := &stats.InHeader{ Client: true, WireLength: int(frame.Header().Length), } - stats.HandleRPC(s.clientStatsCtx, inHeader) + t.statsHandler.HandleRPC(s.clientStatsCtx, inHeader) } else { inTrailer := &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), } - stats.HandleRPC(s.clientStatsCtx, inTrailer) + t.statsHandler.HandleRPC(s.clientStatsCtx, inTrailer) } } }() diff --git a/transport/http2_server.go b/transport/http2_server.go index 316188e7..a095dd0e 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -88,6 +88,8 @@ type http2Server struct { // sendQuotaPool provides flow control to outbound message. sendQuotaPool *quotaPool + stats stats.Handler + mu sync.Mutex // guard the following state transportState activeStreams map[uint32]*Stream @@ -146,14 +148,15 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err shutdownChan: make(chan struct{}), activeStreams: make(map[uint32]*Stream), streamSendQuota: defaultWindowSize, + stats: config.StatsHandler, } - if stats.On() { - t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{ + if t.stats != nil { + t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, }) connBegin := &stats.ConnBegin{} - stats.HandleConn(t.ctx, connBegin) + t.stats.HandleConn(t.ctx, connBegin) } go t.controller() t.writableChan <- 0 @@ -250,8 +253,8 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( t.updateWindow(s, uint32(n)) } s.ctx = traceCtx(s.ctx, s.method) - if stats.On() { - s.ctx = stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + if t.stats != nil { + s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ FullMethod: s.method, RemoteAddr: t.remoteAddr, @@ -259,7 +262,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( Compression: s.recvCompress, WireLength: int(frame.Header().Length), } - stats.HandleRPC(s.ctx, inHeader) + t.stats.HandleRPC(s.ctx, inHeader) } handle(s) return @@ -540,11 +543,11 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { if err := t.writeHeaders(s, t.hBuf, false); err != nil { return err } - if stats.On() { + if t.stats != nil { outHeader := &stats.OutHeader{ WireLength: bufLen, } - stats.HandleRPC(s.Context(), outHeader) + t.stats.HandleRPC(s.Context(), outHeader) } t.writableChan <- 0 return nil @@ -603,11 +606,11 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s t.Close() return err } - if stats.On() { + if t.stats != nil { outTrailer := &stats.OutTrailer{ WireLength: bufLen, } - stats.HandleRPC(s.Context(), outTrailer) + t.stats.HandleRPC(s.Context(), outTrailer) } t.closeStream(s) t.writableChan <- 0 @@ -789,9 +792,9 @@ func (t *http2Server) Close() (err error) { for _, s := range streams { s.cancel() } - if stats.On() { + if t.stats != nil { connEnd := &stats.ConnEnd{} - stats.HandleConn(t.ctx, connEnd) + t.stats.HandleConn(t.ctx, connEnd) } return } diff --git a/transport/transport.go b/transport/transport.go index 4726bb2c..d4659918 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -48,6 +48,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/tap" ) @@ -357,9 +358,10 @@ const ( // ServerConfig consists of all the configurations to establish a server transport. type ServerConfig struct { - MaxStreams uint32 - AuthInfo credentials.AuthInfo - InTapHandle tap.ServerInHandle + MaxStreams uint32 + AuthInfo credentials.AuthInfo + InTapHandle tap.ServerInHandle + StatsHandler stats.Handler } // NewServerTransport creates a ServerTransport with conn or non-nil error @@ -380,6 +382,8 @@ type ConnectOptions struct { PerRPCCredentials []credentials.PerRPCCredentials // TransportCredentials stores the Authenticator required to setup a client connection. TransportCredentials credentials.TransportCredentials + // StatsHandler stores the handler for stats. + StatsHandler stats.Handler } // TargetInfo contains the information of the target such as network address and metadata. From 50955793b0183f9de69bd78e2ec251cf20aab121 Mon Sep 17 00:00:00 2001 From: MakMukhi Date: Wed, 11 Jan 2017 11:10:52 -0800 Subject: [PATCH 3/3] Debugging tests for AuthInfo (#1046) * debug * fix --- credentials/credentials_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index a5db3867..1609374c 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -155,6 +155,7 @@ func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net. serverAuthInfo, err := hs(serverRawConn) if err != nil { t.Errorf("Server failed while handshake. Error: %v", err) + serverRawConn.Close() close(done) return }