diff --git a/transport/http_util_test.go b/transport/http_util_test.go index 62519490..9967161b 100644 --- a/transport/http_util_test.go +++ b/transport/http_util_test.go @@ -109,27 +109,36 @@ func TestValidContentType(t *testing.T) { } func TestGrpcMessageEncode(t *testing.T) { - testGrpcMessageEncode(t, "my favorite character is \u0000", "my favorite character is %00") - testGrpcMessageEncode(t, "my favorite character is %", "my favorite character is %25") + for _, tt := range []struct { + input string + expected string + }{ + {"", ""}, + {"Hello", "Hello"}, + {"my favorite character is \u0000", "my favorite character is %00"}, + {"my favorite character is %", "my favorite character is %25"}, + } { + actual := grpcMessageEncode(tt.input) + if tt.expected != actual { + t.Errorf("grpcMessageEncode(%v) = %v, want %v", tt.input, actual, tt.expected) + } + } } func TestGrpcMessageDecode(t *testing.T) { - testGrpcMessageDecode(t, "Hello", "Hello") - testGrpcMessageDecode(t, "H%61o", "Hao") - testGrpcMessageDecode(t, "H%6", "H%6") - testGrpcMessageDecode(t, "%G0", "%G0") -} - -func testGrpcMessageEncode(t *testing.T, input string, expected string) { - actual := grpcMessageEncode(input) - if expected != actual { - t.Errorf("grpcMessageEncode(%v) = %v, want %v", input, actual, expected) - } -} - -func testGrpcMessageDecode(t *testing.T, input string, expected string) { - actual := grpcMessageDecode(input) - if expected != actual { - t.Errorf("grpcMessageDncode(%v) = %v, want %v", input, actual, expected) + for _, tt := range []struct { + input string + expected string + }{ + {"", ""}, + {"Hello", "Hello"}, + {"H%61o", "Hao"}, + {"H%6", "H%6"}, + {"%G0", "%G0"}, + } { + actual := grpcMessageDecode(tt.input) + if tt.expected != actual { + t.Errorf("grpcMessageDncode(%v) = %v, want %v", tt.input, actual, tt.expected) + } } } diff --git a/transport/transport_test.go b/transport/transport_test.go index 6ad0c782..047e6543 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -37,7 +37,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "math" "net" "reflect" @@ -131,7 +130,7 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) { // raw newline is not accepted by http2 framer so it must be encoded. - h.t.WriteStatus(s, codes.Internal, "\n") + h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc) } // start starts server. Other goroutines should block on s.readyChan for further operations. @@ -714,6 +713,11 @@ func TestClientWithMisbehavedServer(t *testing.T) { server.stop() } +var ( + encodingTestStatusCode = codes.Internal + encodingTestStatusDesc = "\n" +) + func TestEncodingRequiredStatus(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) callHdr := &CallHdr{ @@ -731,8 +735,12 @@ func TestEncodingRequiredStatus(t *testing.T) { if err := ct.Write(s, expectedRequest, &opts); err != nil { t.Fatalf("Failed to write the request: %v", err) } - if _, err = ioutil.ReadAll(s); err != nil { - t.Fatalf("Read got err %v, want ", err) + p := make([]byte, http2MaxFrameLen) + if _, err := s.dec.Read(p); err != io.EOF { + t.Fatalf("Read got error %v, want %v", err, io.EOF) + } + if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc { + t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc) } ct.Close() server.stop()