From 0b4b292f3a440791d15e33b79d98ee984e775757 Mon Sep 17 00:00:00 2001 From: dfawley Date: Fri, 25 Aug 2017 13:18:45 -0700 Subject: [PATCH] transport: fix handling of InTapHandle's returned context (#1461) --- test/end2end_test.go | 49 +++++++++++++++++++++++++++----- transport/http2_server.go | 60 +++++++++++++++++++++------------------ 2 files changed, 75 insertions(+), 34 deletions(-) diff --git a/test/end2end_test.go b/test/end2end_test.go index 1021744f..19f76342 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -4477,14 +4477,14 @@ func (ss *stubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallSer } // Start starts the server and creates a client connected to it. -func (ss *stubServer) Start() error { +func (ss *stubServer) Start(sopts []grpc.ServerOption) error { lis, err := net.Listen("tcp", "localhost:0") if err != nil { return fmt.Errorf(`net.Listen("tcp", "localhost:0") = %v`, err) } ss.cleanups = append(ss.cleanups, func() { lis.Close() }) - s := grpc.NewServer() + s := grpc.NewServer(sopts...) testpb.RegisterTestServiceServer(s, ss) go s.Serve(lis) ss.cleanups = append(ss.cleanups, s.Stop) @@ -4517,7 +4517,7 @@ func TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { return &testpb.Empty{}, nil }, } - if err := endpoint.Start(); err != nil { + if err := endpoint.Start(nil); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer endpoint.Stop() @@ -4532,7 +4532,7 @@ func TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { return endpoint.client.EmptyCall(ctx, in) }, } - if err := proxy.Start(); err != nil { + if err := proxy.Start(nil); err != nil { t.Fatalf("Error starting proxy server: %v", err) } defer proxy.Stop() @@ -4580,7 +4580,7 @@ func TestStreamingProxyDoesNotForwardMetadata(t *testing.T) { return nil }, } - if err := endpoint.Start(); err != nil { + if err := endpoint.Start(nil); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer endpoint.Stop() @@ -4596,7 +4596,7 @@ func TestStreamingProxyDoesNotForwardMetadata(t *testing.T) { return doFDC(ctx, endpoint.client) }, } - if err := proxy.Start(); err != nil { + if err := proxy.Start(nil); err != nil { t.Fatalf("Error starting proxy server: %v", err) } defer proxy.Stop() @@ -4642,7 +4642,7 @@ func TestStatsTagsAndTrace(t *testing.T) { return &testpb.Empty{}, nil }, } - if err := endpoint.Start(); err != nil { + if err := endpoint.Start(nil); err != nil { t.Fatalf("Error starting endpoint server: %v", err) } defer endpoint.Stop() @@ -4672,6 +4672,41 @@ func TestStatsTagsAndTrace(t *testing.T) { } } +func TestTapTimeout(t *testing.T) { + sopts := []grpc.ServerOption{ + grpc.InTapHandle(func(ctx context.Context, _ *tap.Info) (context.Context, error) { + c, cancel := context.WithCancel(ctx) + // Call cancel instead of setting a deadline so we can detect which error + // occurred -- this cancellation (desired) or the client's deadline + // expired (indicating this cancellation did not affect the RPC). + time.AfterFunc(10*time.Millisecond, cancel) + return c, nil + }), + } + + ss := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + <-ctx.Done() + return &testpb.Empty{}, nil + }, + } + if err := ss.Start(sopts); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + // This was known to be flaky; test several times. + for i := 0; i < 10; i++ { + // Set our own deadline in case the server hangs. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + res, err := ss.client.EmptyCall(ctx, &testpb.Empty{}) + cancel() + if s, ok := status.FromError(err); !ok || s.Code() != codes.Canceled { + t.Fatalf("ss.client.EmptyCall(context.Background(), _) = %v, %v; want nil, ", res, err) + } + } +} + type windowSizeConfig struct { serverStream int32 serverConn int32 diff --git a/transport/http2_server.go b/transport/http2_server.go index 302651b5..eeeacf77 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -230,29 +230,32 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { - buf := newRecvBuffer() - s := &Stream{ - id: frame.Header().StreamID, - st: t, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - } + streamID := frame.Header().StreamID var state decodeState for _, hf := range frame.Fields { if err := state.processHeaderField(hf); err != nil { if se, ok := err.(StreamError); ok { - t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) + t.controlBuf.put(&resetStream{streamID, statusCodeConvTab[se.Code]}) } return } } + buf := newRecvBuffer() + s := &Stream{ + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + recvCompress: state.encoding, + method: state.method, + } + if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone } - s.recvCompress = state.encoding if state.timeoutSet { s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) } else { @@ -280,17 +283,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( if state.statsTrace != nil { s.ctx = stats.SetIncomingTrace(s.ctx, state.statsTrace) } - s.trReader = &transportReader{ - reader: &recvBufferReader{ - ctx: s.ctx, - recv: s.buf, - }, - windowHandler: func(n int) { - t.updateWindow(s, uint32(n)) - }, - } - s.recvCompress = state.encoding - s.method = state.method if t.inTapHandle != nil { var err error info := &tap.Info{ @@ -310,18 +302,18 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } if uint32(len(t.activeStreams)) >= t.maxStreams { t.mu.Unlock() - t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) + t.controlBuf.put(&resetStream{streamID, http2.ErrCodeRefusedStream}) return } - if s.id%2 != 1 || s.id <= t.maxStreamID { + if streamID%2 != 1 || streamID <= t.maxStreamID { t.mu.Unlock() // illegal gRPC stream id. - errorf("transport: http2Server.HandleStreams received an illegal stream id: %v", s.id) + errorf("transport: http2Server.HandleStreams received an illegal stream id: %v", streamID) return true } - t.maxStreamID = s.id + t.maxStreamID = streamID s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) - t.activeStreams[s.id] = s + t.activeStreams[streamID] = s if len(t.activeStreams) == 1 { t.idle = time.Time{} } @@ -341,6 +333,15 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } t.stats.HandleRPC(s.ctx, inHeader) } + s.trReader = &transportReader{ + reader: &recvBufferReader{ + ctx: s.ctx, + recv: s.buf, + }, + windowHandler: func(n int) { + t.updateWindow(s, uint32(n)) + }, + } handle(s) return } @@ -774,8 +775,13 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { headersSent = true } - if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { - return err + // Always write a status regardless of context cancellation unless the stream + // is terminated (e.g. by a RST_STREAM, GOAWAY, or transport error). The + // server's application code is already done so it is fine to ignore s.ctx. + select { + case <-t.shutdownChan: + return ErrConnClosing + case <-t.writableChan: } t.hBuf.Reset() if !headersSent {