diff --git a/test/end2end_test.go b/test/end2end_test.go index 8b64234b..27179ee3 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -80,6 +80,9 @@ var ( "tkey1": []string{"trailerValue1"}, "tkey2": []string{"trailerValue2"}, } + malformedHTTP2Metadata = metadata.MD{ + "Key": []string{"foo"}, + } testAppUA = "myApp1/1.0 myApp2/0.9" ) @@ -889,6 +892,35 @@ func testMetadataUnaryRPC(t *testing.T, e env) { } } +func TestMalformedHTTP2Metadata(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testMalformedHTTP2Metadata(t, e) + } +} + +func testMalformedHTTP2Metadata(t *testing.T, e env) { + te := newTest(t, e) + te.startServer() + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 2718) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(314), + Payload: payload, + } + ctx := metadata.NewContext(context.Background(), malformedHTTP2Metadata) + if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Internal { + t.Fatalf("TestService.UnaryCall(%v, _) = _, %v; want _, %q", ctx, err, codes.Internal) + } +} + func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup) { defer wg.Done() const argSize = 2718 diff --git a/transport/http2_server.go b/transport/http2_server.go index cec441cf..03164236 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -246,6 +246,16 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { for { frame, err := t.framer.readFrame() if err != nil { + if se, ok := err.(http2.StreamError); ok { + t.mu.Lock() + s := t.activeStreams[se.StreamID] + t.mu.Unlock() + if s != nil { + t.closeStream(s) + } + t.controlBuf.put(&resetStream{se.StreamID, se.Code}) + continue + } t.Close() return }