Merge pull request #552 from bradfitz/concurrency

Fix flakiness of TestCancelNoIO with http.Handler-based server transport
This commit is contained in:
Qi Zhao
2016-02-12 17:16:34 -08:00
5 changed files with 328 additions and 110 deletions

View File

@ -34,6 +34,7 @@
package grpc_test
import (
"bytes"
"flag"
"fmt"
"io"
@ -374,9 +375,71 @@ func listTestEnv() (envs []env) {
return envs
}
// serverSetUp is the old way to start a test server. New callers should use newTest.
// TODO(bradfitz): update all tests to newTest and delete this.
func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) {
t.Logf("Running test in %s environment...", e.name)
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)}
te := &test{
t: t,
e: e,
healthServer: hs,
maxStream: maxStream,
cp: cp,
dc: dc,
}
if servON {
te.testServer = &testServer{security: e.security}
}
te.startServer()
return te.srv, te.srvAddr
}
// test is an end-to-end test. It should be created with the newTest
// func, modified as needed, and then started with its startServer method.
// It should be cleaned up with the tearDown method.
type test struct {
t *testing.T
e env
// Configurable knobs, after newTest returns:
testServer testpb.TestServiceServer // nil means none
healthServer *health.HealthServer // nil means disabled
maxStream uint32
cp grpc.Compressor // nil means no server compression
dc grpc.Decompressor // nil means no server decompression
userAgent string
// srv and srvAddr are set once startServer is called.
srv *grpc.Server
srvAddr string
cc *grpc.ClientConn // nil until requested via clientConn
}
func (te *test) tearDown() {
te.srv.Stop()
if te.cc != nil {
te.cc.Close()
}
}
// newTest returns a new test using the provided testing.T and
// environment. It is returned with default values. Tests should
// modify it before calling its startServer and clientConn methods.
func newTest(t *testing.T, e env) *test {
return &test{
t: t,
e: e,
testServer: &testServer{security: e.security},
maxStream: math.MaxUint32,
}
}
// startServer starts a gRPC server listening. Callers should defer a
// call to te.tearDown to clean up.
func (te *test) startServer() {
e := te.e
te.t.Logf("Running test in %s environment...", e.name)
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream), grpc.RPCCompressor(te.cp), grpc.RPCDecompressor(te.dc)}
la := ":0"
switch e.network {
case "unix":
@ -385,37 +448,46 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u
}
lis, err := net.Listen(e.network, la)
if err != nil {
t.Fatalf("Failed to listen: %v", err)
te.t.Fatalf("Failed to listen: %v", err)
}
if e.security == "tls" {
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
t.Fatalf("Failed to generate credentials %v", err)
te.t.Fatalf("Failed to generate credentials %v", err)
}
sopts = append(sopts, grpc.Creds(creds))
}
s = grpc.NewServer(sopts...)
s := grpc.NewServer(sopts...)
te.srv = s
if e.httpHandler {
s.TestingUseHandlerImpl()
}
if hs != nil {
healthpb.RegisterHealthServer(s, hs)
if te.healthServer != nil {
healthpb.RegisterHealthServer(s, te.healthServer)
}
if servON {
testpb.RegisterTestServiceServer(s, &testServer{security: e.security})
if te.testServer != nil {
testpb.RegisterTestServiceServer(s, te.testServer)
}
go s.Serve(lis)
addr = la
addr := la
switch e.network {
case "unix":
default:
_, port, err := net.SplitHostPort(lis.Addr().String())
if err != nil {
t.Fatalf("Failed to parse listener address: %v", err)
te.t.Fatalf("Failed to parse listener address: %v", err)
}
addr = "localhost:" + port
}
return
go s.Serve(lis)
te.srvAddr = addr
}
func (te *test) clientConn() *grpc.ClientConn {
if te.cc == nil {
te.cc = clientSetUp(te.t, te.srvAddr, te.cp, te.dc, te.userAgent, te.e)
}
return te.cc
}
func clientSetUp(t *testing.T, addr string, cp grpc.Compressor, dc grpc.Decompressor, ua string, e env) (cc *grpc.ClientConn) {
@ -888,17 +960,28 @@ func testCancelNoIO(t *testing.T, e env) {
cc := clientSetUp(t, addr, nil, nil, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx, cancel := context.WithCancel(context.Background())
// Start one blocked RPC for which we'll never send streaming
// input. This will consume the 1 maximum concurrent streams,
// causing future RPCs to hang.
ctx, cancelFirst := context.WithCancel(context.Background())
_, err := tc.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
}
// Loop until receiving the new max stream setting from the server.
// Loop until the ClientConn receives the initial settings
// frame from the server, notifying it about the maximum
// concurrent streams. We know when it's received it because
// an RPC will fail with codes.DeadlineExceeded instead of
// succeeding.
// TODO(bradfitz): add internal test hook for this (Issue 534)
for {
ctx, _ := context.WithTimeout(context.Background(), time.Second)
ctx, cancelSecond := context.WithTimeout(context.Background(), 250*time.Millisecond)
_, err := tc.StreamingInputCall(ctx)
cancelSecond()
if err == nil {
time.Sleep(time.Second)
time.Sleep(50 * time.Millisecond)
continue
}
if grpc.Code(err) == codes.DeadlineExceeded {
@ -906,19 +989,23 @@ func testCancelNoIO(t *testing.T, e env) {
}
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
}
// If there are any RPCs slipping before the client receives the max streams setting,
// let them be expired.
time.Sleep(2 * time.Second)
// If there are any RPCs in flight before the client receives
// the max streams setting, let them be expired.
// TODO(bradfitz): add internal test hook for this (Issue 534)
time.Sleep(500 * time.Millisecond)
ch := make(chan struct{})
go func() {
defer close(ch)
// This should be blocked until the 1st is canceled.
ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
ctx, cancelThird := context.WithTimeout(context.Background(), 2*time.Second)
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
}
cancelThird()
}()
cancel()
cancelFirst()
<-ch
}
@ -1169,6 +1256,87 @@ func testFailedServerStreaming(t *testing.T, e env) {
}
}
// concurrentSendServer is a TestServiceServer whose
// StreamingOutputCall makes ten serial Send calls, sending payloads
// "0".."9", inclusive. TestServerStreaming_Concurrent verifies they
// were received in the correct order, and that there were no races.
//
// All other TestServiceServer methods crash if called.
type concurrentSendServer struct {
testpb.TestServiceServer
}
func (s concurrentSendServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
for i := 0; i < 10; i++ {
stream.Send(&testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{
Body: []byte{'0' + uint8(i)},
},
})
}
return nil
}
// Tests doing a bunch of concurrent streaming output calls.
func TestServerStreaming_Concurrent(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testServerStreaming_Concurrent(t, e)
}
}
func testServerStreaming_Concurrent(t *testing.T, e env) {
et := newTest(t, e)
et.testServer = concurrentSendServer{}
et.startServer()
defer et.tearDown()
cc := et.clientConn()
tc := testpb.NewTestServiceClient(cc)
doStreamingCall := func() {
req := &testpb.StreamingOutputCallRequest{}
stream, err := tc.StreamingOutputCall(context.Background(), req)
if err != nil {
t.Errorf("%v.StreamingOutputCall(_) = _, %v, want <nil>", tc, err)
return
}
var ngot int
var buf bytes.Buffer
for {
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
ngot++
if buf.Len() > 0 {
buf.WriteByte(',')
}
buf.Write(reply.GetPayload().GetBody())
}
if want := 10; ngot != want {
t.Errorf("Got %d replies, want %d", ngot, want)
}
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
t.Errorf("Got replies %q; want %q", got, want)
}
}
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
doStreamingCall()
}()
}
wg.Wait()
}
func TestClientStreaming(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {

View File

@ -75,10 +75,10 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
}
st := &serverHandlerTransport{
rw: w,
req: r,
closedCh: make(chan struct{}),
wroteStatus: make(chan struct{}),
rw: w,
req: r,
closedCh: make(chan struct{}),
writes: make(chan func()),
}
if v := r.Header.Get("grpc-timeout"); v != "" {
@ -132,7 +132,10 @@ type serverHandlerTransport struct {
closeOnce sync.Once
closedCh chan struct{} // closed on Close
wroteStatus chan struct{} // closed on WriteStatus
// writes is a channel of code to run serialized in the
// ServeHTTP (HandleStreams) goroutine. The channel is closed
// when WriteStatus is called.
writes chan func()
}
func (ht *serverHandlerTransport) Close() error {
@ -166,31 +169,43 @@ func (a strAddr) Network() string {
func (a strAddr) String() string { return string(a) }
func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
ht.writeCommonHeaders(s)
// And flush, in case no header or body has been sent yet.
// This forces a separation of headers and trailers if this is the
// first call (for example, in end2end tests's TestNoService).
ht.rw.(http.Flusher).Flush()
h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" {
h.Set("Grpc-Message", statusDesc)
// do runs fn in the ServeHTTP goroutine.
func (ht *serverHandlerTransport) do(fn func()) error {
select {
case ht.writes <- fn:
return nil
case <-ht.closedCh:
return ErrConnClosing
}
if md := s.Trailer(); len(md) > 0 {
for k, vv := range md {
for _, v := range vv {
// http2 ResponseWriter mechanism to
// send undeclared Trailers after the
// headers have possibly been written.
h.Add(http2.TrailerPrefix+k, v)
}
func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
err := ht.do(func() {
ht.writeCommonHeaders(s)
// And flush, in case no header or body has been sent yet.
// This forces a separation of headers and trailers if this is the
// first call (for example, in end2end tests's TestNoService).
ht.rw.(http.Flusher).Flush()
h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" {
h.Set("Grpc-Message", statusDesc)
}
if md := s.Trailer(); len(md) > 0 {
for k, vv := range md {
for _, v := range vv {
// http2 ResponseWriter mechanism to
// send undeclared Trailers after the
// headers have possibly been written.
h.Add(http2.TrailerPrefix+k, v)
}
}
}
}
close(ht.wroteStatus)
return nil
})
close(ht.writes)
return err
}
// writeCommonHeaders sets common headers on the first write
@ -219,28 +234,30 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
}
func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
ht.writeCommonHeaders(s)
ht.rw.Write(data)
if !opts.Delay {
ht.rw.(http.Flusher).Flush()
}
return nil
return ht.do(func() {
ht.writeCommonHeaders(s)
ht.rw.Write(data)
if !opts.Delay {
ht.rw.(http.Flusher).Flush()
}
})
}
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
ht.writeCommonHeaders(s)
h := ht.rw.Header()
for k, vv := range md {
for _, v := range vv {
h.Add(k, v)
return ht.do(func() {
ht.writeCommonHeaders(s)
h := ht.rw.Header()
for k, vv := range md {
for _, v := range vv {
h.Add(k, v)
}
}
}
ht.rw.WriteHeader(200)
ht.rw.(http.Flusher).Flush()
return nil
ht.rw.WriteHeader(200)
ht.rw.(http.Flusher).Flush()
})
}
func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.
var ctx context.Context
@ -251,12 +268,18 @@ func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
ctx, cancel = context.WithCancel(context.Background())
}
// requestOver is closed when either the request's context is done
// or the status has been written via WriteStatus.
requestOver := make(chan struct{})
// clientGone receives a single value if peer is gone, either
// because the underlying connection is dead or because the
// peer sends an http2 RST_STREAM.
clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
go func() {
select {
case <-requestOver:
return
case <-ht.closedCh:
case <-clientGone:
}
@ -285,10 +308,6 @@ func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
s.ctx = newContextWithStream(ctx, s)
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
// requestOver is closed when either the request's context is done
// or the status has been written via WriteStatus.
requestOver := make(chan struct{})
// readerDone is closed when the Body.Read-ing goroutine exits.
readerDone := make(chan struct{})
go func() {
@ -296,34 +315,40 @@ func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
for {
buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership
n, err := req.Body.Read(buf)
select {
case <-requestOver:
return
default:
}
if n > 0 {
s.buf.put(&recvMsg{data: buf[:n]})
}
if err != nil {
s.buf.put(&recvMsg{err: err})
break
return
}
}
}()
// runStream is provided by the *grpc.Server.serveStreams.
// It starts a goroutine handling s and exits immediately.
runStream(s)
// startStream is provided by the *grpc.Server's serveStreams.
// It starts a goroutine serving s and exits immediately.
// The goroutine that is started is the one that then calls
// into ht, calling WriteHeader, Write, WriteStatus, Close, etc.
startStream(s)
// Wait for the stream to be done. It is considered done when
// either its context is done, or we've written its status.
select {
case <-ctx.Done():
case <-ht.wroteStatus:
}
ht.runStream()
close(requestOver)
// Wait for reading goroutine to finish.
req.Body.Close()
<-readerDone
}
func (ht *serverHandlerTransport) runStream() {
for {
select {
case fn, ok := <-ht.writes:
if !ok {
return
}
fn()
case <-ht.closedCh:
return
}
}
}

View File

@ -293,13 +293,14 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
func TestHandlerTransport_HandleStreams(t *testing.T) {
st := newHandleStreamTest(t)
st.ht.HandleStreams(func(s *Stream) {
handleStream := func(s *Stream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
}
st.bodyw.Close() // no body
st.ht.WriteStatus(s, codes.OK, "")
})
}
st.ht.HandleStreams(func(s *Stream) { go handleStream(s) })
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
@ -323,9 +324,10 @@ func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
st := newHandleStreamTest(t)
st.ht.HandleStreams(func(s *Stream) {
handleStream := func(s *Stream) {
st.ht.WriteStatus(s, statusCode, msg)
})
}
st.ht.HandleStreams(func(s *Stream) { go handleStream(s) })
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
@ -358,7 +360,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ht.HandleStreams(func(s *Stream) {
runStream := func(s *Stream) {
defer bodyw.Close()
select {
case <-s.ctx.Done():
@ -372,7 +374,8 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
return
}
ht.WriteStatus(s, codes.DeadlineExceeded, "too slow")
})
}
ht.HandleStreams(func(s *Stream) { go runStream(s) })
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
@ -381,6 +384,6 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
"Grpc-Message": {"too slow"},
}
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map: %#v; want %#v", rw.HeaderMap, wantHeader)
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
}
}

