diff --git a/transport/flowcontrol.go b/transport/flowcontrol.go index 378f5c45..bbf98b6f 100644 --- a/transport/flowcontrol.go +++ b/transport/flowcontrol.go @@ -58,14 +58,20 @@ type writeQuota struct { ch chan struct{} // done is triggered in error case. done <-chan struct{} + // replenish is called by loopyWriter to give quota back to. + // It is implemented as a field so that it can be updated + // by tests. + replenish func(n int) } func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota { - return &writeQuota{ + w := &writeQuota{ quota: sz, ch: make(chan struct{}, 1), done: done, } + w.replenish = w.realReplenish + return w } func (w *writeQuota) get(sz int32) error { @@ -83,7 +89,7 @@ func (w *writeQuota) get(sz int32) error { } } -func (w *writeQuota) replenish(n int) { +func (w *writeQuota) realReplenish(n int) { sz := int32(n) a := atomic.AddInt32(&w.quota, sz) b := a - sz diff --git a/transport/transport_test.go b/transport/transport_test.go index 30b05e88..df695cbc 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -51,6 +51,7 @@ type server struct { mu sync.Mutex conns map[ServerTransport]bool h *testStreamHandler + ready chan struct{} } var ( @@ -62,8 +63,9 @@ var ( ) type testStreamHandler struct { - t *http2Server - notify chan struct{} + t *http2Server + notify chan struct{} + getNotified chan struct{} } type hType int @@ -76,7 +78,6 @@ const ( encodingRequiredStatus invalidHeaderField delayRead - delayWrite pingpong ) @@ -182,6 +183,10 @@ func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stre }) } +// handleStreamDelayRead delays reads so that the other side has to halt on +// stream-level flow control. +// This handler assumes dynamic flow control is turned off and assumes window +// sizes to be set to defaultWindowSize. func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { req := expectedRequest resp := expectedResponse @@ -189,11 +194,52 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { req = expectedRequestLarge resp = expectedResponseLarge } + var ( + mu sync.Mutex + total int + ) + s.wq.replenish = func(n int) { + mu.Lock() + total += n + mu.Unlock() + s.wq.realReplenish(n) + } + getTotal := func() int { + mu.Lock() + defer mu.Unlock() + return total + } + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + // Prevent goroutine from leaking. + case <-done: + return + default: + } + if getTotal() == defaultWindowSize { + // Signal the client to start reading and + // thereby send window update. + close(h.notify) + return + } + runtime.Gosched() + } + }() p := make([]byte, len(req)) - // Wait before reading. Give time to client to start sending - // before server starts reading. - time.Sleep(2 * time.Second) + // Let the other side run out of stream-level window before + // starting to read and thereby sending a window update. + timer := time.NewTimer(time.Second * 10) + select { + case <-h.getNotified: + timer.Stop() + case <-timer.C: + t.Errorf("Server timed-out.") + return + } _, err := s.Read(p) if err != nil { t.Errorf("s.Read(_) = _, %v, want _, ", err) @@ -204,41 +250,19 @@ func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { t.Errorf("handleStream got %v, want %v", p, req) return } - // send a response back to the client. + // This write will cause server to run out of stream level, + // flow control and the other side won't send a window update + // until that happens. if err := h.t.Write(s, nil, resp, &Options{}); err != nil { t.Errorf("server Write got %v, want ", err) return } - // send the trailer to end the stream. - if err := h.t.WriteStatus(s, status.New(codes.OK, "")); err != nil { - t.Errorf("server WriteStatus got %v, want ", err) - return - } -} - -func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) { - req := expectedRequest - resp := expectedResponse - if s.Method() == "foo.Large" { - req = expectedRequestLarge - resp = expectedResponseLarge - } - p := make([]byte, len(req)) - _, err := s.Read(p) + // Read one more time to ensure that everything remains fine and + // that the goroutine, that we launched earlier to signal client + // to read, gets enough time to process. + _, err = s.Read(p) if err != nil { - t.Errorf("s.Read(_) = _, %v, want _, ", err) - return - } - if !bytes.Equal(p, req) { - t.Errorf("handleStream got %v, want %v", p, req) - return - } - - // Wait before sending. Give time to client to start reading - // before server starts sending. - time.Sleep(2 * time.Second) - if err := h.t.Write(s, nil, resp, &Options{}); err != nil { - t.Errorf("server Write got %v, want ", err) + t.Errorf("s.Read(_) = _, %v, want _, nil", err) return } // send the trailer to end the stream. @@ -317,17 +341,16 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT return ctx }) case delayRead: + h.notify = make(chan struct{}) + h.getNotified = make(chan struct{}) + s.mu.Lock() + close(s.ready) + s.mu.Unlock() go transport.HandleStreams(func(s *Stream) { go h.handleStreamDelayRead(t, s) }, func(ctx context.Context, method string) context.Context { return ctx }) - case delayWrite: - go transport.HandleStreams(func(s *Stream) { - go h.handleStreamDelayWrite(t, s) - }, func(ctx context.Context, method string) context.Context { - return ctx - }) case pingpong: go transport.HandleStreams(func(s *Stream) { go h.handleStreamPingPong(t, s) @@ -366,7 +389,7 @@ func (s *server) stop() { } func setUpServerOnly(t *testing.T, port int, serverConfig *ServerConfig, ht hType) *server { - server := &server{startedErr: make(chan error, 1)} + server := &server{startedErr: make(chan error, 1), ready: make(chan struct{})} go server.start(t, port, serverConfig, ht) server.wait(t, 2*time.Second) return server @@ -957,83 +980,99 @@ func TestLargeMessage(t *testing.T) { } func TestLargeMessageWithDelayRead(t *testing.T) { - server, ct := setUp(t, 0, math.MaxUint32, delayRead) + // Disable dynamic flow control. + sc := &ServerConfig{ + InitialWindowSize: defaultWindowSize, + InitialConnWindowSize: defaultWindowSize, + } + co := ConnectOptions{ + InitialWindowSize: defaultWindowSize, + InitialConnWindowSize: defaultWindowSize, + } + server, ct := setUpWithOptions(t, 0, sc, delayRead, co, func() {}) + defer server.stop() + defer ct.Close() + server.mu.Lock() + ready := server.ready + server.mu.Unlock() callHdr := &CallHdr{ Host: "localhost", Method: "foo.Large", } - var wg sync.WaitGroup - for i := 0; i < 2; i++ { - wg.Add(1) - go func() { - defer wg.Done() - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) - defer cancel() - s, err := ct.NewStream(ctx, callHdr) - if err != nil { - t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) - return - } - if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { - t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) - return - } - p := make([]byte, len(expectedResponseLarge)) - - // Give time to server to begin sending before client starts reading. - time.Sleep(2 * time.Second) - if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { - t.Errorf("s.Read(_) = _, %v, want _, ", err) - return - } - if _, err = s.Read(p); err != io.EOF { - t.Errorf("Failed to complete the stream %v; want ", err) - } - }() + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) + defer cancel() + s, err := ct.NewStream(ctx, callHdr) + if err != nil { + t.Fatalf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) + return } - wg.Wait() - ct.Close() - server.stop() -} - -func TestLargeMessageDelayWrite(t *testing.T) { - server, ct := setUp(t, 0, math.MaxUint32, delayWrite) - callHdr := &CallHdr{ - Host: "localhost", - Method: "foo.Large", + // Wait for server's handerler to be initialized + select { + case <-ready: + case <-ctx.Done(): + t.Fatalf("Client timed out waiting for server handler to be initialized.") } - var wg sync.WaitGroup - for i := 0; i < 2; i++ { - wg.Add(1) - go func() { - defer wg.Done() - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second*10)) - defer cancel() - s, err := ct.NewStream(ctx, callHdr) - if err != nil { - t.Errorf("%v.NewStream(_, _) = _, %v, want _, ", ct, err) - return - } - - // Give time to server to start reading before client starts sending. - time.Sleep(2 * time.Second) - if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { - t.Errorf("%v.Write(_, _, _) = %v, want ", ct, err) - return - } - p := make([]byte, len(expectedResponseLarge)) - if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { - t.Errorf("io.ReadFull(%v) = _, %v, want %v, ", err, p, expectedResponse) - return - } - if _, err = s.Read(p); err != io.EOF { - t.Errorf("Failed to complete the stream %v; want ", err) - } - }() + server.mu.Lock() + serviceHandler := server.h + server.mu.Unlock() + var ( + mu sync.Mutex + total int + ) + s.wq.replenish = func(n int) { + mu.Lock() + total += n + mu.Unlock() + s.wq.realReplenish(n) + } + getTotal := func() int { + mu.Lock() + defer mu.Unlock() + return total + } + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + // Prevent goroutine from leaking in case of error. + case <-done: + return + default: + } + if getTotal() == defaultWindowSize { + // unblock server to be able to read and + // thereby send stream level window update. + close(serviceHandler.getNotified) + return + } + runtime.Gosched() + } + }() + // This write will cause client to run out of stream level, + // flow control and the other side won't send a window update + // until that happens. + if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{}); err != nil { + t.Fatalf("write(_, _, _) = %v, want ", err) + } + p := make([]byte, len(expectedResponseLarge)) + + // Wait for the other side to run out of stream level flow control before + // reading and thereby sending a window update. + select { + case <-serviceHandler.notify: + case <-ctx.Done(): + t.Fatalf("Client timed out") + } + if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { + t.Fatalf("s.Read(_) = _, %v, want _, ", err) + } + if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true}); err != nil { + t.Fatalf("Write(_, _, _) = %v, want ", err) + } + if _, err = s.Read(p); err != io.EOF { + t.Fatalf("Failed to complete the stream %v; want ", err) } - wg.Wait() - ct.Close() - server.stop() } func TestGracefulClose(t *testing.T) {