From 110fd99e303b7a339c5e56a54721138d3eb7da00 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 24 Feb 2016 08:32:08 -0800 Subject: [PATCH] Fix crashes where transports returned errors unhandled by the message parser. The http.Handler-based transport body reader was returning error types not understood by the recvMsg parser. See #557 for some background and examples. Fix the http.Handler transport and add tests. I copied in a subset of the http2 package's serverTest type, adapted slightly to work with grpc. In the process of adding tests, I discovered that ErrUnexpectedEOF was also not handled by the regular server transport. Document the rules and fix that crash as well. Unrelated stuff in this CL: * make tests listen on localhost:0 instead of :0, to avoid Mac firewall pop-up dialogs. * rename parser.s field to parser.r, to be more idiomatic that it's an io.Reader and not anything fancier. (it's not acting like type stream, even if that's the typical concrete type) * move 5 byte temp buffer into parser, rather than allocating it for each new message. (drop in the bucket improvement in garbage; more to do later) * rename http2RSTErrConvTab to http2ErrConvTab, per Qi's earlier CL. Also add the HTTP/1.1-required error mapping for completeness, not that it should ever arise with gRPC, also per Qi's earlier CL referenced in #557. --- call.go | 2 +- call_test.go | 6 +- rpc_util.go | 38 +++-- rpc_util_test.go | 4 +- server.go | 7 +- stream.go | 4 +- test/end2end_test.go | 175 +++++++++++++++++++++- test/servertester_test.go | 289 ++++++++++++++++++++++++++++++++++++ transport/handler_server.go | 25 +++- transport/http2_client.go | 2 +- transport/http_util.go | 5 +- transport/transport_test.go | 6 +- 12 files changed, 532 insertions(+), 31 deletions(-) create mode 100644 test/servertester_test.go diff --git a/call.go b/call.go index d4ae68be..504a6e18 100644 --- a/call.go +++ b/call.go @@ -55,7 +55,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s if err != nil { return err } - p := &parser{s: stream} + p := &parser{r: stream} for { if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil { if err == io.EOF { diff --git a/call_test.go b/call_test.go index 58cef3cd..fb587e35 100644 --- a/call_test.go +++ b/call_test.go @@ -75,7 +75,7 @@ type testStreamHandler struct { } func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { - p := &parser{s: s} + p := &parser{r: s} for { pf, req, err := p.recvMsg() if err == io.EOF { @@ -125,9 +125,9 @@ func newTestServer() *server { func (s *server) start(t *testing.T, port int, maxStreams uint32) { var err error if port == 0 { - s.lis, err = net.Listen("tcp", ":0") + s.lis, err = net.Listen("tcp", "localhost:0") } else { - s.lis, err = net.Listen("tcp", ":"+strconv.Itoa(port)) + s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port)) } if err != nil { s.startedErr <- fmt.Errorf("failed to listen: %v", err) diff --git a/rpc_util.go b/rpc_util.go index fadf3394..66d34eb9 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -191,30 +191,44 @@ const ( // parser reads complelete gRPC messages from the underlying reader. type parser struct { - s io.Reader -} + // r is the underlying reader. + // See the comment on recvMsg for the permissible + // error types. + r io.Reader -// recvMsg is to read a complete gRPC message from the stream. It is blocking if -// the message has not been complete yet. It returns the message and its type, -// EOF is returned with nil msg and 0 pf if the entire stream is done. Other -// non-nil error is returned if something is wrong on reading. -func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { // The header of a gRPC message. Find more detail // at http://www.grpc.io/docs/guides/wire.html. - var buf [5]byte + header [5]byte +} - if _, err := io.ReadFull(p.s, buf[:]); err != nil { +// recvMsg reads a complete gRPC message from the stream. +// +// It returns the message and its payload (compression/encoding) +// format. The caller owns the returned msg memory. +// +// If there is an error, possible values are: +// * io.EOF, when no messages remain +// * io.ErrUnexpectedEOF +// * of type transport.ConnectionError +// * of type transport.StreamError +// No other error values or types must be returned, which also means +// that the underlying io.Reader must not return an incompatible +// error. +func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { + if _, err := io.ReadFull(p.r, p.header[:]); err != nil { return 0, nil, err } - pf = payloadFormat(buf[0]) - length := binary.BigEndian.Uint32(buf[1:]) + pf = payloadFormat(p.header[0]) + length := binary.BigEndian.Uint32(p.header[1:]) if length == 0 { return pf, nil, nil } + // TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead + // of making it for each message: msg = make([]byte, int(length)) - if _, err := io.ReadFull(p.s, msg); err != nil { + if _, err := io.ReadFull(p.r, msg); err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } diff --git a/rpc_util_test.go b/rpc_util_test.go index 3f3749ae..f6327f13 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -65,7 +65,7 @@ func TestSimpleParsing(t *testing.T) { {append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone}, } { buf := bytes.NewReader(test.p) - parser := &parser{buf} + parser := &parser{r: buf} pt, b, err := parser.recvMsg() if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt { t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) @@ -77,7 +77,7 @@ func TestMultipleParsing(t *testing.T) { // Set a byte stream consists of 3 messages with their headers. p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'} b := bytes.NewReader(p) - parser := &parser{b} + parser := &parser{r: b} wantRecvs := []struct { pt payloadFormat diff --git a/server.go b/server.go index eb56b344..ec4485c2 100644 --- a/server.go +++ b/server.go @@ -446,13 +446,16 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } }() } - p := &parser{s: stream} + p := &parser{r: stream} for { pf, req, err := p.recvMsg() if err == io.EOF { // The entire stream is done (for unary RPC only). return err } + if err == io.ErrUnexpectedEOF { + err = transport.StreamError{Code: codes.Internal, Desc: "io.ErrUnexpectedEOF"} + } if err != nil { switch err := err.(type) { case transport.ConnectionError: @@ -558,7 +561,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss := &serverStream{ t: t, s: stream, - p: &parser{s: stream}, + p: &parser{r: stream}, codec: s.opts.codec, cp: s.opts.cp, dc: s.opts.dc, diff --git a/stream.go b/stream.go index ea685cc1..9cf3c010 100644 --- a/stream.go +++ b/stream.go @@ -109,7 +109,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth callHdr := &transport.CallHdr{ Host: cc.authority, Method: method, - Flush: desc.ServerStreams&&desc.ClientStreams, + Flush: desc.ServerStreams && desc.ClientStreams, } if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() @@ -141,7 +141,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } cs.t = t cs.s = s - cs.p = &parser{s: s} + cs.p = &parser{r: s} // Listen on ctx.Done() to detect cancellation when there is no pending // I/O operations on this stream. go func() { diff --git a/test/end2end_test.go b/test/end2end_test.go index d5f63ee8..035cbd9e 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -35,6 +35,8 @@ package grpc_test import ( "bytes" + "crypto/tls" + "errors" "flag" "fmt" "io" @@ -53,6 +55,7 @@ import ( "github.com/golang/protobuf/proto" "golang.org/x/net/context" + "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -62,6 +65,7 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" testpb "google.golang.org/grpc/test/grpc_testing" + "google.golang.org/grpc/transport" ) var ( @@ -290,7 +294,7 @@ func TestReconnectTimeout(t *testing.T) { ) defer restore() - lis, err := net.Listen("tcp", ":0") + lis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen: %v", err) } @@ -354,6 +358,15 @@ func (e env) runnable() bool { return true } +func (e env) getDialer() func(addr string, timeout time.Duration) (net.Conn, error) { + if e.dialer != nil { + return e.dialer + } + return func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("tcp", addr, timeout) + } +} + var ( tcpClearEnv = env{name: "tcp-clear", network: "tcp"} tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} @@ -451,7 +464,7 @@ func (te *test) startServer() { ) } - la := ":0" + la := "localhost:0" switch e.network { case "unix": la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now()) @@ -530,6 +543,25 @@ func (te *test) declareLogNoise(phrases ...string) { te.restoreLogs = declareLogNoise(te.t, phrases...) } +func (te *test) withServerTester(fn func(st *serverTester)) { + var c net.Conn + var err error + c, err = te.e.getDialer()(te.srvAddr, 10*time.Second) + if err != nil { + te.t.Fatal(err) + } + defer c.Close() + if te.e.security == "tls" { + c = tls.Client(c, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{http2.NextProtoTLS}, + }) + } + st := newServerTesterFromConn(te.t, c) + st.greet() + fn(st) +} + func TestTimeoutOnDeadServer(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() { @@ -1613,6 +1645,145 @@ func testCompressOK(t *testing.T, e env) { } } +// funcServer implements methods of TestServiceServer using funcs, +// similar to an http.HandlerFunc. +// Any unimplemented method will crash. Tests implement the method(s) +// they need. +type funcServer struct { + testpb.TestServiceServer + unaryCall func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) + streamingInputCall func(stream testpb.TestService_StreamingInputCallServer) error +} + +func (s *funcServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return s.unaryCall(ctx, in) +} + +func (s *funcServer) StreamingInputCall(stream testpb.TestService_StreamingInputCallServer) error { + return s.streamingInputCall(stream) +} + +func TestClientRequestBodyError_UnexpectedEOF(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testClientRequestBodyError_UnexpectedEOF(t, e) + } +} + +func testClientRequestBodyError_UnexpectedEOF(t *testing.T, e env) { + te := newTest(t, e) + te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + errUnexpectedCall := errors.New("unexpected call func server method") + t.Error(errUnexpectedCall) + return nil, errUnexpectedCall + }} + te.startServer() + defer te.tearDown() + te.withServerTester(func(st *serverTester) { + st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") + // Say we have 5 bytes coming, but set END_STREAM flag: + st.writeData(1, true, []byte{0, 0, 0, 0, 5}) + st.wantAnyFrame() // wait for server to crash (it used to crash) + }) +} + +func TestClientRequestBodyError_CloseAfterLength(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testClientRequestBodyError_CloseAfterLength(t, e) + } +} + +func testClientRequestBodyError_CloseAfterLength(t *testing.T, e env) { + te := newTest(t, e) + te.declareLogNoise("Server.processUnaryRPC failed to write status") + te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + errUnexpectedCall := errors.New("unexpected call func server method") + t.Error(errUnexpectedCall) + return nil, errUnexpectedCall + }} + te.startServer() + defer te.tearDown() + te.withServerTester(func(st *serverTester) { + st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") + // say we're sending 5 bytes, but then close the connection instead. + st.writeData(1, false, []byte{0, 0, 0, 0, 5}) + st.cc.Close() + }) +} + +func TestClientRequestBodyError_Cancel(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testClientRequestBodyError_Cancel(t, e) + } +} + +func testClientRequestBodyError_Cancel(t *testing.T, e env) { + te := newTest(t, e) + gotCall := make(chan bool, 1) + te.testServer = &funcServer{unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + gotCall <- true + return new(testpb.SimpleResponse), nil + }} + te.startServer() + defer te.tearDown() + te.withServerTester(func(st *serverTester) { + st.writeHeadersGRPC(1, "/grpc.testing.TestService/UnaryCall") + // Say we have 5 bytes coming, but cancel it instead. + st.writeData(1, false, []byte{0, 0, 0, 0, 5}) + st.writeRSTStream(1, http2.ErrCodeCancel) + + // Verify we didn't a call yet. + select { + case <-gotCall: + t.Fatal("unexpected call") + default: + } + + // And now send an uncanceled (but still invalid), just to get a response. + st.writeHeadersGRPC(3, "/grpc.testing.TestService/UnaryCall") + st.writeData(3, true, []byte{0, 0, 0, 0, 0}) + <-gotCall + st.wantAnyFrame() + }) +} + +func TestClientRequestBodyError_Cancel_StreamingInput(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testClientRequestBodyError_Cancel_StreamingInput(t, e) + } +} + +func testClientRequestBodyError_Cancel_StreamingInput(t *testing.T, e env) { + te := newTest(t, e) + recvErr := make(chan error, 1) + te.testServer = &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error { + _, err := stream.Recv() + recvErr <- err + return nil + }} + te.startServer() + defer te.tearDown() + te.withServerTester(func(st *serverTester) { + st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall") + // Say we have 5 bytes coming, but cancel it instead. + st.writeData(1, false, []byte{0, 0, 0, 0, 5}) + st.writeRSTStream(1, http2.ErrCodeCancel) + + var got error + select { + case got = <-recvErr: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for error") + } + if se, ok := got.(transport.StreamError); !ok || se.Code != codes.Canceled { + t.Errorf("error = %#v; want transport.StreamError with code Canceled") + } + }) +} + // interestingGoroutines returns all goroutines we care about for the purpose // of leak checking. It excludes testing or runtime ones. func interestingGoroutines() (gs []string) { diff --git a/test/servertester_test.go b/test/servertester_test.go new file mode 100644 index 00000000..0225a857 --- /dev/null +++ b/test/servertester_test.go @@ -0,0 +1,289 @@ +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package grpc_test + +import ( + "bytes" + "errors" + "io" + "strings" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +// This is a subset of http2's serverTester type. +// +// serverTester wraps a io.ReadWriter (acting like the underlying +// network connection) and provides utility methods to read and write +// http2 frames. +// +// NOTE(bradfitz): this could eventually be exported somewhere. Others +// have asked for it too. For now I'm still experimenting with the +// API and don't feel like maintaining a stable testing API. + +type serverTester struct { + cc io.ReadWriteCloser // client conn + t testing.TB + fr *http2.Framer + + // writing headers: + headerBuf bytes.Buffer + hpackEnc *hpack.Encoder + + // reading frames: + frc chan http2.Frame + frErrc chan error + readTimer *time.Timer +} + +func newServerTesterFromConn(t testing.TB, cc io.ReadWriteCloser) *serverTester { + st := &serverTester{ + t: t, + cc: cc, + frc: make(chan http2.Frame, 1), + frErrc: make(chan error, 1), + } + st.hpackEnc = hpack.NewEncoder(&st.headerBuf) + st.fr = http2.NewFramer(cc, cc) + st.fr.ReadMetaHeaders = hpack.NewDecoder(4096 /*initialHeaderTableSize*/, nil) + + return st +} + +func (st *serverTester) readFrame() (http2.Frame, error) { + go func() { + fr, err := st.fr.ReadFrame() + if err != nil { + st.frErrc <- err + } else { + st.frc <- fr + } + }() + t := time.NewTimer(2 * time.Second) + defer t.Stop() + select { + case f := <-st.frc: + return f, nil + case err := <-st.frErrc: + return nil, err + case <-t.C: + return nil, errors.New("timeout waiting for frame") + } +} + +// greet initiates the client's HTTP/2 connection into a state where +// frames may be sent. +func (st *serverTester) greet() { + st.writePreface() + st.writeInitialSettings() + st.wantSettings() + st.writeSettingsAck() + for { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + switch f := f.(type) { + case *http2.WindowUpdateFrame: + // grpc's transport/http2_server sends this + // before the settings ack. The Go http2 + // server uses a setting instead. + case *http2.SettingsFrame: + if f.IsAck() { + return + } + st.t.Fatalf("during greet, got non-ACK settings frame") + default: + st.t.Fatalf("during greet, unexpected frame type %T", f) + } + } +} + +func (st *serverTester) writePreface() { + n, err := st.cc.Write([]byte(http2.ClientPreface)) + if err != nil { + st.t.Fatalf("Error writing client preface: %v", err) + } + if n != len(http2.ClientPreface) { + st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(http2.ClientPreface)) + } +} + +func (st *serverTester) writeInitialSettings() { + if err := st.fr.WriteSettings(); err != nil { + st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err) + } +} + +func (st *serverTester) writeSettingsAck() { + if err := st.fr.WriteSettingsAck(); err != nil { + st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err) + } +} + +func (st *serverTester) wantSettings() *http2.SettingsFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err) + } + sf, ok := f.(*http2.SettingsFrame) + if !ok { + st.t.Fatalf("got a %T; want *SettingsFrame", f) + } + return sf +} + +func (st *serverTester) wantSettingsAck() { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + sf, ok := f.(*http2.SettingsFrame) + if !ok { + st.t.Fatalf("Wanting a settings ACK, received a %T", f) + } + if !sf.IsAck() { + st.t.Fatal("Settings Frame didn't have ACK set") + } +} + +// wait for any activity from the server +func (st *serverTester) wantAnyFrame() http2.Frame { + f, err := st.fr.ReadFrame() + if err != nil { + st.t.Fatal(err) + } + return f +} + +func (st *serverTester) encodeHeaderField(k, v string) { + err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) + if err != nil { + st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) + } +} + +// encodeHeader encodes headers and returns their HPACK bytes. headers +// must contain an even number of key/value pairs. There may be +// multiple pairs for keys (e.g. "cookie"). The :method, :path, and +// :scheme headers default to GET, / and https. +func (st *serverTester) encodeHeader(headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + + st.headerBuf.Reset() + + if len(headers) == 0 { + // Fast path, mostly for benchmarks, so test code doesn't pollute + // profiles when we're looking to improve server allocations. + st.encodeHeaderField(":method", "GET") + st.encodeHeaderField(":path", "/") + st.encodeHeaderField(":scheme", "https") + return st.headerBuf.Bytes() + } + + if len(headers) == 2 && headers[0] == ":method" { + // Another fast path for benchmarks. + st.encodeHeaderField(":method", headers[1]) + st.encodeHeaderField(":path", "/") + st.encodeHeaderField(":scheme", "https") + return st.headerBuf.Bytes() + } + + pseudoCount := map[string]int{} + keys := []string{":method", ":path", ":scheme"} + vals := map[string][]string{ + ":method": {"GET"}, + ":path": {"/"}, + ":scheme": {"https"}, + } + for len(headers) > 0 { + k, v := headers[0], headers[1] + headers = headers[2:] + if _, ok := vals[k]; !ok { + keys = append(keys, k) + } + if strings.HasPrefix(k, ":") { + pseudoCount[k]++ + if pseudoCount[k] == 1 { + vals[k] = []string{v} + } else { + // Allows testing of invalid headers w/ dup pseudo fields. + vals[k] = append(vals[k], v) + } + } else { + vals[k] = append(vals[k], v) + } + } + for _, k := range keys { + for _, v := range vals[k] { + st.encodeHeaderField(k, v) + } + } + return st.headerBuf.Bytes() +} + +func (st *serverTester) writeHeadersGRPC(streamID uint32, path string) { + st.writeHeaders(http2.HeadersFrameParam{ + StreamID: streamID, + BlockFragment: st.encodeHeader( + ":method", "POST", + ":path", path, + "content-type", "application/grpc", + "te", "trailers", + ), + EndStream: false, + EndHeaders: true, + }) +} + +func (st *serverTester) writeHeaders(p http2.HeadersFrameParam) { + if err := st.fr.WriteHeaders(p); err != nil { + st.t.Fatalf("Error writing HEADERS: %v", err) + } +} + +func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) { + if err := st.fr.WriteData(streamID, endStream, data); err != nil { + st.t.Fatalf("Error writing DATA: %v", err) + } +} + +func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) { + if err := st.fr.WriteRSTStream(streamID, code); err != nil { + st.t.Fatalf("Error writing RST_STREAM: %v", err) + } +} diff --git a/transport/handler_server.go b/transport/handler_server.go index 8bfbf970..63ba0537 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -40,6 +40,7 @@ package transport import ( "errors" "fmt" + "io" "net" "net/http" "strings" @@ -319,7 +320,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { s.buf.put(&recvMsg{data: buf[:n]}) } if err != nil { - s.buf.put(&recvMsg{err: err}) + s.buf.put(&recvMsg{err: mapRecvMsgError(err)}) return } } @@ -352,3 +353,25 @@ func (ht *serverHandlerTransport) runStream() { } } } + +// mapRecvMsgError returns the non-nil err into the appropriate +// error value as expected by callers of *grpc.parser.recvMsg. +// In particular, in can only be: +// * io.EOF +// * io.ErrUnexpectedEOF +// * of type transport.ConnectionError +// * of type transport.StreamError +func mapRecvMsgError(err error) error { + if err == io.EOF || err == io.ErrUnexpectedEOF { + return err + } + if se, ok := err.(http2.StreamError); ok { + if code, ok := http2ErrConvTab[se.Code]; ok { + return StreamError{ + Code: code, + Desc: se.Error(), + } + } + } + return ConnectionError{Desc: err.Error()} +} diff --git a/transport/http2_client.go b/transport/http2_client.go index 7cf700fe..bb72fea3 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -637,7 +637,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { close(s.headerChan) s.headerDone = true } - s.statusCode, ok = http2RSTErrConvTab[http2.ErrCode(f.ErrCode)] + s.statusCode, ok = http2ErrConvTab[http2.ErrCode(f.ErrCode)] if !ok { grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) } diff --git a/transport/http_util.go b/transport/http_util.go index 25173118..ea7b39bf 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -62,8 +62,8 @@ const ( ) var ( - clientPreface = []byte(http2.ClientPreface) - http2RSTErrConvTab = map[http2.ErrCode]codes.Code{ + clientPreface = []byte(http2.ClientPreface) + http2ErrConvTab = map[http2.ErrCode]codes.Code{ http2.ErrCodeNo: codes.Internal, http2.ErrCodeProtocol: codes.Internal, http2.ErrCodeInternal: codes.Internal, @@ -76,6 +76,7 @@ var ( http2.ErrCodeConnect: codes.Internal, http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, http2.ErrCodeInadequateSecurity: codes.PermissionDenied, + http2.ErrCodeHTTP11Required: codes.FailedPrecondition, } statusCodeConvTab = map[codes.Code]http2.ErrCode{ codes.Internal: http2.ErrCodeInternal, diff --git a/transport/transport_test.go b/transport/transport_test.go index 07128d5d..cb11ab2f 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -131,9 +131,9 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { var err error if port == 0 { - s.lis, err = net.Listen("tcp", ":0") + s.lis, err = net.Listen("tcp", "localhost:0") } else { - s.lis, err = net.Listen("tcp", ":"+strconv.Itoa(port)) + s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port)) } if err != nil { s.startedErr <- fmt.Errorf("failed to listen: %v", err) @@ -568,7 +568,7 @@ func TestServerWithMisbehavedClient(t *testing.T) { sent++ } // Server sent a resetStream for s already. - code := http2RSTErrConvTab[http2.ErrCodeFlowControl] + code := http2ErrConvTab[http2.ErrCodeFlowControl] if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF || s.statusCode != code { t.Fatalf("%v got err %v with statusCode %d, want err with statusCode %d", s, err, s.statusCode, code) }