transport: support stats.Handler in serverHandlerTransport (#1840)
This commit is contained in:
@ -694,7 +694,7 @@ func (s *Server) serveUsingHandler(conn net.Conn) {
|
|||||||
// available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL
|
// available through grpc-go's HTTP/2 server, and it is currently EXPERIMENTAL
|
||||||
// and subject to change.
|
// and subject to change.
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
st, err := transport.NewServerHandlerTransport(w, r)
|
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -40,13 +40,14 @@ import (
|
|||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
|
"google.golang.org/grpc/stats"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewServerHandlerTransport returns a ServerTransport handling gRPC
|
// NewServerHandlerTransport returns a ServerTransport handling gRPC
|
||||||
// from inside an http.Handler. It requires that the http Server
|
// from inside an http.Handler. It requires that the http Server
|
||||||
// supports HTTP/2.
|
// supports HTTP/2.
|
||||||
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) {
|
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler) (ServerTransport, error) {
|
||||||
if r.ProtoMajor != 2 {
|
if r.ProtoMajor != 2 {
|
||||||
return nil, errors.New("gRPC requires HTTP/2")
|
return nil, errors.New("gRPC requires HTTP/2")
|
||||||
}
|
}
|
||||||
@ -73,6 +74,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
|
|||||||
writes: make(chan func()),
|
writes: make(chan func()),
|
||||||
contentType: contentType,
|
contentType: contentType,
|
||||||
contentSubtype: contentSubtype,
|
contentSubtype: contentSubtype,
|
||||||
|
stats: stats,
|
||||||
}
|
}
|
||||||
|
|
||||||
if v := r.Header.Get("grpc-timeout"); v != "" {
|
if v := r.Header.Get("grpc-timeout"); v != "" {
|
||||||
@ -137,6 +139,8 @@ type serverHandlerTransport struct {
|
|||||||
// we store both contentType and contentSubtype so we don't keep recreating them
|
// we store both contentType and contentSubtype so we don't keep recreating them
|
||||||
// TODO make sure this is consistent across handler_server and http2_server
|
// TODO make sure this is consistent across handler_server and http2_server
|
||||||
contentSubtype string
|
contentSubtype string
|
||||||
|
|
||||||
|
stats stats.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) Close() error {
|
func (ht *serverHandlerTransport) Close() error {
|
||||||
@ -230,6 +234,9 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err == nil { // transport has not been closed
|
if err == nil { // transport has not been closed
|
||||||
|
if ht.stats != nil {
|
||||||
|
ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{})
|
||||||
|
}
|
||||||
ht.Close()
|
ht.Close()
|
||||||
close(ht.writes)
|
close(ht.writes)
|
||||||
}
|
}
|
||||||
@ -274,7 +281,7 @@ func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
||||||
return ht.do(func() {
|
err := ht.do(func() {
|
||||||
ht.writeCommonHeaders(s)
|
ht.writeCommonHeaders(s)
|
||||||
h := ht.rw.Header()
|
h := ht.rw.Header()
|
||||||
for k, vv := range md {
|
for k, vv := range md {
|
||||||
@ -290,6 +297,13 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
|||||||
ht.rw.WriteHeader(200)
|
ht.rw.WriteHeader(200)
|
||||||
ht.rw.(http.Flusher).Flush()
|
ht.rw.(http.Flusher).Flush()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
if ht.stats != nil {
|
||||||
|
ht.stats.HandleRPC(s.Context(), &stats.OutHeader{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
|
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
|
||||||
@ -342,6 +356,15 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
|||||||
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
||||||
ctx = peer.NewContext(ctx, pr)
|
ctx = peer.NewContext(ctx, pr)
|
||||||
s.ctx = newContextWithStream(ctx, s)
|
s.ctx = newContextWithStream(ctx, s)
|
||||||
|
if ht.stats != nil {
|
||||||
|
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
||||||
|
inHeader := &stats.InHeader{
|
||||||
|
FullMethod: s.method,
|
||||||
|
RemoteAddr: ht.RemoteAddr(),
|
||||||
|
Compression: s.recvCompress,
|
||||||
|
}
|
||||||
|
ht.stats.HandleRPC(s.ctx, inHeader)
|
||||||
|
}
|
||||||
s.trReader = &transportReader{
|
s.trReader = &transportReader{
|
||||||
reader: &recvBufferReader{ctx: s.ctx, recv: s.buf},
|
reader: &recvBufferReader{ctx: s.ctx, recv: s.buf},
|
||||||
windowHandler: func(int) {},
|
windowHandler: func(int) {},
|
||||||
|
|||||||
@ -218,7 +218,7 @@ func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
|
|||||||
if tt.modrw != nil {
|
if tt.modrw != nil {
|
||||||
rw = tt.modrw(rw)
|
rw = tt.modrw(rw)
|
||||||
}
|
}
|
||||||
got, gotErr := NewServerHandlerTransport(rw, tt.req)
|
got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
|
||||||
if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
|
if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
|
||||||
t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr)
|
t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr)
|
||||||
continue
|
continue
|
||||||
@ -272,7 +272,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
|
|||||||
Body: bodyr,
|
Body: bodyr,
|
||||||
}
|
}
|
||||||
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
||||||
ht, err := NewServerHandlerTransport(rw, req)
|
ht, err := NewServerHandlerTransport(rw, req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -357,7 +357,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
|||||||
Body: bodyr,
|
Body: bodyr,
|
||||||
}
|
}
|
||||||
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
|
||||||
ht, err := NewServerHandlerTransport(rw, req)
|
ht, err := NewServerHandlerTransport(rw, req, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user