diff --git a/server.go b/server.go index b80594f9..0f7ff5d6 100644 --- a/server.go +++ b/server.go @@ -694,7 +694,7 @@ func (s *Server) serveUsingHandler(conn net.Conn) { // available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL // and subject to change. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - st, err := transport.NewServerHandlerTransport(w, r) + st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/transport/handler_server.go b/transport/handler_server.go index ce8ebece..451d7e62 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -40,13 +40,14 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" ) // NewServerHandlerTransport returns a ServerTransport handling gRPC // from inside an http.Handler. It requires that the http Server // supports HTTP/2. -func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) { +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) { if r.ProtoMajor != 2 { return nil, errors.New("gRPC requires HTTP/2") } @@ -73,6 +74,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr writes: make(chan func()), contentType: contentType, contentSubtype: contentSubtype, + stats: stats, } if v := r.Header.Get("grpc-timeout"); v != "" { @@ -137,6 +139,8 @@ type serverHandlerTransport struct { // we store both contentType and contentSubtype so we don't keep recreating them // TODO make sure this is consistent across handler_server and http2_server contentSubtype string + + stats stats.Handler } func (ht *serverHandlerTransport) Close() error { @@ -230,6 +234,9 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro }) if err == nil { // transport has not been closed + if ht.stats != nil { + ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) + } ht.Close() close(ht.writes) } @@ -274,7 +281,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts } func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { - return ht.do(func() { + err := ht.do(func() { ht.writeCommonHeaders(s) h := ht.rw.Header() for k, vv := range md { @@ -290,6 +297,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { ht.rw.WriteHeader(200) ht.rw.(http.Flusher).Flush() }) + + if err == nil { + if ht.stats != nil { + ht.stats.HandleRPC(s.Context(), &stats.OutHeader{}) + } + } + return err } func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { @@ -342,6 +356,15 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace ctx = metadata.NewIncomingContext(ctx, ht.headerMD) ctx = peer.NewContext(ctx, pr) s.ctx = newContextWithStream(ctx, s) + if ht.stats != nil { + s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) + inHeader := &stats.InHeader{ + FullMethod: s.method, + RemoteAddr: ht.RemoteAddr(), + Compression: s.recvCompress, + } + ht.stats.HandleRPC(s.ctx, inHeader) + } s.trReader = &transportReader{ reader: &recvBufferReader{ctx: s.ctx, recv: s.buf}, windowHandler: func(int) {}, diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index b7e9120e..3261b8e3 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -218,7 +218,7 @@ func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { if tt.modrw != nil { rw = tt.modrw(rw) } - got, gotErr := NewServerHandlerTransport(rw, tt.req) + got, gotErr := NewServerHandlerTransport(rw, tt.req, nil) if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) { t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr) continue @@ -272,7 +272,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest { Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) - ht, err := NewServerHandlerTransport(rw, req) + ht, err := NewServerHandlerTransport(rw, req, nil) if err != nil { t.Fatal(err) } @@ -357,7 +357,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) - ht, err := NewServerHandlerTransport(rw, req) + ht, err := NewServerHandlerTransport(rw, req, nil) if err != nil { t.Fatal(err) }