diff --git a/clientconn.go b/clientconn.go index 199f7747..78706f69 100644 --- a/clientconn.go +++ b/clientconn.go @@ -473,6 +473,14 @@ func defaultDialOptions() dialOptions { } } +// WithMaxHeaderListSize returns a DialOption that specifies the maximum (uncompressed) size of +// header list that the client is prepared to accept. +func WithMaxHeaderListSize(s uint32) DialOption { + return func(o *dialOptions) { + o.copts.MaxHeaderListSize = &s + } +} + // Dial creates a client connection to the given target. func Dial(target string, opts ...DialOption) (*ClientConn, error) { return DialContext(context.Background(), target, opts...) diff --git a/server.go b/server.go index e29bd2c4..a85297bc 100644 --- a/server.go +++ b/server.go @@ -135,6 +135,7 @@ type options struct { writeBufferSize int readBufferSize int connectionTimeout time.Duration + maxHeaderListSize *uint32 } var defaultServerOptions = options{ @@ -343,6 +344,14 @@ func ConnectionTimeout(d time.Duration) ServerOption { } } +// MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size +// of header list that the server is prepared to accept. +func MaxHeaderListSize(s uint32) ServerOption { + return func(o *options) { + o.maxHeaderListSize = &s + } +} + // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -665,6 +674,7 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr WriteBufferSize: s.opts.writeBufferSize, ReadBufferSize: s.opts.readBufferSize, ChannelzParentID: s.channelzID, + MaxHeaderListSize: s.opts.maxHeaderListSize, } st, err := transport.NewServerTransport("http2", c, config) if err != nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index e47b7534..0087df56 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -452,6 +452,8 @@ type test struct { maxClientSendMsgSize *int maxServerReceiveMsgSize *int maxServerSendMsgSize *int + maxClientHeaderListSize *uint32 + maxServerHeaderListSize *uint32 userAgent string // clientCompression and serverCompression are set to test the deprecated API // WithCompressor and WithDecompressor. @@ -546,6 +548,9 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, if te.maxServerSendMsgSize != nil { sopts = append(sopts, grpc.MaxSendMsgSize(*te.maxServerSendMsgSize)) } + if te.maxServerHeaderListSize != nil { + sopts = append(sopts, grpc.MaxHeaderListSize(*te.maxServerHeaderListSize)) + } if te.tapHandle != nil { sopts = append(sopts, grpc.InTapHandle(te.tapHandle)) } @@ -697,6 +702,9 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) if te.maxClientSendMsgSize != nil { opts = append(opts, grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(*te.maxClientSendMsgSize))) } + if te.maxClientHeaderListSize != nil { + opts = append(opts, grpc.WithMaxHeaderListSize(*te.maxClientHeaderListSize)) + } switch te.e.security { case "tls": creds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "x.test.youtube.com") @@ -6454,3 +6462,166 @@ func TestDisabledIOBuffers(t *testing.T) { t.Fatalf("stream.Recv() = _, %v, want _, io.EOF", err) } } + +func TestServerMaxHeaderListSizeClientUserViolation(t *testing.T) { + defer leakcheck.Check(t) + for _, e := range listTestEnv() { + if e.httpHandler { + continue + } + testServerMaxHeaderListSizeClientUserViolation(t, e) + } +} + +func testServerMaxHeaderListSizeClientUserViolation(t *testing.T, e env) { + te := newTest(t, e) + te.maxServerHeaderListSize = new(uint32) + *te.maxServerHeaderListSize = 216 + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + metadata.AppendToOutgoingContext(ctx, "oversize", string(make([]byte, 216))) + var err error + if err = verifyResultWithDelay(func() (bool, error) { + if _, err = tc.EmptyCall(ctx, &testpb.Empty{}); err != nil && status.Code(err) == codes.Internal { + return true, nil + } + return false, fmt.Errorf("tc.EmptyCall() = _, err: %v, want _, error code: %v", err, codes.Internal) + }); err != nil { + t.Fatal(err) + } +} + +func TestClientMaxHeaderListSizeServerUserViolation(t *testing.T) { + defer leakcheck.Check(t) + for _, e := range listTestEnv() { + if e.httpHandler == true { + continue + } + testClientMaxHeaderListSizeServerUserViolation(t, e) + } +} + +func testClientMaxHeaderListSizeServerUserViolation(t *testing.T, e env) { + te := newTest(t, e) + te.maxClientHeaderListSize = new(uint32) + *te.maxClientHeaderListSize = 1 // any header server sends will violate + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var err error + if err = verifyResultWithDelay(func() (bool, error) { + if _, err = tc.EmptyCall(ctx, &testpb.Empty{}); err != nil && status.Code(err) == codes.Internal { + return true, nil + } + return false, fmt.Errorf("tc.EmptyCall() = _, err: %v, want _, error code: %v", err, codes.Internal) + }); err != nil { + t.Fatal(err) + } +} + +func TestServerMaxHeaderListSizeClientIntentionalViolation(t *testing.T) { + defer leakcheck.Check(t) + for _, e := range listTestEnv() { + if e.httpHandler == true || e.security == "tls" { + continue + } + testServerMaxHeaderListSizeClientIntentionalViolation(t, e) + } +} + +func testServerMaxHeaderListSizeClientIntentionalViolation(t *testing.T, e env) { + te := newTest(t, e) + te.maxServerHeaderListSize = new(uint32) + *te.maxServerHeaderListSize = 512 + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + + cc, dw := te.clientConnWithConnControl() + tc := &testServiceClientWrapper{TestServiceClient: testpb.NewTestServiceClient(cc)} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, <nil>", tc, err) + } + rcw := dw.getRawConnWrapper() + val := make([]string, 512) + for i := range val { + val[i] = "a" + } + // allow for client to send the initial header + time.Sleep(100 * time.Millisecond) + rcw.writeHeaders(http2.HeadersFrameParam{ + StreamID: tc.getCurrentStreamID(), + BlockFragment: rcw.encodeHeader("oversize", strings.Join(val, "")), + EndStream: false, + EndHeaders: true, + }) + if _, err := stream.Recv(); err == nil || status.Code(err) != codes.Internal { + t.Fatalf("stream.Recv() = _, %v, want _, error code: %v", err, codes.Internal) + } +} + +func TestClientMaxHeaderListSizeServerIntentionalViolation(t *testing.T) { + defer leakcheck.Check(t) + for _, e := range listTestEnv() { + if e.httpHandler == true || e.security == "tls" { + continue + } + testClientMaxHeaderListSizeServerIntentionalViolation(t, e) + } +} + +func testClientMaxHeaderListSizeServerIntentionalViolation(t *testing.T, e env) { + te := newTest(t, e) + te.maxClientHeaderListSize = new(uint32) + *te.maxClientHeaderListSize = 200 + lw := te.startServerWithConnControl(&testServer{security: e.security, setHeaderOnly: true}) + defer te.tearDown() + cc, _ := te.clientConnWithConnControl() + tc := &testServiceClientWrapper{TestServiceClient: testpb.NewTestServiceClient(cc)} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, <nil>", tc, err) + } + var i int + var rcw *rawConnWrapper + for i = 0; i < 100; i++ { + rcw = lw.getLastConn() + if rcw != nil { + break + } + time.Sleep(10 * time.Millisecond) + continue + } + if i == 100 { + t.Fatalf("failed to create server transport after 1s") + } + + val := make([]string, 200) + for i := range val { + val[i] = "a" + } + // allow for client to send the initial header. + time.Sleep(100 * time.Millisecond) + rcw.writeHeaders(http2.HeadersFrameParam{ + StreamID: tc.getCurrentStreamID(), + BlockFragment: rcw.encodeHeader("oversize", strings.Join(val, "")), + EndStream: false, + EndHeaders: true, + }) + if _, err := stream.Recv(); err == nil || status.Code(err) != codes.Internal { + t.Fatalf("stream.Recv() = _, %v, want _, error code: %v", err, codes.Internal) + } +} diff --git a/transport/controlbuf.go b/transport/controlbuf.go index 853d9ef3..ce135c4d 100644 --- a/transport/controlbuf.go +++ b/transport/controlbuf.go @@ -285,6 +285,21 @@ func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{ return true, nil } +// Note argument f should never be nil. +func (c *controlBuffer) execute(f func(it interface{}) bool, it interface{}) (bool, error) { + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return false, c.err + } + if !f(it) { // f wasn't successful + c.mu.Unlock() + return false, nil + } + c.mu.Unlock() + return true, nil +} + func (c *controlBuffer) get(block bool) (interface{}, error) { for { c.mu.Lock() diff --git a/transport/defaults.go b/transport/defaults.go new file mode 100644 index 00000000..9fa306b2 --- /dev/null +++ b/transport/defaults.go @@ -0,0 +1,49 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport + +import ( + "math" + "time" +) + +const ( + // The default value of flow control window size in HTTP2 spec. + defaultWindowSize = 65535 + // The initial window size for flow control. + initialWindowSize = defaultWindowSize // for an RPC + infinity = time.Duration(math.MaxInt64) + defaultClientKeepaliveTime = infinity + defaultClientKeepaliveTimeout = 20 * time.Second + defaultMaxStreamsClient = 100 + defaultMaxConnectionIdle = infinity + defaultMaxConnectionAge = infinity + defaultMaxConnectionAgeGrace = infinity + defaultServerKeepaliveTime = 2 * time.Hour + defaultServerKeepaliveTimeout = 20 * time.Second + defaultKeepalivePolicyMinTime = 5 * time.Minute + // max window limit set by HTTP2 Specs. + maxWindowSize = math.MaxInt32 + // defaultWriteQuota is the default value for number of data + // bytes that each stream can schedule before some of it being + // flushed out. + defaultWriteQuota = 64 * 1024 + defaultClientMaxHeaderListSize = uint32(16 << 20) + defaultServerMaxHeaderListSize = uint32(16 << 20) +) diff --git a/transport/flowcontrol.go b/transport/flowcontrol.go index bbf98b6f..5ea997a7 100644 --- a/transport/flowcontrol.go +++ b/transport/flowcontrol.go @@ -23,30 +23,6 @@ import ( "math" "sync" "sync/atomic" - "time" -) - -const ( - // The default value of flow control window size in HTTP2 spec. - defaultWindowSize = 65535 - // The initial window size for flow control. - initialWindowSize = defaultWindowSize // for an RPC - infinity = time.Duration(math.MaxInt64) - defaultClientKeepaliveTime = infinity - defaultClientKeepaliveTimeout = 20 * time.Second - defaultMaxStreamsClient = 100 - defaultMaxConnectionIdle = infinity - defaultMaxConnectionAge = infinity - defaultMaxConnectionAgeGrace = infinity - defaultServerKeepaliveTime = 2 * time.Hour - defaultServerKeepaliveTimeout = 20 * time.Second - defaultKeepalivePolicyMinTime = 5 * time.Minute - // max window limit set by HTTP2 Specs. - maxWindowSize = math.MaxInt32 - // defaultWriteQuota is the default value for number of data - // bytes that each stream can schedule before some of it being - // flushed out. - defaultWriteQuota = 64 * 1024 ) // writeQuota is a soft limit on the amount of data a stream can diff --git a/transport/http2_client.go b/transport/http2_client.go index 528efcd4..3f11089a 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -85,6 +85,9 @@ type http2Client struct { initialWindowSize int32 + // configured by peer through SETTINGS_MAX_HEADER_LIST_SIZE + maxSendHeaderListSize *uint32 + bdpEst *bdpEstimator // onSuccess is a callback that client transport calls upon // receiving server preface to signal that a succefull HTTP2 @@ -199,6 +202,10 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne } writeBufSize := opts.WriteBufferSize readBufSize := opts.ReadBufferSize + maxHeaderListSize := defaultClientMaxHeaderListSize + if opts.MaxHeaderListSize != nil { + maxHeaderListSize = *opts.MaxHeaderListSize + } t := &http2Client{ ctx: ctx, ctxDone: ctx.Done(), // Cache Done chan. @@ -213,7 +220,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne writerDone: make(chan struct{}), goAway: make(chan struct{}), awakenKeepalive: make(chan struct{}, 1), - framer: newFramer(conn, writeBufSize, readBufSize), + framer: newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize), fc: &trInFlow{limit: uint32(icwz)}, scheme: scheme, activeStreams: make(map[uint32]*Stream), @@ -273,14 +280,21 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne t.Close() return nil, connectionErrorf(true, err, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) } + var ss []http2.Setting + if t.initialWindowSize != defaultWindowSize { - err = t.framer.fr.WriteSettings(http2.Setting{ + ss = append(ss, http2.Setting{ ID: http2.SettingInitialWindowSize, Val: uint32(t.initialWindowSize), }) - } else { - err = t.framer.fr.WriteSettings() } + if opts.MaxHeaderListSize != nil { + ss = append(ss, http2.Setting{ + ID: http2.SettingMaxHeaderListSize, + Val: *opts.MaxHeaderListSize, + }) + } + err = t.framer.fr.WriteSettings(ss...) if err != nil { t.Close() return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err) @@ -588,14 +602,40 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } return true } + var hdrListSizeErr error + checkForHeaderListSize := func(it interface{}) bool { + if t.maxSendHeaderListSize == nil { + return true + } + hdrFrame := it.(*headerFrame) + var sz int64 + for _, f := range hdrFrame.hf { + if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { + hdrListSizeErr = streamErrorf(codes.Internal, "header list size to send violates the maximum size (%d bytes) set by server", *t.maxSendHeaderListSize) + return false + } + } + return true + } for { - success, err := t.controlBuf.executeAndPut(checkForStreamQuota, hdr) + success, err := t.controlBuf.executeAndPut(func(it interface{}) bool { + if !checkForStreamQuota(it) { + return false + } + if !checkForHeaderListSize(it) { + return false + } + return true + }, hdr) if err != nil { return nil, err } if success { break } + if hdrListSizeErr != nil { + return nil, hdrListSizeErr + } firstTry = false select { case <-ch: @@ -917,13 +957,20 @@ func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) { } var maxStreams *uint32 var ss []http2.Setting + var updateFuncs []func() f.ForeachSetting(func(s http2.Setting) error { - if s.ID == http2.SettingMaxConcurrentStreams { + switch s.ID { + case http2.SettingMaxConcurrentStreams: maxStreams = new(uint32) *maxStreams = s.Val - return nil + case http2.SettingMaxHeaderListSize: + updateFuncs = append(updateFuncs, func() { + t.maxSendHeaderListSize = new(uint32) + *t.maxSendHeaderListSize = s.Val + }) + default: + ss = append(ss, s) } - ss = append(ss, s) return nil }) if isFirst && maxStreams == nil { @@ -933,21 +980,24 @@ func (t *http2Client) handleSettings(f *http2.SettingsFrame, isFirst bool) { sf := &incomingSettings{ ss: ss, } - if maxStreams == nil { - t.controlBuf.put(sf) - return + if maxStreams != nil { + updateStreamQuota := func() { + delta := int64(*maxStreams) - int64(t.maxConcurrentStreams) + t.maxConcurrentStreams = *maxStreams + t.streamQuota += delta + if delta > 0 && t.waitingStreams > 0 { + close(t.streamsQuotaAvailable) // wake all of them up. + t.streamsQuotaAvailable = make(chan struct{}, 1) + } + } + updateFuncs = append(updateFuncs, updateStreamQuota) } - updateStreamQuota := func(interface{}) bool { - delta := int64(*maxStreams) - int64(t.maxConcurrentStreams) - t.maxConcurrentStreams = *maxStreams - t.streamQuota += delta - if delta > 0 && t.waitingStreams > 0 { - close(t.streamsQuotaAvailable) // wake all of them up. - t.streamsQuotaAvailable = make(chan struct{}, 1) + t.controlBuf.executeAndPut(func(interface{}) bool { + for _, f := range updateFuncs { + f() } return true - } - t.controlBuf.executeAndPut(updateStreamQuota, sf) + }, sf) } func (t *http2Client) handlePing(f *http2.PingFrame) { @@ -1058,7 +1108,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } atomic.StoreUint32(&s.bytesReceived, 1) var state decodeState - if err := state.decodeResponseHeader(frame); err != nil { + if err := state.decodeHeader(frame); err != nil { t.closeStream(s, err, true, http2.ErrCodeProtocol, status.New(codes.Internal, err.Error()), nil, false) // Something wrong. Stops reading even when there is remaining. return diff --git a/transport/http2_server.go b/transport/http2_server.go index 6b1ceabe..4a9a6753 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -48,9 +48,14 @@ import ( "google.golang.org/grpc/tap" ) -// ErrIllegalHeaderWrite indicates that setting header is illegal because of -// the stream's state. -var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") +var ( + // ErrIllegalHeaderWrite indicates that setting header is illegal because of + // the stream's state. + ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + // ErrHeaderListSizeLimitViolation indicates that the header list size is larger + // than the limit set by peer. + ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") +) // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { @@ -89,9 +94,10 @@ type http2Server struct { // Flag to signify that number of ping strikes should be reset to 0. // This is set whenever data or header frames are sent. // 1 means yes. - resetPingStrikes uint32 // Accessed atomically. - initialWindowSize int32 - bdpEst *bdpEstimator + resetPingStrikes uint32 // Accessed atomically. + initialWindowSize int32 + bdpEst *bdpEstimator + maxSendHeaderListSize *uint32 mu sync.Mutex // guard the following @@ -132,7 +138,11 @@ type http2Server struct { func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { writeBufSize := config.WriteBufferSize readBufSize := config.ReadBufferSize - framer := newFramer(conn, writeBufSize, readBufSize) + maxHeaderListSize := defaultServerMaxHeaderListSize + if config.MaxHeaderListSize != nil { + maxHeaderListSize = *config.MaxHeaderListSize + } + framer := newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize) // Send initial settings as connection preface to client. var isettings []http2.Setting // TODO(zhaoq): Have a better way to signal "no limit" because 0 is @@ -162,6 +172,12 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err ID: http2.SettingInitialWindowSize, Val: uint32(iwz)}) } + if config.MaxHeaderListSize != nil { + isettings = append(isettings, http2.Setting{ + ID: http2.SettingMaxHeaderListSize, + Val: *config.MaxHeaderListSize, + }) + } if err := framer.fr.WriteSettings(isettings...); err != nil { return nil, connectionErrorf(false, err, "transport: %v", err) } @@ -281,19 +297,17 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { streamID := frame.Header().StreamID - var state decodeState - for _, hf := range frame.Fields { - if err := state.processHeaderField(hf); err != nil { - if se, ok := err.(StreamError); ok { - t.controlBuf.put(&cleanupStream{ - streamID: streamID, - rst: true, - rstCode: statusCodeConvTab[se.Code], - onWrite: func() {}, - }) - } - return + state := decodeState{serverSide: true} + if err := state.decodeHeader(frame); err != nil { + if se, ok := err.(StreamError); ok { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: statusCodeConvTab[se.Code], + onWrite: func() {}, + }) } + return } buf := newRecvBuffer() @@ -613,11 +627,25 @@ func (t *http2Server) handleSettings(f *http2.SettingsFrame) { return } var ss []http2.Setting + var updateFuncs []func() f.ForeachSetting(func(s http2.Setting) error { - ss = append(ss, s) + switch s.ID { + case http2.SettingMaxHeaderListSize: + updateFuncs = append(updateFuncs, func() { + t.maxSendHeaderListSize = new(uint32) + *t.maxSendHeaderListSize = s.Val + }) + default: + ss = append(ss, s) + } return nil }) - t.controlBuf.put(&incomingSettings{ + t.controlBuf.executeAndPut(func(interface{}) bool { + for _, f := range updateFuncs { + f() + } + return true + }, &incomingSettings{ ss: ss, }) } @@ -697,6 +725,21 @@ func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD) return headerFields } +func (t *http2Server) checkForHeaderListSize(it interface{}) bool { + if t.maxSendHeaderListSize == nil { + return true + } + hdrFrame := it.(*headerFrame) + var sz int64 + for _, f := range hdrFrame.hf { + if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { + errorf("header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize) + return false + } + } + return true +} + // WriteHeader sends the header metedata md back to the client. func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { if s.updateHeaderSent() || s.getState() == streamDone { @@ -710,12 +753,15 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { s.header = md } } - t.writeHeaderLocked(s) + if err := t.writeHeaderLocked(s); err != nil { + s.hdrMu.Unlock() + return err + } s.hdrMu.Unlock() return nil } -func (t *http2Server) writeHeaderLocked(s *Stream) { +func (t *http2Server) writeHeaderLocked(s *Stream) error { // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields // first and create a slice of that exact size. headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. @@ -725,7 +771,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) { headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) } headerFields = appendHeaderFieldsFromMD(headerFields, s.header) - t.controlBuf.put(&headerFrame{ + success, err := t.controlBuf.executeAndPut(t.checkForHeaderListSize, &headerFrame{ streamID: s.id, hf: headerFields, endStream: false, @@ -733,12 +779,20 @@ func (t *http2Server) writeHeaderLocked(s *Stream) { atomic.StoreUint32(&t.resetPingStrikes, 1) }, }) + if !success { + if err != nil { + return err + } + t.closeStream(s, true, http2.ErrCodeInternal, nil, false) + return ErrHeaderListSizeLimitViolation + } if t.stats != nil { // Note: WireLength is not set in outHeader. // TODO(mmukhi): Revisit this later, if needed. outHeader := &stats.OutHeader{} t.stats.HandleRPC(s.Context(), outHeader) } + return nil } // WriteStatus sends stream status to the client and terminates the stream. @@ -755,7 +809,10 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. if !s.updateHeaderSent() { // No headers have been sent. if len(s.header) > 0 { // Send a separate header frame. - t.writeHeaderLocked(s) + if err := t.writeHeaderLocked(s); err != nil { + s.hdrMu.Unlock() + return err + } } else { // Send a trailer only response. headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) @@ -785,6 +842,14 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { }, } s.hdrMu.Unlock() + success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader) + if !success { + if err != nil { + return err + } + t.closeStream(s, true, http2.ErrCodeInternal, nil, false) + return ErrHeaderListSizeLimitViolation + } t.closeStream(s, false, 0, trailingHeader, true) if t.stats != nil { t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) diff --git a/transport/http_util.go b/transport/http_util.go index 0c1b2b1c..dea43e7e 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -119,6 +119,8 @@ type decodeState struct { statsTags []byte statsTrace []byte contentSubtype string + // whether decoding on server side or not + serverSide bool } // isReservedHeader checks whether hdr belongs to HTTP2 headers @@ -235,13 +237,22 @@ func decodeMetadataHeader(k, v string) (string, error) { return v, nil } -func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error { +func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) error { + // frame.Truncated is set to true when framer detects that the current header + // list size hits MaxHeaderListSize limit. + if frame.Truncated { + return streamErrorf(codes.Internal, "peer header list size exceeded limit") + } for _, hf := range frame.Fields { if err := d.processHeaderField(hf); err != nil { return err } } + if d.serverSide { + return nil + } + // If grpc status exists, no need to check further. if d.rawStatusCode != nil || d.statusGen != nil { return nil @@ -270,7 +281,6 @@ func (d *decodeState) decodeResponseHeader(frame *http2.MetaHeadersFrame) error code := int(codes.Unknown) d.rawStatusCode = &code return nil - } func (d *decodeState) addMetadata(k, v string) { @@ -581,7 +591,7 @@ type framer struct { fr *http2.Framer } -func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer { +func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, maxHeaderListSize uint32) *framer { if writeBufferSize < 0 { writeBufferSize = 0 } @@ -597,6 +607,7 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int) *framer { // Opt-in to Frame reuse API on framer to reduce garbage. // Frames aren't safe to read from after a subsequent call to ReadFrame. f.fr.SetReuseFrames() + f.fr.MaxHeaderListSize = maxHeaderListSize f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil) return f } diff --git a/transport/transport.go b/transport/transport.go index 62d6e6bb..e724ca93 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -454,6 +454,7 @@ type ServerConfig struct { WriteBufferSize int ReadBufferSize int ChannelzParentID int64 + MaxHeaderListSize *uint32 } // NewServerTransport creates a ServerTransport with conn or non-nil error @@ -491,6 +492,8 @@ type ConnectOptions struct { ReadBufferSize int // ChannelzParentID sets the addrConn id which initiate the creation of this client transport. ChannelzParentID int64 + // MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received. + MaxHeaderListSize *uint32 } // TargetInfo contains the information of the target such as network address and metadata. diff --git a/transport/transport_test.go b/transport/transport_test.go index 59045290..fb95e7c5 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -595,7 +595,7 @@ func TestKeepaliveServer(t *testing.T) { if n, err := client.Write(clientPreface); err != nil || n != len(clientPreface) { t.Fatalf("Error writing client preface; n=%v, err=%v", n, err) } - framer := newFramer(client, defaultWriteBufSize, defaultReadBufSize) + framer := newFramer(client, defaultWriteBufSize, defaultReadBufSize, 0) if err := framer.fr.WriteSettings(http2.Setting{}); err != nil { t.Fatal("Error writing settings frame:", err) }