diff --git a/server.go b/server.go index 58827c53..bfb9c606 100644 --- a/server.go +++ b/server.go @@ -464,6 +464,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } }() } + if s.opts.cp != nil { + // NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686. + stream.SetSendCompress(s.opts.cp.Type()) + } p := &parser{r: stream} for { pf, req, err := p.recvMsg() @@ -549,9 +553,6 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Last: true, Delay: false, } - if s.opts.cp != nil { - stream.SetSendCompress(s.opts.cp.Type()) - } if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil { switch err := err.(type) { case transport.ConnectionError: diff --git a/test/end2end_test.go b/test/end2end_test.go index bf1622d8..09bcc392 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -130,6 +130,9 @@ func newPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) { func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { md, ok := metadata.FromContext(ctx) if ok { + if _, exists := md[":authority"]; !exists { + return nil, grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md) + } if err := grpc.SendHeader(ctx, md); err != nil { return nil, fmt.Errorf("grpc.SendHeader(%v, %v) = %v, want %v", ctx, md, err, nil) } @@ -167,6 +170,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if err != nil { return nil, err } + return &testpb.SimpleResponse{ Payload: payload, }, nil @@ -174,8 +178,11 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { if md, ok := metadata.FromContext(stream.Context()); ok { - // For testing purpose, returns an error if there is attached metadata. - if len(md) > 0 { + if _, exists := md[":authority"]; !exists { + return grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md) + } + // For testing purpose, returns an error if there is attached metadata except for authority. + if len(md) > 1 { return grpc.Errorf(codes.DataLoss, "got extra metadata") } } @@ -1733,7 +1740,8 @@ func testCompressOK(t *testing.T, e env) { ResponseSize: proto.Int32(respSize), Payload: payload, } - if _, err := tc.UnaryCall(context.Background(), req); err != nil { + ctx := metadata.NewContext(context.Background(), metadata.Pairs("something", "something")) + if _, err := tc.UnaryCall(ctx, req); err != nil { t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) } // Streaming RPC diff --git a/transport/handler_server.go b/transport/handler_server.go index fef541db..7a4ae07b 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -92,9 +92,12 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr } var metakv []string + if r.Host != "" { + metakv = append(metakv, ":authority", r.Host) + } for k, vv := range r.Header { k = strings.ToLower(k) - if isReservedHeader(k) { + if isReservedHeader(k) && !isWhitelistedPseudoHeader(k){ continue } for _, v := range vv { @@ -108,7 +111,6 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr } } metakv = append(metakv, k, v) - } } st.headerMD = metadata.Pairs(metakv...) @@ -196,6 +198,10 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, } if md := s.Trailer(); len(md) > 0 { for k, vv := range md { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + if isReservedHeader(k) { + continue + } for _, v := range vv { // http2 ResponseWriter mechanism to // send undeclared Trailers after the @@ -249,6 +255,10 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { ht.writeCommonHeaders(s) h := ht.rw.Header() for k, vv := range md { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + if isReservedHeader(k) { + continue + } for _, v := range vv { h.Add(k, v) } diff --git a/transport/http2_client.go b/transport/http2_client.go index be521ffb..459d14d6 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -344,6 +344,10 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea if md, ok := metadata.FromContext(ctx); ok { hasMD = true for k, v := range md { + // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. + if isReservedHeader(k) { + continue + } for _, entry := range v { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } diff --git a/transport/http2_server.go b/transport/http2_server.go index 21b63116..1c4d5852 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -460,6 +460,10 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) } for k, v := range md { + if isReservedHeader(k) { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + continue + } for _, entry := range v { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } @@ -502,6 +506,10 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) // Attach the trailer metadata. for k, v := range s.trailer { + // Clients don't tolerate reading restricted headers after some non restricted ones were sent. + if isReservedHeader(k) { + continue + } for _, entry := range v { t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) } diff --git a/transport/http_util.go b/transport/http_util.go index 7a3594ac..a4b1b07d 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -127,6 +127,17 @@ func isReservedHeader(hdr string) bool { } } +// isWhitelistedPseudoHeader checks whether hdr belongs to HTTP2 pseudoheaders +// that should be propagated into metadata visible to users. +func isWhitelistedPseudoHeader(hdr string) bool { + switch hdr { + case ":authority": + return true + default: + return false + } +} + func (d *decodeState) setErr(err error) { if d.err == nil { d.err = err @@ -162,7 +173,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { case ":path": d.method = f.Value default: - if !isReservedHeader(f.Name) { + if !isReservedHeader(f.Name) || isWhitelistedPseudoHeader(f.Name) { if f.Name == "user-agent" { i := strings.LastIndex(f.Value, " ") if i == -1 {