transport: support stats.Handler in serverHandlerTransport (#1840)
This commit is contained in:
@ -694,7 +694,7 @@ func (s *Server) serveUsingHandler(conn net.Conn) {
|
||||
// available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL
|
||||
// and subject to change.
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
st, err := transport.NewServerHandlerTransport(w, r)
|
||||
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@ -40,13 +40,14 @@ import (
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/stats"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// NewServerHandlerTransport returns a ServerTransport handling gRPC
|
||||
// from inside an http.Handler. It requires that the http Server
|
||||
// supports HTTP/2.
|
||||
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
|
||||
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
|
||||
if r.ProtoMajor != 2 {
|
||||
return nil, errors.New("gRPC requires HTTP/2")
|
||||
}
|
||||
@ -73,6 +74,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
|
||||
writes: make(chan func()),
|
||||
contentType: contentType,
|
||||
contentSubtype: contentSubtype,
|
||||
stats: stats,
|
||||
}
|
||||
|
||||
if v := r.Header.Get("grpc-timeout"); v != "" {
|
||||
@ -137,6 +139,8 @@ type serverHandlerTransport struct {
|
||||
// we store both contentType and contentSubtype so we don't keep recreating them
|
||||
// TODO make sure this is consistent across handler_server and http2_server
|
||||
contentSubtype string
|
||||
|
||||
stats stats.Handler
|
||||
}
|
||||
|
||||
func (ht *serverHandlerTransport) Close() error {
|
||||
@ -230,6 +234,9 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
|
||||
})
|
||||
|
||||
if err == nil { // transport has not been closed
|
||||
if ht.stats != nil {
|
||||
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
|
||||
}
|
||||
ht.Close()
|
||||
close(ht.writes)
|
||||
}
|
||||
@ -274,7 +281,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts
|
||||
}
|
||||
|
||||
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
||||
return ht.do(func() {
|
||||
err := ht.do(func() {
|
||||
ht.writeCommonHeaders(s)
|
||||
h := ht.rw.Header()
|
||||
for k, vv := range md {
|
||||
@ -290,6 +297,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
||||
ht.rw.WriteHeader(200)
|
||||
ht.rw.(http.Flusher).Flush()
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
if ht.stats != nil {
|
||||
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
|
||||
@ -342,6 +356,15 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
||||
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
||||
ctx = peer.NewContext(ctx, pr)
|
||||
s.ctx = newContextWithStream(ctx, s)
|
||||
if ht.stats != nil {
|
||||
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
||||
inHeader := &stats.InHeader{
|
||||
FullMethod: s.method,
|
||||
RemoteAddr: ht.RemoteAddr(),
|
||||
Compression: s.recvCompress,
|
||||
}
|
||||
ht.stats.HandleRPC(s.ctx, inHeader)
|
||||
}
|
||||
s.trReader = &transportReader{
|
||||
reader: &recvBufferReader{ctx: s.ctx, recv: s.buf},
|
||||
windowHandler: func(int) {},
|
||||
|
||||
@ -218,7 +218,7 @@ func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
|
||||
if tt.modrw != nil {
|
||||
rw = tt.modrw(rw)
|
||||
}
|
||||
got, gotErr := NewServerHandlerTransport(rw, tt.req)
|
||||
got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
|
||||
if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
|
||||
t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr)
|
||||
continue
|
||||
@ -272,7 +272,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
|
||||
Body: bodyr,
|
||||
}
|
||||
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
||||
ht, err := NewServerHandlerTransport(rw, req)
|
||||
ht, err := NewServerHandlerTransport(rw, req, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -357,7 +357,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
||||
Body: bodyr,
|
||||
}
|
||||
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
||||
ht, err := NewServerHandlerTransport(rw, req)
|
||||
ht, err := NewServerHandlerTransport(rw, req, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user