View File

@ -62,8 +62,8 @@ type http2Server struct {
maxStreamID uint32 // max stream ID ever seen
authInfo credentials.AuthInfo // auth info about the connection
// writableChan synchronizes write access to the transport.
// A writer acquires the write lock by sending a value on writableChan
// and releases it by receiving from writableChan.
// A writer acquires the write lock by receiving a value on writableChan
// and releases it by sending on writableChan.
writableChan chan int
// shutdownChan is closed when Close is called.
// Blocking operations should select on shutdownChan to avoid

View File

@ -352,30 +352,40 @@ func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, e
// Options provides additional hints and information for message
// transmission.
type Options struct {
// Indicate whether it is the last piece for this stream.
// Last indicates whether this write is the last piece for
// this stream.
Last bool
// The hint to transport impl whether the data could be buffered for
// batching write. Transport impl can feel free to ignore it.
// Delay is a hint to the transport implementation for whether
// the data could be buffered for a batching write. The
// Transport implementation may ignore the hint.
Delay bool
}
// CallHdr carries the information of a particular RPC.
type CallHdr struct {
// Host specifies peer host.
// Host specifies the peer's host.
Host string
// Method specifies the operation to perform.
Method string
// RecvCompress specifies the compression algorithm applied on inbound messages.
// RecvCompress specifies the compression algorithm applied on
// inbound messages.
RecvCompress string
// SendCompress specifies the compression algorithm applied on outbound message.
// SendCompress specifies the compression algorithm applied on
// outbound message.
SendCompress string
// Flush indicates if new stream command should be sent to the peer without
// waiting for the first data. This is a hint though. The transport may modify
// the flush decision for performance purpose.
// Flush indicates whether a new stream command should be sent
// to the peer without waiting for the first data. This is
// only a hint. The transport may modify the flush decision
// for performance purposes.
Flush bool
}
// ClientTransport is the common interface for all gRPC client side transport
// ClientTransport is the common interface for all gRPC client-side transport
// implementations.
type ClientTransport interface {
// Close tears down this transport. Once it returns, the transport
@ -404,21 +414,33 @@ type ClientTransport interface {
Error() <-chan struct{}
}
// ServerTransport is the common interface for all gRPC server side transport
// ServerTransport is the common interface for all gRPC server-side transport
// implementations.
//
// Methods may be called concurrently from multiple goroutines, but
// Write methods for a given Stream will be called serially.
type ServerTransport interface {
// WriteStatus sends the status of a stream to the client.
WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error
// Write sends the data for the given stream.
Write(s *Stream, data []byte, opts *Options) error
// WriteHeader sends the header metedata for the given stream.
WriteHeader(s *Stream, md metadata.MD) error
// HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream))
// WriteHeader sends the header metadata for the given stream.
// WriteHeader may not be called on all streams.
WriteHeader(s *Stream, md metadata.MD) error
// Write sends the data for the given stream.
// Write may not be called on all streams.
Write(s *Stream, data []byte, opts *Options) error
// WriteStatus sends the status of a stream to the client.
// WriteStatus is the final call made on a stream and always
// occurs.
WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error
// Close tears down the transport. Once it is called, the transport
// should not be accessed any more. All the pending streams and their
// handlers will be terminated asynchronously.
Close() error
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr
}