diff --git a/transport/transport_test.go b/transport/transport_test.go index 6a15bf30..9c4147cb 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -37,6 +37,7 @@ import ( "bytes" "fmt" "io" + "io/ioutil" "math" "net" "reflect" @@ -75,7 +76,7 @@ const ( normal hType = iota suspended misbehaved - malformedStatus + encodingRequiredStatus ) func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { @@ -128,9 +129,8 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { } } -func (h *testStreamHandler) handleStreamMalformedStatus(t *testing.T, s *Stream) { - // raw newline is not accepted by http2 framer and a http2.StreamError is - // generated. +func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) { + // raw newline is not accepted by http2 framer so it must be encoded. h.t.WriteStatus(s, codes.Internal, "\n") } @@ -179,9 +179,9 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) { go transport.HandleStreams(func(s *Stream) { go h.handleStreamMisbehave(t, s) }) - case malformedStatus: + case encodingRequiredStatus: go transport.HandleStreams(func(s *Stream) { - go h.handleStreamMalformedStatus(t, s) + go h.handleStreamEncodingRequiredStatus(t, s) }) default: go transport.HandleStreams(func(s *Stream) { @@ -663,8 +663,8 @@ func TestClientWithMisbehavedServer(t *testing.T) { server.stop() } -func TestMalformedStatus(t *testing.T) { - server, ct := setUp(t, 0, math.MaxUint32, malformedStatus) +func TestEncodingRequiredStatus(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) callHdr := &CallHdr{ Host: "localhost", Method: "foo", @@ -680,10 +680,8 @@ func TestMalformedStatus(t *testing.T) { if err := ct.Write(s, expectedRequest, &opts); err != nil { t.Fatalf("Failed to write the request: %v", err) } - p := make([]byte, http2MaxFrameLen) - expectedErr := StreamErrorf(codes.Internal, "invalid header field value \"\\n\"") - if _, err = s.dec.Read(p); err != expectedErr { - t.Fatalf("Read the err %v, want %v", err, expectedErr) + if _, err = ioutil.ReadAll(s); err != nil { + t.Fatal(err) } ct.Close() server.stop()