Merge remote-tracking branch 'upstream/master' into service_config_doc_fix
This commit is contained in:
25
call.go
25
call.go
@ -66,7 +66,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
|
|||||||
}
|
}
|
||||||
p := &parser{r: stream}
|
p := &parser{r: stream}
|
||||||
var inPayload *stats.InPayload
|
var inPayload *stats.InPayload
|
||||||
if stats.On() {
|
if dopts.copts.StatsHandler != nil {
|
||||||
inPayload = &stats.InPayload{
|
inPayload = &stats.InPayload{
|
||||||
Client: true,
|
Client: true,
|
||||||
}
|
}
|
||||||
@ -82,14 +82,14 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
|
|||||||
if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK {
|
if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK {
|
||||||
// TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
|
// TODO in the current implementation, inTrailer may be handled before inPayload in some cases.
|
||||||
// Fix the order if necessary.
|
// Fix the order if necessary.
|
||||||
stats.HandleRPC(ctx, inPayload)
|
dopts.copts.StatsHandler.HandleRPC(ctx, inPayload)
|
||||||
}
|
}
|
||||||
c.trailerMD = stream.Trailer()
|
c.trailerMD = stream.Trailer()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendRequest writes out various information of an RPC such as Context and Message.
|
// sendRequest writes out various information of an RPC such as Context and Message.
|
||||||
func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
|
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
|
||||||
stream, err := t.NewStream(ctx, callHdr)
|
stream, err := t.NewStream(ctx, callHdr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -109,19 +109,19 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
|
|||||||
if compressor != nil {
|
if compressor != nil {
|
||||||
cbuf = new(bytes.Buffer)
|
cbuf = new(bytes.Buffer)
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if dopts.copts.StatsHandler != nil {
|
||||||
outPayload = &stats.OutPayload{
|
outPayload = &stats.OutPayload{
|
||||||
Client: true,
|
Client: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outBuf, err := encode(codec, args, compressor, cbuf, outPayload)
|
outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, Errorf(codes.Internal, "grpc: %v", err)
|
return nil, Errorf(codes.Internal, "grpc: %v", err)
|
||||||
}
|
}
|
||||||
err = t.Write(stream, outBuf, opts)
|
err = t.Write(stream, outBuf, opts)
|
||||||
if err == nil && outPayload != nil {
|
if err == nil && outPayload != nil {
|
||||||
outPayload.SentTime = time.Now()
|
outPayload.SentTime = time.Now()
|
||||||
stats.HandleRPC(ctx, outPayload)
|
dopts.copts.StatsHandler.HandleRPC(ctx, outPayload)
|
||||||
}
|
}
|
||||||
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
|
// t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method
|
||||||
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
|
// does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following
|
||||||
@ -179,23 +179,24 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
if stats.On() {
|
sh := cc.dopts.copts.StatsHandler
|
||||||
ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
|
if sh != nil {
|
||||||
|
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
|
||||||
begin := &stats.Begin{
|
begin := &stats.Begin{
|
||||||
Client: true,
|
Client: true,
|
||||||
BeginTime: time.Now(),
|
BeginTime: time.Now(),
|
||||||
FailFast: c.failFast,
|
FailFast: c.failFast,
|
||||||
}
|
}
|
||||||
stats.HandleRPC(ctx, begin)
|
sh.HandleRPC(ctx, begin)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if stats.On() {
|
if sh != nil {
|
||||||
end := &stats.End{
|
end := &stats.End{
|
||||||
Client: true,
|
Client: true,
|
||||||
EndTime: time.Now(),
|
EndTime: time.Now(),
|
||||||
Error: e,
|
Error: e,
|
||||||
}
|
}
|
||||||
stats.HandleRPC(ctx, end)
|
sh.HandleRPC(ctx, end)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
topts := &transport.Options{
|
topts := &transport.Options{
|
||||||
@ -241,7 +242,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
if c.traceInfo.tr != nil {
|
if c.traceInfo.tr != nil {
|
||||||
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
|
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
|
||||||
}
|
}
|
||||||
stream, err = sendRequest(ctx, cc.dopts.codec, cc.dopts.cp, callHdr, t, args, topts)
|
stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if put != nil {
|
if put != nil {
|
||||||
put()
|
put()
|
||||||
|
@ -45,6 +45,7 @@ import (
|
|||||||
"golang.org/x/net/trace"
|
"golang.org/x/net/trace"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/grpclog"
|
"google.golang.org/grpc/grpclog"
|
||||||
|
"google.golang.org/grpc/stats"
|
||||||
"google.golang.org/grpc/transport"
|
"google.golang.org/grpc/transport"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -222,6 +223,14 @@ func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithStatsHandler returns a DialOption that specifies the stats handler
|
||||||
|
// for all the RPCs and underlying network connections in this ClientConn.
|
||||||
|
func WithStatsHandler(h stats.Handler) DialOption {
|
||||||
|
return func(o *dialOptions) {
|
||||||
|
o.copts.StatsHandler = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FailOnNonTempDialError returns a DialOption that specified if gRPC fails on non-temporary dial errors.
|
// FailOnNonTempDialError returns a DialOption that specified if gRPC fails on non-temporary dial errors.
|
||||||
// If f is true, and dialer returns a non-temporary error, gRPC will fail the connection to the network
|
// If f is true, and dialer returns a non-temporary error, gRPC will fail the connection to the network
|
||||||
// address and won't try to reconnect.
|
// address and won't try to reconnect.
|
||||||
|
@ -165,9 +165,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, nil, ctx.Err()
|
return nil, nil, ctx.Err()
|
||||||
}
|
}
|
||||||
// TODO(zhaoq): Omit the auth info for client now. It is more for
|
return conn, TLSInfo{conn.ConnectionState()}, nil
|
||||||
// information than anything else.
|
|
||||||
return conn, nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
|
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) {
|
||||||
|
@ -34,7 +34,11 @@
|
|||||||
package credentials
|
package credentials
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTLSOverrideServerName(t *testing.T) {
|
func TestTLSOverrideServerName(t *testing.T) {
|
||||||
@ -58,4 +62,161 @@ func TestTLSClone(t *testing.T) {
|
|||||||
if c.Info().ServerName != expectedServerName {
|
if c.Info().ServerName != expectedServerName {
|
||||||
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
|
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
const tlsDir = "../test/testdata/"
|
||||||
|
|
||||||
|
type serverHandshake func(net.Conn) (AuthInfo, error)
|
||||||
|
|
||||||
|
func TestClientHandshakeReturnsAuthInfo(t *testing.T) {
|
||||||
|
done := make(chan AuthInfo, 1)
|
||||||
|
lis := launchServer(t, tlsServerHandshake, done)
|
||||||
|
defer lis.Close()
|
||||||
|
lisAddr := lis.Addr().String()
|
||||||
|
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
|
||||||
|
// wait until server sends serverAuthInfo or fails.
|
||||||
|
serverAuthInfo, ok := <-done
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error at server-side")
|
||||||
|
}
|
||||||
|
if !compare(clientAuthInfo, serverAuthInfo) {
|
||||||
|
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerHandshakeReturnsAuthInfo(t *testing.T) {
|
||||||
|
done := make(chan AuthInfo, 1)
|
||||||
|
lis := launchServer(t, gRPCServerHandshake, done)
|
||||||
|
defer lis.Close()
|
||||||
|
clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
|
||||||
|
// wait until server sends serverAuthInfo or fails.
|
||||||
|
serverAuthInfo, ok := <-done
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error at server-side")
|
||||||
|
}
|
||||||
|
if !compare(clientAuthInfo, serverAuthInfo) {
|
||||||
|
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerAndClientHandshake(t *testing.T) {
|
||||||
|
done := make(chan AuthInfo, 1)
|
||||||
|
lis := launchServer(t, gRPCServerHandshake, done)
|
||||||
|
defer lis.Close()
|
||||||
|
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
|
||||||
|
// wait until server sends serverAuthInfo or fails.
|
||||||
|
serverAuthInfo, ok := <-done
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error at server-side")
|
||||||
|
}
|
||||||
|
if !compare(clientAuthInfo, serverAuthInfo) {
|
||||||
|
t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compare(a1, a2 AuthInfo) bool {
|
||||||
|
if a1.AuthType() != a2.AuthType() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch a1.AuthType() {
|
||||||
|
case "tls":
|
||||||
|
state1 := a1.(TLSInfo).State
|
||||||
|
state2 := a2.(TLSInfo).State
|
||||||
|
if state1.Version == state2.Version &&
|
||||||
|
state1.HandshakeComplete == state2.HandshakeComplete &&
|
||||||
|
state1.CipherSuite == state2.CipherSuite &&
|
||||||
|
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener {
|
||||||
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to listen: %v", err)
|
||||||
|
}
|
||||||
|
go serverHandle(t, hs, done, lis)
|
||||||
|
return lis
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is run in a seperate goroutine.
|
||||||
|
func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) {
|
||||||
|
serverRawConn, err := lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Server failed to accept connection: %v", err)
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverAuthInfo, err := hs(serverRawConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Server failed while handshake. Error: %v", err)
|
||||||
|
serverRawConn.Close()
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
done <- serverAuthInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo {
|
||||||
|
conn, err := net.Dial("tcp", lisAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
clientAuthInfo, err := hs(conn, lisAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error on client while handshake. Error: %v", err)
|
||||||
|
}
|
||||||
|
return clientAuthInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server handshake implementation in gRPC.
|
||||||
|
func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) {
|
||||||
|
serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return serverAuthInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client handshake implementation in gRPC.
|
||||||
|
func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) {
|
||||||
|
clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true})
|
||||||
|
_, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return authInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tlsServerHandshake(conn net.Conn) (AuthInfo, error) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||||
|
serverConn := tls.Server(conn, serverTLSConfig)
|
||||||
|
err = serverConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return TLSInfo{State: serverConn.ConnectionState()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) {
|
||||||
|
clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
|
||||||
|
clientConn := tls.Client(conn, clientTLSConfig)
|
||||||
|
if err := clientConn.Handshake(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return TLSInfo{State: clientConn.ConnectionState()}, nil
|
||||||
}
|
}
|
||||||
|
36
server.go
36
server.go
@ -113,6 +113,7 @@ type options struct {
|
|||||||
unaryInt UnaryServerInterceptor
|
unaryInt UnaryServerInterceptor
|
||||||
streamInt StreamServerInterceptor
|
streamInt StreamServerInterceptor
|
||||||
inTapHandle tap.ServerInHandle
|
inTapHandle tap.ServerInHandle
|
||||||
|
statsHandler stats.Handler
|
||||||
maxConcurrentStreams uint32
|
maxConcurrentStreams uint32
|
||||||
useHandlerImpl bool // use http.Handler-based server
|
useHandlerImpl bool // use http.Handler-based server
|
||||||
}
|
}
|
||||||
@ -200,6 +201,13 @@ func InTapHandle(h tap.ServerInHandle) ServerOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StatsHandler returns a ServerOption that sets the stats handler for the server.
|
||||||
|
func StatsHandler(h stats.Handler) ServerOption {
|
||||||
|
return func(o *options) {
|
||||||
|
o.statsHandler = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewServer creates a gRPC server which has no service registered and has not
|
// NewServer creates a gRPC server which has no service registered and has not
|
||||||
// started to accept requests yet.
|
// started to accept requests yet.
|
||||||
func NewServer(opt ...ServerOption) *Server {
|
func NewServer(opt ...ServerOption) *Server {
|
||||||
@ -441,6 +449,7 @@ func (s *Server) serveHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo)
|
|||||||
MaxStreams: s.opts.maxConcurrentStreams,
|
MaxStreams: s.opts.maxConcurrentStreams,
|
||||||
AuthInfo: authInfo,
|
AuthInfo: authInfo,
|
||||||
InTapHandle: s.opts.inTapHandle,
|
InTapHandle: s.opts.inTapHandle,
|
||||||
|
StatsHandler: s.opts.statsHandler,
|
||||||
}
|
}
|
||||||
st, err := transport.NewServerTransport("http2", c, config)
|
st, err := transport.NewServerTransport("http2", c, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -567,7 +576,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
|
|||||||
if cp != nil {
|
if cp != nil {
|
||||||
cbuf = new(bytes.Buffer)
|
cbuf = new(bytes.Buffer)
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if s.opts.statsHandler != nil {
|
||||||
outPayload = &stats.OutPayload{}
|
outPayload = &stats.OutPayload{}
|
||||||
}
|
}
|
||||||
p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
|
p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
|
||||||
@ -584,27 +593,28 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
|
|||||||
err = t.Write(stream, p, opts)
|
err = t.Write(stream, p, opts)
|
||||||
if err == nil && outPayload != nil {
|
if err == nil && outPayload != nil {
|
||||||
outPayload.SentTime = time.Now()
|
outPayload.SentTime = time.Now()
|
||||||
stats.HandleRPC(stream.Context(), outPayload)
|
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
|
func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
|
||||||
if stats.On() {
|
sh := s.opts.statsHandler
|
||||||
|
if sh != nil {
|
||||||
begin := &stats.Begin{
|
begin := &stats.Begin{
|
||||||
BeginTime: time.Now(),
|
BeginTime: time.Now(),
|
||||||
}
|
}
|
||||||
stats.HandleRPC(stream.Context(), begin)
|
sh.HandleRPC(stream.Context(), begin)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if stats.On() {
|
if sh != nil {
|
||||||
end := &stats.End{
|
end := &stats.End{
|
||||||
EndTime: time.Now(),
|
EndTime: time.Now(),
|
||||||
}
|
}
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
end.Error = toRPCErr(err)
|
end.Error = toRPCErr(err)
|
||||||
}
|
}
|
||||||
stats.HandleRPC(stream.Context(), end)
|
sh.HandleRPC(stream.Context(), end)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if trInfo != nil {
|
if trInfo != nil {
|
||||||
@ -665,7 +675,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
var inPayload *stats.InPayload
|
var inPayload *stats.InPayload
|
||||||
if stats.On() {
|
if sh != nil {
|
||||||
inPayload = &stats.InPayload{
|
inPayload = &stats.InPayload{
|
||||||
RecvTime: time.Now(),
|
RecvTime: time.Now(),
|
||||||
}
|
}
|
||||||
@ -699,7 +709,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
inPayload.Payload = v
|
inPayload.Payload = v
|
||||||
inPayload.Data = req
|
inPayload.Data = req
|
||||||
inPayload.Length = len(req)
|
inPayload.Length = len(req)
|
||||||
stats.HandleRPC(stream.Context(), inPayload)
|
sh.HandleRPC(stream.Context(), inPayload)
|
||||||
}
|
}
|
||||||
if trInfo != nil {
|
if trInfo != nil {
|
||||||
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
|
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)
|
||||||
@ -756,21 +766,22 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
|
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc, trInfo *traceInfo) (err error) {
|
||||||
if stats.On() {
|
sh := s.opts.statsHandler
|
||||||
|
if sh != nil {
|
||||||
begin := &stats.Begin{
|
begin := &stats.Begin{
|
||||||
BeginTime: time.Now(),
|
BeginTime: time.Now(),
|
||||||
}
|
}
|
||||||
stats.HandleRPC(stream.Context(), begin)
|
sh.HandleRPC(stream.Context(), begin)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if stats.On() {
|
if sh != nil {
|
||||||
end := &stats.End{
|
end := &stats.End{
|
||||||
EndTime: time.Now(),
|
EndTime: time.Now(),
|
||||||
}
|
}
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
end.Error = toRPCErr(err)
|
end.Error = toRPCErr(err)
|
||||||
}
|
}
|
||||||
stats.HandleRPC(stream.Context(), end)
|
sh.HandleRPC(stream.Context(), end)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if s.opts.cp != nil {
|
if s.opts.cp != nil {
|
||||||
@ -785,6 +796,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
dc: s.opts.dc,
|
dc: s.opts.dc,
|
||||||
maxMsgSize: s.opts.maxMsgSize,
|
maxMsgSize: s.opts.maxMsgSize,
|
||||||
trInfo: trInfo,
|
trInfo: trInfo,
|
||||||
|
statsHandler: sh,
|
||||||
}
|
}
|
||||||
if ss.cp != nil {
|
if ss.cp != nil {
|
||||||
ss.cbuf = new(bytes.Buffer)
|
ss.cbuf = new(bytes.Buffer)
|
||||||
|
@ -35,10 +35,8 @@ package stats
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc/grpclog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnTagInfo defines the relevant information needed by connection context tagger.
|
// ConnTagInfo defines the relevant information needed by connection context tagger.
|
||||||
@ -56,91 +54,23 @@ type RPCTagInfo struct {
|
|||||||
FullMethodName string
|
FullMethodName string
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// Handler defines the interface for the related stats handling (e.g., RPCs, connections).
|
||||||
on = new(int32)
|
type Handler interface {
|
||||||
rpcHandler func(context.Context, RPCStats)
|
// TagRPC can attach some information to the given context.
|
||||||
connHandler func(context.Context, ConnStats)
|
// The returned context is used in the rest lifetime of the RPC.
|
||||||
connTagger func(context.Context, *ConnTagInfo) context.Context
|
TagRPC(context.Context, *RPCTagInfo) context.Context
|
||||||
rpcTagger func(context.Context, *RPCTagInfo) context.Context
|
// HandleRPC processes the RPC stats.
|
||||||
)
|
HandleRPC(context.Context, RPCStats)
|
||||||
|
|
||||||
// HandleRPC processes the RPC stats using the rpc handler registered by the user.
|
// TagConn can attach some information to the given context.
|
||||||
func HandleRPC(ctx context.Context, s RPCStats) {
|
// The returned context will be used for stats handling.
|
||||||
if rpcHandler == nil {
|
// For conn stats handling, the context used in HandleConn for this
|
||||||
return
|
// connection will be derived from the context returned.
|
||||||
}
|
// For RPC stats handling,
|
||||||
rpcHandler(ctx, s)
|
// - On server side, the context used in HandleRPC for all RPCs on this
|
||||||
}
|
// connection will be derived from the context returned.
|
||||||
|
// - On client side, the context is not derived from the context returned.
|
||||||
// RegisterRPCHandler registers the user handler function for RPC stats processing.
|
TagConn(context.Context, *ConnTagInfo) context.Context
|
||||||
// It should be called only once. The later call will overwrite the former value if it is called multiple times.
|
// HandleConn processes the Conn stats.
|
||||||
// This handler function will be called to process the rpc stats.
|
HandleConn(context.Context, ConnStats)
|
||||||
func RegisterRPCHandler(f func(context.Context, RPCStats)) {
|
|
||||||
rpcHandler = f
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleConn processes the stats using the call back function registered by user.
|
|
||||||
func HandleConn(ctx context.Context, s ConnStats) {
|
|
||||||
if connHandler == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
connHandler(ctx, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterConnHandler registers the user handler function for conn stats.
|
|
||||||
// It should be called only once. The later call will overwrite the former value if it is called multiple times.
|
|
||||||
// This handler function will be called to process the conn stats.
|
|
||||||
func RegisterConnHandler(f func(context.Context, ConnStats)) {
|
|
||||||
connHandler = f
|
|
||||||
}
|
|
||||||
|
|
||||||
// TagConn calls user registered connection context tagger.
|
|
||||||
func TagConn(ctx context.Context, info *ConnTagInfo) context.Context {
|
|
||||||
if connTagger == nil {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
return connTagger(ctx, info)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterConnTagger registers the user connection context tagger function.
|
|
||||||
// The connection context tagger can attach some information to the given context.
|
|
||||||
// The returned context will be used for stats handling.
|
|
||||||
// For conn stats handling, the context used in connHandler for this
|
|
||||||
// connection will be derived from the context returned.
|
|
||||||
// For RPC stats handling,
|
|
||||||
// - On server side, the context used in rpcHandler for all RPCs on this
|
|
||||||
// connection will be derived from the context returned.
|
|
||||||
// - On client side, the context is not derived from the context returned.
|
|
||||||
func RegisterConnTagger(t func(context.Context, *ConnTagInfo) context.Context) {
|
|
||||||
connTagger = t
|
|
||||||
}
|
|
||||||
|
|
||||||
// TagRPC calls the user registered RPC context tagger.
|
|
||||||
func TagRPC(ctx context.Context, info *RPCTagInfo) context.Context {
|
|
||||||
if rpcTagger == nil {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
return rpcTagger(ctx, info)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterRPCTagger registers the user RPC context tagger function.
|
|
||||||
// The RPC context tagger can attach some information to the given context.
|
|
||||||
// The context used in stats rpcHandler for this RPC will be derived from the
|
|
||||||
// context returned.
|
|
||||||
func RegisterRPCTagger(t func(context.Context, *RPCTagInfo) context.Context) {
|
|
||||||
rpcTagger = t
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the stats collection and processing if there is a registered stats handle.
|
|
||||||
func Start() {
|
|
||||||
if rpcHandler == nil && connHandler == nil {
|
|
||||||
grpclog.Println("rpcHandler and connHandler are both nil when starting stats. Stats is not started")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
atomic.StoreInt32(on, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// On indicates whether the stats collection and processing is on.
|
|
||||||
func On() bool {
|
|
||||||
return atomic.CompareAndSwapInt32(on, 1, 1)
|
|
||||||
}
|
}
|
||||||
|
@ -57,40 +57,6 @@ func init() {
|
|||||||
type connCtxKey struct{}
|
type connCtxKey struct{}
|
||||||
type rpcCtxKey struct{}
|
type rpcCtxKey struct{}
|
||||||
|
|
||||||
func TestTagConnCtx(t *testing.T) {
|
|
||||||
defer stats.RegisterConnTagger(nil)
|
|
||||||
ctx1 := context.Background()
|
|
||||||
stats.RegisterConnTagger(nil)
|
|
||||||
ctx2 := stats.TagConn(ctx1, nil)
|
|
||||||
if ctx2 != ctx1 {
|
|
||||||
t.Fatalf("nil conn ctx tagger should not modify context, got %v; want %v", ctx2, ctx1)
|
|
||||||
}
|
|
||||||
stats.RegisterConnTagger(func(ctx context.Context, info *stats.ConnTagInfo) context.Context {
|
|
||||||
return context.WithValue(ctx, connCtxKey{}, "connctxvalue")
|
|
||||||
})
|
|
||||||
ctx3 := stats.TagConn(ctx1, nil)
|
|
||||||
if v, ok := ctx3.Value(connCtxKey{}).(string); !ok || v != "connctxvalue" {
|
|
||||||
t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, connCtxKey{}, "connctxvalue"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTagRPCCtx(t *testing.T) {
|
|
||||||
defer stats.RegisterRPCTagger(nil)
|
|
||||||
ctx1 := context.Background()
|
|
||||||
stats.RegisterRPCTagger(nil)
|
|
||||||
ctx2 := stats.TagRPC(ctx1, nil)
|
|
||||||
if ctx2 != ctx1 {
|
|
||||||
t.Fatalf("nil rpc ctx tagger should not modify context, got %v; want %v", ctx2, ctx1)
|
|
||||||
}
|
|
||||||
stats.RegisterRPCTagger(func(ctx context.Context, info *stats.RPCTagInfo) context.Context {
|
|
||||||
return context.WithValue(ctx, rpcCtxKey{}, "rpcctxvalue")
|
|
||||||
})
|
|
||||||
ctx3 := stats.TagRPC(ctx1, nil)
|
|
||||||
if v, ok := ctx3.Value(rpcCtxKey{}).(string); !ok || v != "rpcctxvalue" {
|
|
||||||
t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, rpcCtxKey{}, "rpcctxvalue"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// For headers:
|
// For headers:
|
||||||
testMetadata = metadata.MD{
|
testMetadata = metadata.MD{
|
||||||
@ -160,6 +126,8 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ
|
|||||||
type test struct {
|
type test struct {
|
||||||
t *testing.T
|
t *testing.T
|
||||||
compress string
|
compress string
|
||||||
|
clientStatsHandler stats.Handler
|
||||||
|
serverStatsHandler stats.Handler
|
||||||
|
|
||||||
testServer testpb.TestServiceServer // nil means none
|
testServer testpb.TestServiceServer // nil means none
|
||||||
// srv and srvAddr are set once startServer is called.
|
// srv and srvAddr are set once startServer is called.
|
||||||
@ -184,8 +152,13 @@ type testConfig struct {
|
|||||||
// newTest returns a new test using the provided testing.T and
|
// newTest returns a new test using the provided testing.T and
|
||||||
// environment. It is returned with default values. Tests should
|
// environment. It is returned with default values. Tests should
|
||||||
// modify it before calling its startServer and clientConn methods.
|
// modify it before calling its startServer and clientConn methods.
|
||||||
func newTest(t *testing.T, tc *testConfig) *test {
|
func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test {
|
||||||
te := &test{t: t, compress: tc.compress}
|
te := &test{
|
||||||
|
t: t,
|
||||||
|
compress: tc.compress,
|
||||||
|
clientStatsHandler: ch,
|
||||||
|
serverStatsHandler: sh,
|
||||||
|
}
|
||||||
return te
|
return te
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,6 +177,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
|
|||||||
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
|
grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
if te.serverStatsHandler != nil {
|
||||||
|
opts = append(opts, grpc.StatsHandler(te.serverStatsHandler))
|
||||||
|
}
|
||||||
s := grpc.NewServer(opts...)
|
s := grpc.NewServer(opts...)
|
||||||
te.srv = s
|
te.srv = s
|
||||||
if te.testServer != nil {
|
if te.testServer != nil {
|
||||||
@ -230,6 +206,9 @@ func (te *test) clientConn() *grpc.ClientConn {
|
|||||||
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
|
grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
if te.clientStatsHandler != nil {
|
||||||
|
opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler))
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
te.cc, err = grpc.Dial(te.srvAddr, opts...)
|
te.cc, err = grpc.Dial(te.srvAddr, opts...)
|
||||||
@ -617,14 +596,32 @@ func checkConnEnd(t *testing.T, d *gotData, e *expectedData) {
|
|||||||
st.IsClient() // TODO remove this.
|
st.IsClient() // TODO remove this.
|
||||||
}
|
}
|
||||||
|
|
||||||
func tagConnCtx(ctx context.Context, info *stats.ConnTagInfo) context.Context {
|
type statshandler struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
gotRPC []*gotData
|
||||||
|
gotConn []*gotData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
|
||||||
return context.WithValue(ctx, connCtxKey{}, info)
|
return context.WithValue(ctx, connCtxKey{}, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
func tagRPCCtx(ctx context.Context, info *stats.RPCTagInfo) context.Context {
|
func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
|
||||||
return context.WithValue(ctx, rpcCtxKey{}, info)
|
return context.WithValue(ctx, rpcCtxKey{}, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s})
|
||||||
|
}
|
||||||
|
|
||||||
func checkConnStats(t *testing.T, got []*gotData) {
|
func checkConnStats(t *testing.T, got []*gotData) {
|
||||||
if len(got) <= 0 || len(got)%2 != 0 {
|
if len(got) <= 0 || len(got)%2 != 0 {
|
||||||
for i, g := range got {
|
for i, g := range got {
|
||||||
@ -662,30 +659,8 @@ func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkF
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
|
func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
|
||||||
var (
|
h := &statshandler{}
|
||||||
mu sync.Mutex
|
te := newTest(t, tc, nil, h)
|
||||||
gotRPC []*gotData
|
|
||||||
gotConn []*gotData
|
|
||||||
)
|
|
||||||
stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
if !s.IsClient() {
|
|
||||||
gotRPC = append(gotRPC, &gotData{ctx, false, s})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
if !s.IsClient() {
|
|
||||||
gotConn = append(gotConn, &gotData{ctx, false, s})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
stats.RegisterConnTagger(tagConnCtx)
|
|
||||||
stats.RegisterRPCTagger(tagRPCCtx)
|
|
||||||
stats.Start()
|
|
||||||
|
|
||||||
te := newTest(t, tc)
|
|
||||||
te.startServer(&testServer{})
|
te.startServer(&testServer{})
|
||||||
defer te.tearDown()
|
defer te.tearDown()
|
||||||
|
|
||||||
@ -709,22 +684,22 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f
|
|||||||
te.srv.GracefulStop() // Wait for the server to stop.
|
te.srv.GracefulStop() // Wait for the server to stop.
|
||||||
|
|
||||||
for {
|
for {
|
||||||
mu.Lock()
|
h.mu.Lock()
|
||||||
if len(gotRPC) >= len(checkFuncs) {
|
if len(h.gotRPC) >= len(checkFuncs) {
|
||||||
mu.Unlock()
|
h.mu.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
h.mu.Unlock()
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
mu.Lock()
|
h.mu.Lock()
|
||||||
if _, ok := gotConn[len(gotConn)-1].s.(*stats.ConnEnd); ok {
|
if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
|
||||||
mu.Unlock()
|
h.mu.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
h.mu.Unlock()
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -741,8 +716,8 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f
|
|||||||
expect.method = "/grpc.testing.TestService/FullDuplexCall"
|
expect.method = "/grpc.testing.TestService/FullDuplexCall"
|
||||||
}
|
}
|
||||||
|
|
||||||
checkConnStats(t, gotConn)
|
checkConnStats(t, h.gotConn)
|
||||||
checkServerStats(t, gotRPC, expect, checkFuncs)
|
checkServerStats(t, h.gotRPC, expect, checkFuncs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerStatsUnaryRPC(t *testing.T) {
|
func TestServerStatsUnaryRPC(t *testing.T) {
|
||||||
@ -891,30 +866,8 @@ func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkF
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
|
func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
|
||||||
var (
|
h := &statshandler{}
|
||||||
mu sync.Mutex
|
te := newTest(t, tc, h, nil)
|
||||||
gotRPC []*gotData
|
|
||||||
gotConn []*gotData
|
|
||||||
)
|
|
||||||
stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
if s.IsClient() {
|
|
||||||
gotRPC = append(gotRPC, &gotData{ctx, true, s})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
if s.IsClient() {
|
|
||||||
gotConn = append(gotConn, &gotData{ctx, true, s})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
stats.RegisterConnTagger(tagConnCtx)
|
|
||||||
stats.RegisterRPCTagger(tagRPCCtx)
|
|
||||||
stats.Start()
|
|
||||||
|
|
||||||
te := newTest(t, tc)
|
|
||||||
te.startServer(&testServer{})
|
te.startServer(&testServer{})
|
||||||
defer te.tearDown()
|
defer te.tearDown()
|
||||||
|
|
||||||
@ -942,22 +895,22 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map
|
|||||||
lenRPCStats += v.c
|
lenRPCStats += v.c
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
mu.Lock()
|
h.mu.Lock()
|
||||||
if len(gotRPC) >= lenRPCStats {
|
if len(h.gotRPC) >= lenRPCStats {
|
||||||
mu.Unlock()
|
h.mu.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
h.mu.Unlock()
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
mu.Lock()
|
h.mu.Lock()
|
||||||
if _, ok := gotConn[len(gotConn)-1].s.(*stats.ConnEnd); ok {
|
if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
|
||||||
mu.Unlock()
|
h.mu.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
h.mu.Unlock()
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -975,8 +928,8 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map
|
|||||||
expect.method = "/grpc.testing.TestService/FullDuplexCall"
|
expect.method = "/grpc.testing.TestService/FullDuplexCall"
|
||||||
}
|
}
|
||||||
|
|
||||||
checkConnStats(t, gotConn)
|
checkConnStats(t, h.gotConn)
|
||||||
checkClientStats(t, gotRPC, expect, checkFuncs)
|
checkClientStats(t, h.gotRPC, expect, checkFuncs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientStatsUnaryRPC(t *testing.T) {
|
func TestClientStatsUnaryRPC(t *testing.T) {
|
||||||
|
35
stream.go
35
stream.go
@ -151,23 +151,24 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
if stats.On() {
|
sh := cc.dopts.copts.StatsHandler
|
||||||
ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
|
if sh != nil {
|
||||||
|
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method})
|
||||||
begin := &stats.Begin{
|
begin := &stats.Begin{
|
||||||
Client: true,
|
Client: true,
|
||||||
BeginTime: time.Now(),
|
BeginTime: time.Now(),
|
||||||
FailFast: c.failFast,
|
FailFast: c.failFast,
|
||||||
}
|
}
|
||||||
stats.HandleRPC(ctx, begin)
|
sh.HandleRPC(ctx, begin)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil && stats.On() {
|
if err != nil && sh != nil {
|
||||||
// Only handle end stats if err != nil.
|
// Only handle end stats if err != nil.
|
||||||
end := &stats.End{
|
end := &stats.End{
|
||||||
Client: true,
|
Client: true,
|
||||||
Error: err,
|
Error: err,
|
||||||
}
|
}
|
||||||
stats.HandleRPC(ctx, end)
|
sh.HandleRPC(ctx, end)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
gopts := BalancerGetOptions{
|
gopts := BalancerGetOptions{
|
||||||
@ -224,6 +225,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||||||
trInfo: trInfo,
|
trInfo: trInfo,
|
||||||
|
|
||||||
statsCtx: ctx,
|
statsCtx: ctx,
|
||||||
|
statsHandler: cc.dopts.copts.StatsHandler,
|
||||||
}
|
}
|
||||||
if cc.dopts.cp != nil {
|
if cc.dopts.cp != nil {
|
||||||
cs.cbuf = new(bytes.Buffer)
|
cs.cbuf = new(bytes.Buffer)
|
||||||
@ -282,6 +284,7 @@ type clientStream struct {
|
|||||||
// All stats collection should use the statsCtx (instead of the stream context)
|
// All stats collection should use the statsCtx (instead of the stream context)
|
||||||
// so that all the generated stats for a particular RPC can be associated in the processing phase.
|
// so that all the generated stats for a particular RPC can be associated in the processing phase.
|
||||||
statsCtx context.Context
|
statsCtx context.Context
|
||||||
|
statsHandler stats.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) Context() context.Context {
|
func (cs *clientStream) Context() context.Context {
|
||||||
@ -335,7 +338,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||||||
err = toRPCErr(err)
|
err = toRPCErr(err)
|
||||||
}()
|
}()
|
||||||
var outPayload *stats.OutPayload
|
var outPayload *stats.OutPayload
|
||||||
if stats.On() {
|
if cs.statsHandler != nil {
|
||||||
outPayload = &stats.OutPayload{
|
outPayload = &stats.OutPayload{
|
||||||
Client: true,
|
Client: true,
|
||||||
}
|
}
|
||||||
@ -352,14 +355,14 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||||||
err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
|
err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
|
||||||
if err == nil && outPayload != nil {
|
if err == nil && outPayload != nil {
|
||||||
outPayload.SentTime = time.Now()
|
outPayload.SentTime = time.Now()
|
||||||
stats.HandleRPC(cs.statsCtx, outPayload)
|
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil && stats.On() {
|
if err != nil && cs.statsHandler != nil {
|
||||||
// Only generate End if err != nil.
|
// Only generate End if err != nil.
|
||||||
// If err == nil, it's not the last RecvMsg.
|
// If err == nil, it's not the last RecvMsg.
|
||||||
// The last RecvMsg gets either an RPC error or io.EOF.
|
// The last RecvMsg gets either an RPC error or io.EOF.
|
||||||
@ -370,11 +373,11 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
end.Error = toRPCErr(err)
|
end.Error = toRPCErr(err)
|
||||||
}
|
}
|
||||||
stats.HandleRPC(cs.statsCtx, end)
|
cs.statsHandler.HandleRPC(cs.statsCtx, end)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var inPayload *stats.InPayload
|
var inPayload *stats.InPayload
|
||||||
if stats.On() {
|
if cs.statsHandler != nil {
|
||||||
inPayload = &stats.InPayload{
|
inPayload = &stats.InPayload{
|
||||||
Client: true,
|
Client: true,
|
||||||
}
|
}
|
||||||
@ -395,7 +398,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||||||
cs.mu.Unlock()
|
cs.mu.Unlock()
|
||||||
}
|
}
|
||||||
if inPayload != nil {
|
if inPayload != nil {
|
||||||
stats.HandleRPC(cs.statsCtx, inPayload)
|
cs.statsHandler.HandleRPC(cs.statsCtx, inPayload)
|
||||||
}
|
}
|
||||||
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
|
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
|
||||||
return
|
return
|
||||||
@ -520,6 +523,8 @@ type serverStream struct {
|
|||||||
statusDesc string
|
statusDesc string
|
||||||
trInfo *traceInfo
|
trInfo *traceInfo
|
||||||
|
|
||||||
|
statsHandler stats.Handler
|
||||||
|
|
||||||
mu sync.Mutex // protects trInfo.tr after the service handler runs.
|
mu sync.Mutex // protects trInfo.tr after the service handler runs.
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -562,7 +567,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var outPayload *stats.OutPayload
|
var outPayload *stats.OutPayload
|
||||||
if stats.On() {
|
if ss.statsHandler != nil {
|
||||||
outPayload = &stats.OutPayload{}
|
outPayload = &stats.OutPayload{}
|
||||||
}
|
}
|
||||||
out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
|
out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
|
||||||
@ -580,7 +585,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
if outPayload != nil {
|
if outPayload != nil {
|
||||||
outPayload.SentTime = time.Now()
|
outPayload.SentTime = time.Now()
|
||||||
stats.HandleRPC(ss.s.Context(), outPayload)
|
ss.statsHandler.HandleRPC(ss.s.Context(), outPayload)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -601,7 +606,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var inPayload *stats.InPayload
|
var inPayload *stats.InPayload
|
||||||
if stats.On() {
|
if ss.statsHandler != nil {
|
||||||
inPayload = &stats.InPayload{}
|
inPayload = &stats.InPayload{}
|
||||||
}
|
}
|
||||||
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil {
|
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, inPayload); err != nil {
|
||||||
@ -614,7 +619,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
if inPayload != nil {
|
if inPayload != nil {
|
||||||
stats.HandleRPC(ss.s.Context(), inPayload)
|
ss.statsHandler.HandleRPC(ss.s.Context(), inPayload)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -99,6 +99,8 @@ type http2Client struct {
|
|||||||
|
|
||||||
creds []credentials.PerRPCCredentials
|
creds []credentials.PerRPCCredentials
|
||||||
|
|
||||||
|
statsHandler stats.Handler
|
||||||
|
|
||||||
mu sync.Mutex // guard the following variables
|
mu sync.Mutex // guard the following variables
|
||||||
state transportState // the state of underlying connection
|
state transportState // the state of underlying connection
|
||||||
activeStreams map[uint32]*Stream
|
activeStreams map[uint32]*Stream
|
||||||
@ -208,16 +210,17 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
|
|||||||
creds: opts.PerRPCCredentials,
|
creds: opts.PerRPCCredentials,
|
||||||
maxStreams: math.MaxInt32,
|
maxStreams: math.MaxInt32,
|
||||||
streamSendQuota: defaultWindowSize,
|
streamSendQuota: defaultWindowSize,
|
||||||
|
statsHandler: opts.StatsHandler,
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if t.statsHandler != nil {
|
||||||
t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{
|
t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{
|
||||||
RemoteAddr: t.remoteAddr,
|
RemoteAddr: t.remoteAddr,
|
||||||
LocalAddr: t.localAddr,
|
LocalAddr: t.localAddr,
|
||||||
})
|
})
|
||||||
connBegin := &stats.ConnBegin{
|
connBegin := &stats.ConnBegin{
|
||||||
Client: true,
|
Client: true,
|
||||||
}
|
}
|
||||||
stats.HandleConn(t.ctx, connBegin)
|
t.statsHandler.HandleConn(t.ctx, connBegin)
|
||||||
}
|
}
|
||||||
// Start the reader goroutine for incoming message. Each transport has
|
// Start the reader goroutine for incoming message. Each transport has
|
||||||
// a dedicated goroutine which reads HTTP2 frame from network. Then it
|
// a dedicated goroutine which reads HTTP2 frame from network. Then it
|
||||||
@ -470,7 +473,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
return nil, connectionErrorf(true, err, "transport: %v", err)
|
return nil, connectionErrorf(true, err, "transport: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if t.statsHandler != nil {
|
||||||
outHeader := &stats.OutHeader{
|
outHeader := &stats.OutHeader{
|
||||||
Client: true,
|
Client: true,
|
||||||
WireLength: bufLen,
|
WireLength: bufLen,
|
||||||
@ -479,7 +482,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||||||
LocalAddr: t.localAddr,
|
LocalAddr: t.localAddr,
|
||||||
Compression: callHdr.SendCompress,
|
Compression: callHdr.SendCompress,
|
||||||
}
|
}
|
||||||
stats.HandleRPC(s.clientStatsCtx, outHeader)
|
t.statsHandler.HandleRPC(s.clientStatsCtx, outHeader)
|
||||||
}
|
}
|
||||||
t.writableChan <- 0
|
t.writableChan <- 0
|
||||||
return s, nil
|
return s, nil
|
||||||
@ -559,11 +562,11 @@ func (t *http2Client) Close() (err error) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
s.write(recvMsg{err: ErrConnClosing})
|
s.write(recvMsg{err: ErrConnClosing})
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if t.statsHandler != nil {
|
||||||
connEnd := &stats.ConnEnd{
|
connEnd := &stats.ConnEnd{
|
||||||
Client: true,
|
Client: true,
|
||||||
}
|
}
|
||||||
stats.HandleConn(t.ctx, connEnd)
|
t.statsHandler.HandleConn(t.ctx, connEnd)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -911,19 +914,19 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
|||||||
endStream := frame.StreamEnded()
|
endStream := frame.StreamEnded()
|
||||||
var isHeader bool
|
var isHeader bool
|
||||||
defer func() {
|
defer func() {
|
||||||
if stats.On() {
|
if t.statsHandler != nil {
|
||||||
if isHeader {
|
if isHeader {
|
||||||
inHeader := &stats.InHeader{
|
inHeader := &stats.InHeader{
|
||||||
Client: true,
|
Client: true,
|
||||||
WireLength: int(frame.Header().Length),
|
WireLength: int(frame.Header().Length),
|
||||||
}
|
}
|
||||||
stats.HandleRPC(s.clientStatsCtx, inHeader)
|
t.statsHandler.HandleRPC(s.clientStatsCtx, inHeader)
|
||||||
} else {
|
} else {
|
||||||
inTrailer := &stats.InTrailer{
|
inTrailer := &stats.InTrailer{
|
||||||
Client: true,
|
Client: true,
|
||||||
WireLength: int(frame.Header().Length),
|
WireLength: int(frame.Header().Length),
|
||||||
}
|
}
|
||||||
stats.HandleRPC(s.clientStatsCtx, inTrailer)
|
t.statsHandler.HandleRPC(s.clientStatsCtx, inTrailer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -88,6 +88,8 @@ type http2Server struct {
|
|||||||
// sendQuotaPool provides flow control to outbound message.
|
// sendQuotaPool provides flow control to outbound message.
|
||||||
sendQuotaPool *quotaPool
|
sendQuotaPool *quotaPool
|
||||||
|
|
||||||
|
stats stats.Handler
|
||||||
|
|
||||||
mu sync.Mutex // guard the following
|
mu sync.Mutex // guard the following
|
||||||
state transportState
|
state transportState
|
||||||
activeStreams map[uint32]*Stream
|
activeStreams map[uint32]*Stream
|
||||||
@ -146,14 +148,15 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
|
|||||||
shutdownChan: make(chan struct{}),
|
shutdownChan: make(chan struct{}),
|
||||||
activeStreams: make(map[uint32]*Stream),
|
activeStreams: make(map[uint32]*Stream),
|
||||||
streamSendQuota: defaultWindowSize,
|
streamSendQuota: defaultWindowSize,
|
||||||
|
stats: config.StatsHandler,
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if t.stats != nil {
|
||||||
t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{
|
t.ctx = t.stats.TagConn(t.ctx, &stats.ConnTagInfo{
|
||||||
RemoteAddr: t.remoteAddr,
|
RemoteAddr: t.remoteAddr,
|
||||||
LocalAddr: t.localAddr,
|
LocalAddr: t.localAddr,
|
||||||
})
|
})
|
||||||
connBegin := &stats.ConnBegin{}
|
connBegin := &stats.ConnBegin{}
|
||||||
stats.HandleConn(t.ctx, connBegin)
|
t.stats.HandleConn(t.ctx, connBegin)
|
||||||
}
|
}
|
||||||
go t.controller()
|
go t.controller()
|
||||||
t.writableChan <- 0
|
t.writableChan <- 0
|
||||||
@ -250,8 +253,8 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
t.updateWindow(s, uint32(n))
|
t.updateWindow(s, uint32(n))
|
||||||
}
|
}
|
||||||
s.ctx = traceCtx(s.ctx, s.method)
|
s.ctx = traceCtx(s.ctx, s.method)
|
||||||
if stats.On() {
|
if t.stats != nil {
|
||||||
s.ctx = stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
||||||
inHeader := &stats.InHeader{
|
inHeader := &stats.InHeader{
|
||||||
FullMethod: s.method,
|
FullMethod: s.method,
|
||||||
RemoteAddr: t.remoteAddr,
|
RemoteAddr: t.remoteAddr,
|
||||||
@ -259,7 +262,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
Compression: s.recvCompress,
|
Compression: s.recvCompress,
|
||||||
WireLength: int(frame.Header().Length),
|
WireLength: int(frame.Header().Length),
|
||||||
}
|
}
|
||||||
stats.HandleRPC(s.ctx, inHeader)
|
t.stats.HandleRPC(s.ctx, inHeader)
|
||||||
}
|
}
|
||||||
handle(s)
|
handle(s)
|
||||||
return
|
return
|
||||||
@ -540,11 +543,11 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
|
|||||||
if err := t.writeHeaders(s, t.hBuf, false); err != nil {
|
if err := t.writeHeaders(s, t.hBuf, false); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if t.stats != nil {
|
||||||
outHeader := &stats.OutHeader{
|
outHeader := &stats.OutHeader{
|
||||||
WireLength: bufLen,
|
WireLength: bufLen,
|
||||||
}
|
}
|
||||||
stats.HandleRPC(s.Context(), outHeader)
|
t.stats.HandleRPC(s.Context(), outHeader)
|
||||||
}
|
}
|
||||||
t.writableChan <- 0
|
t.writableChan <- 0
|
||||||
return nil
|
return nil
|
||||||
@ -603,11 +606,11 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
|
|||||||
t.Close()
|
t.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if t.stats != nil {
|
||||||
outTrailer := &stats.OutTrailer{
|
outTrailer := &stats.OutTrailer{
|
||||||
WireLength: bufLen,
|
WireLength: bufLen,
|
||||||
}
|
}
|
||||||
stats.HandleRPC(s.Context(), outTrailer)
|
t.stats.HandleRPC(s.Context(), outTrailer)
|
||||||
}
|
}
|
||||||
t.closeStream(s)
|
t.closeStream(s)
|
||||||
t.writableChan <- 0
|
t.writableChan <- 0
|
||||||
@ -789,9 +792,9 @@ func (t *http2Server) Close() (err error) {
|
|||||||
for _, s := range streams {
|
for _, s := range streams {
|
||||||
s.cancel()
|
s.cancel()
|
||||||
}
|
}
|
||||||
if stats.On() {
|
if t.stats != nil {
|
||||||
connEnd := &stats.ConnEnd{}
|
connEnd := &stats.ConnEnd{}
|
||||||
stats.HandleConn(t.ctx, connEnd)
|
t.stats.HandleConn(t.ctx, connEnd)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -48,6 +48,7 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/stats"
|
||||||
"google.golang.org/grpc/tap"
|
"google.golang.org/grpc/tap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -360,6 +361,7 @@ type ServerConfig struct {
|
|||||||
MaxStreams uint32
|
MaxStreams uint32
|
||||||
AuthInfo credentials.AuthInfo
|
AuthInfo credentials.AuthInfo
|
||||||
InTapHandle tap.ServerInHandle
|
InTapHandle tap.ServerInHandle
|
||||||
|
StatsHandler stats.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerTransport creates a ServerTransport with conn or non-nil error
|
// NewServerTransport creates a ServerTransport with conn or non-nil error
|
||||||
@ -380,6 +382,8 @@ type ConnectOptions struct {
|
|||||||
PerRPCCredentials []credentials.PerRPCCredentials
|
PerRPCCredentials []credentials.PerRPCCredentials
|
||||||
// TransportCredentials stores the Authenticator required to setup a client connection.
|
// TransportCredentials stores the Authenticator required to setup a client connection.
|
||||||
TransportCredentials credentials.TransportCredentials
|
TransportCredentials credentials.TransportCredentials
|
||||||
|
// StatsHandler stores the handler for stats.
|
||||||
|
StatsHandler stats.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// TargetInfo contains the information of the target such as network address and metadata.
|
// TargetInfo contains the information of the target such as network address and metadata.
|
||||||
|
Reference in New Issue
Block a user