diff --git a/transport/http2_client.go b/transport/http2_client.go index 80583ab7..3e5ff731 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -968,18 +968,16 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } s.bytesReceived = true var state decodeState - for _, hf := range frame.Fields { - if err := state.processHeaderField(hf); err != nil { - s.mu.Lock() - if !s.headerDone { - close(s.headerChan) - s.headerDone = true - } - s.mu.Unlock() - s.write(recvMsg{err: err}) - // Something wrong. Stops reading even when there is remaining. - return + if err := state.decodeResponseHeader(frame); err != nil { + s.mu.Lock() + if !s.headerDone { + close(s.headerChan) + s.headerDone = true } + s.mu.Unlock() + s.write(recvMsg{err: err}) + // Something wrong. Stops reading even when there is remaining. + return } endStream := frame.StreamEnded() diff --git a/transport/http_util.go b/transport/http_util.go index 795d5d18..9b31717c 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -40,6 +40,7 @@ import ( "fmt" "io" "net" + "net/http" "strconv" "strings" "sync/atomic" @@ -88,6 +89,24 @@ var ( codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm, codes.PermissionDenied: http2.ErrCodeInadequateSecurity, } + httpStatusConvTab = map[int]codes.Code{ + // 400 Bad Request - INTERNAL. + http.StatusBadRequest: codes.Internal, + // 401 Unauthorized - UNAUTHENTICATED. + http.StatusUnauthorized: codes.Unauthenticated, + // 403 Forbidden - PERMISSION_DENIED. + http.StatusForbidden: codes.PermissionDenied, + // 404 Not Found - UNIMPLEMENTED. + http.StatusNotFound: codes.Unimplemented, + // 429 Too Many Requests - UNAVAILABLE. + http.StatusTooManyRequests: codes.Unavailable, + // 502 Bad Gateway - UNAVAILABLE. + http.StatusBadGateway: codes.Unavailable, + // 503 Service Unavailable - UNAVAILABLE. + http.StatusServiceUnavailable: codes.Unavailable, + // 504 Gateway timeout - UNAVAILABLE. + http.StatusGatewayTimeout: codes.Unavailable, + } ) // Records the states during HPACK decoding. Must be reset once the @@ -100,8 +119,9 @@ type decodeState struct { statusGen *status.Status // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not // intended for direct access outside of parsing. - rawStatusCode int32 + rawStatusCode *int rawStatusMsg string + httpStatus *int // Server side only fields. timeoutSet bool timeout time.Duration @@ -159,7 +179,7 @@ func validContentType(t string) bool { func (d *decodeState) status() *status.Status { if d.statusGen == nil { // No status-details were provided; generate status using code/msg. - d.statusGen = status.New(codes.Code(d.rawStatusCode), d.rawStatusMsg) + d.statusGen = status.New(codes.Code(int32(*(d.rawStatusCode))), d.rawStatusMsg) } return d.statusGen } @@ -193,6 +213,44 @@ func decodeMetadataHeader(k, v string) (string, error) { return v, nil } +func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error { + for _, hf := range frame.Fields { + if err := d.processHeaderField(hf); err != nil { + return err + } + } + + // If grpc status exists, no need to check further. + if d.rawStatusCode != nil || d.statusGen != nil { + return nil + } + + // If grpc status doesn't exist and http status doesn't exist, + // then it's a malformed header. + if d.httpStatus == nil { + return streamErrorf(codes.Internal, "malformed header: doesn't contain status(gRPC or HTTP)") + } + + if *(d.httpStatus) != http.StatusOK { + code, ok := httpStatusConvTab[*(d.httpStatus)] + if !ok { + code = codes.Unknown + } + return streamErrorf(code, http.StatusText(*(d.httpStatus))) + } + + // gRPC status doesn't exist and http status is OK. + // Set rawStatusCode to be unknown and return nil error. + // So that, if the stream has ended this Unknown status + // will be propogated to the user. + // Otherwise, it will be ignored. In which case, status from + // a later trailer, that has StreamEnded flag set, is propogated. + code := int(codes.Unknown) + d.rawStatusCode = &code + return nil + +} + func (d *decodeState) processHeaderField(f hpack.HeaderField) error { switch f.Name { case "content-type": @@ -206,7 +264,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error { if err != nil { return streamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err) } - d.rawStatusCode = int32(code) + d.rawStatusCode = &code case "grpc-message": d.rawStatusMsg = decodeGrpcMessage(f.Value) case "grpc-status-details-bin": @@ -227,6 +285,12 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) error { } case ":path": d.method = f.Value + case ":status": + code, err := strconv.Atoi(f.Value) + if err != nil { + return streamErrorf(codes.Internal, "transport: malformed http-status: %v", err) + } + d.httpStatus = &code default: if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) { if d.mdata == nil { diff --git a/transport/transport_test.go b/transport/transport_test.go index 7429f2e2..0b534d2e 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -34,11 +34,13 @@ package transport import ( + "bufio" "bytes" "fmt" "io" "math" "net" + "net/http" "reflect" "strconv" "strings" @@ -1416,3 +1418,192 @@ func waitWhileTrue(t *testing.T, condition func() (bool, error)) { break } } + +// A function of type writeHeaders writes out +// http status with the given stream ID using the given framer. +type writeHeaders func(*http2.Framer, uint32, int) error + +func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error { + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)}) + if err := framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: sid, + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }); err != nil { + return err + } + return nil +} + +func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error { + var buf bytes.Buffer + henc := hpack.NewEncoder(&buf) + henc.WriteField(hpack.HeaderField{ + Name: ":status", + Value: fmt.Sprint(http.StatusOK), + }) + if err := framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: sid, + BlockFragment: buf.Bytes(), + EndHeaders: true, + }); err != nil { + return err + } + buf.Reset() + henc.WriteField(hpack.HeaderField{ + Name: ":status", + Value: fmt.Sprint(httpStatus), + }) + if err := framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: sid, + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }); err != nil { + return err + } + return nil +} + +type httpServer struct { + conn net.Conn + httpStatus int + wh writeHeaders +} + +func (s *httpServer) start(t *testing.T, lis net.Listener) { + // Launch an HTTP server to send back header with httpStatus. + go func() { + var err error + s.conn, err = lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + defer s.conn.Close() + // Read preface sent by client. + if _, err = io.ReadFull(s.conn, make([]byte, len(http2.ClientPreface))); err != nil { + t.Errorf("Error at server-side while reading preface from cleint. Err: %v", err) + return + } + reader := bufio.NewReaderSize(s.conn, http2IOBufSize) + writer := bufio.NewWriterSize(s.conn, http2IOBufSize) + framer := http2.NewFramer(writer, reader) + if err = framer.WriteSettingsAck(); err != nil { + t.Errorf("Error at server-side while sending Settings ack. Err: %v", err) + return + } + var sid uint32 + // Read frames until a header is received. + for { + frame, err := framer.ReadFrame() + if err != nil { + t.Errorf("Error at server-side while reading frame. Err: %v", err) + return + } + if hframe, ok := frame.(*http2.HeadersFrame); ok { + sid = hframe.Header().StreamID + break + } + } + if err = s.wh(framer, sid, s.httpStatus); err != nil { + t.Errorf("Error at server-side while writing headers. Err: %v", err) + return + } + writer.Flush() + }() +} + +func (s *httpServer) cleanUp() { + if s.conn != nil { + s.conn.Close() + } +} + +func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream *Stream, cleanUp func()) { + var ( + err error + lis net.Listener + server *httpServer + client ClientTransport + ) + cleanUp = func() { + if lis != nil { + lis.Close() + } + if server != nil { + server.cleanUp() + } + if client != nil { + client.Close() + } + } + defer func() { + if err != nil { + cleanUp() + } + }() + lis, err = net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen. Err: %v", err) + } + server = &httpServer{ + httpStatus: httpStatus, + wh: wh, + } + server.start(t, lis) + client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}) + if err != nil { + t.Fatalf("Error creating client. Err: %v", err) + } + stream, err = client.NewStream(context.Background(), &CallHdr{Method: "bogus/method", Flush: true}) + if err != nil { + t.Fatalf("Error creating stream at client-side. Err: %v", err) + } + return +} + +func TestHTTPToGRPCStatusMapping(t *testing.T) { + for k := range httpStatusConvTab { + testHTTPToGRPCStatusMapping(t, k, writeOneHeader) + } +} + +func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) { + stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh) + defer cleanUp() + want := httpStatusConvTab[httpStatus] + _, err := stream.Read([]byte{}) + if err == nil { + t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want) + } + serr, ok := err.(StreamError) + if !ok { + t.Fatalf("err.(Type) = %T, want StreamError", err) + } + if want != serr.Code { + t.Fatalf("Want error code: %v, got: %v", want, serr.Code) + } +} + +func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) { + stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader) + defer cleanUp() + _, err := stream.Read([]byte{}) + if err != io.EOF { + t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err) + } + want := codes.Unknown + stream.mu.Lock() + defer stream.mu.Unlock() + if stream.status.Code() != want { + t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want) + } +} + +func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) { + testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders) +}