Merge pull request #552 from bradfitz/concurrency
Fix flakiness of TestCancelNoIO with http.Handler-based server transport
This commit is contained in:
@ -34,6 +34,7 @@
|
|||||||
package grpc_test
|
package grpc_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -374,9 +375,71 @@ func listTestEnv() (envs []env) {
|
|||||||
return envs
|
return envs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serverSetUp is the old way to start a test server. New callers should use newTest.
|
||||||
|
// TODO(bradfitz): update all tests to newTest and delete this.
|
||||||
func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) {
|
func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) {
|
||||||
t.Logf("Running test in %s environment...", e.name)
|
te := &test{
|
||||||
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)}
|
t: t,
|
||||||
|
e: e,
|
||||||
|
healthServer: hs,
|
||||||
|
maxStream: maxStream,
|
||||||
|
cp: cp,
|
||||||
|
dc: dc,
|
||||||
|
}
|
||||||
|
if servON {
|
||||||
|
te.testServer = &testServer{security: e.security}
|
||||||
|
}
|
||||||
|
te.startServer()
|
||||||
|
return te.srv, te.srvAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// test is an end-to-end test. It should be created with the newTest
|
||||||
|
// func, modified as needed, and then started with its startServer method.
|
||||||
|
// It should be cleaned up with the tearDown method.
|
||||||
|
type test struct {
|
||||||
|
t *testing.T
|
||||||
|
e env
|
||||||
|
|
||||||
|
// Configurable knobs, after newTest returns:
|
||||||
|
testServer testpb.TestServiceServer // nil means none
|
||||||
|
healthServer *health.HealthServer // nil means disabled
|
||||||
|
maxStream uint32
|
||||||
|
cp grpc.Compressor // nil means no server compression
|
||||||
|
dc grpc.Decompressor // nil means no server decompression
|
||||||
|
userAgent string
|
||||||
|
|
||||||
|
// srv and srvAddr are set once startServer is called.
|
||||||
|
srv *grpc.Server
|
||||||
|
srvAddr string
|
||||||
|
|
||||||
|
cc *grpc.ClientConn // nil until requested via clientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (te *test) tearDown() {
|
||||||
|
te.srv.Stop()
|
||||||
|
if te.cc != nil {
|
||||||
|
te.cc.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTest returns a new test using the provided testing.T and
|
||||||
|
// environment. It is returned with default values. Tests should
|
||||||
|
// modify it before calling its startServer and clientConn methods.
|
||||||
|
func newTest(t *testing.T, e env) *test {
|
||||||
|
return &test{
|
||||||
|
t: t,
|
||||||
|
e: e,
|
||||||
|
testServer: &testServer{security: e.security},
|
||||||
|
maxStream: math.MaxUint32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startServer starts a gRPC server listening. Callers should defer a
|
||||||
|
// call to te.tearDown to clean up.
|
||||||
|
func (te *test) startServer() {
|
||||||
|
e := te.e
|
||||||
|
te.t.Logf("Running test in %s environment...", e.name)
|
||||||
|
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream), grpc.RPCCompressor(te.cp), grpc.RPCDecompressor(te.dc)}
|
||||||
la := ":0"
|
la := ":0"
|
||||||
switch e.network {
|
switch e.network {
|
||||||
case "unix":
|
case "unix":
|
||||||
@ -385,37 +448,46 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u
|
|||||||
}
|
}
|
||||||
lis, err := net.Listen(e.network, la)
|
lis, err := net.Listen(e.network, la)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to listen: %v", err)
|
te.t.Fatalf("Failed to listen: %v", err)
|
||||||
}
|
}
|
||||||
if e.security == "tls" {
|
if e.security == "tls" {
|
||||||
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
|
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to generate credentials %v", err)
|
te.t.Fatalf("Failed to generate credentials %v", err)
|
||||||
}
|
}
|
||||||
sopts = append(sopts, grpc.Creds(creds))
|
sopts = append(sopts, grpc.Creds(creds))
|
||||||
}
|
}
|
||||||
s = grpc.NewServer(sopts...)
|
s := grpc.NewServer(sopts...)
|
||||||
|
te.srv = s
|
||||||
if e.httpHandler {
|
if e.httpHandler {
|
||||||
s.TestingUseHandlerImpl()
|
s.TestingUseHandlerImpl()
|
||||||
}
|
}
|
||||||
if hs != nil {
|
if te.healthServer != nil {
|
||||||
healthpb.RegisterHealthServer(s, hs)
|
healthpb.RegisterHealthServer(s, te.healthServer)
|
||||||
}
|
}
|
||||||
if servON {
|
if te.testServer != nil {
|
||||||
testpb.RegisterTestServiceServer(s, &testServer{security: e.security})
|
testpb.RegisterTestServiceServer(s, te.testServer)
|
||||||
}
|
}
|
||||||
go s.Serve(lis)
|
addr := la
|
||||||
addr = la
|
|
||||||
switch e.network {
|
switch e.network {
|
||||||
case "unix":
|
case "unix":
|
||||||
default:
|
default:
|
||||||
_, port, err := net.SplitHostPort(lis.Addr().String())
|
_, port, err := net.SplitHostPort(lis.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to parse listener address: %v", err)
|
te.t.Fatalf("Failed to parse listener address: %v", err)
|
||||||
}
|
}
|
||||||
addr = "localhost:" + port
|
addr = "localhost:" + port
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
go s.Serve(lis)
|
||||||
|
te.srvAddr = addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (te *test) clientConn() *grpc.ClientConn {
|
||||||
|
if te.cc == nil {
|
||||||
|
te.cc = clientSetUp(te.t, te.srvAddr, te.cp, te.dc, te.userAgent, te.e)
|
||||||
|
}
|
||||||
|
return te.cc
|
||||||
}
|
}
|
||||||
|
|
||||||
func clientSetUp(t *testing.T, addr string, cp grpc.Compressor, dc grpc.Decompressor, ua string, e env) (cc *grpc.ClientConn) {
|
func clientSetUp(t *testing.T, addr string, cp grpc.Compressor, dc grpc.Decompressor, ua string, e env) (cc *grpc.ClientConn) {
|
||||||
@ -888,17 +960,28 @@ func testCancelNoIO(t *testing.T, e env) {
|
|||||||
cc := clientSetUp(t, addr, nil, nil, "", e)
|
cc := clientSetUp(t, addr, nil, nil, "", e)
|
||||||
tc := testpb.NewTestServiceClient(cc)
|
tc := testpb.NewTestServiceClient(cc)
|
||||||
defer tearDown(s, cc)
|
defer tearDown(s, cc)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
// Start one blocked RPC for which we'll never send streaming
|
||||||
|
// input. This will consume the 1 maximum concurrent streams,
|
||||||
|
// causing future RPCs to hang.
|
||||||
|
ctx, cancelFirst := context.WithCancel(context.Background())
|
||||||
_, err := tc.StreamingInputCall(ctx)
|
_, err := tc.StreamingInputCall(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
|
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
|
||||||
}
|
}
|
||||||
// Loop until receiving the new max stream setting from the server.
|
|
||||||
|
// Loop until the ClientConn receives the initial settings
|
||||||
|
// frame from the server, notifying it about the maximum
|
||||||
|
// concurrent streams. We know when it's received it because
|
||||||
|
// an RPC will fail with codes.DeadlineExceeded instead of
|
||||||
|
// succeeding.
|
||||||
|
// TODO(bradfitz): add internal test hook for this (Issue 534)
|
||||||
for {
|
for {
|
||||||
ctx, _ := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancelSecond := context.WithTimeout(context.Background(), 250*time.Millisecond)
|
||||||
_, err := tc.StreamingInputCall(ctx)
|
_, err := tc.StreamingInputCall(ctx)
|
||||||
|
cancelSecond()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
time.Sleep(time.Second)
|
time.Sleep(50 * time.Millisecond)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if grpc.Code(err) == codes.DeadlineExceeded {
|
if grpc.Code(err) == codes.DeadlineExceeded {
|
||||||
@ -906,19 +989,23 @@ func testCancelNoIO(t *testing.T, e env) {
|
|||||||
}
|
}
|
||||||
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
|
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
|
||||||
}
|
}
|
||||||
// If there are any RPCs slipping before the client receives the max streams setting,
|
// If there are any RPCs in flight before the client receives
|
||||||
// let them be expired.
|
// the max streams setting, let them be expired.
|
||||||
time.Sleep(2 * time.Second)
|
// TODO(bradfitz): add internal test hook for this (Issue 534)
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
ch := make(chan struct{})
|
ch := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
// This should be blocked until the 1st is canceled.
|
// This should be blocked until the 1st is canceled.
|
||||||
ctx, _ := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancelThird := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
if _, err := tc.StreamingInputCall(ctx); err != nil {
|
if _, err := tc.StreamingInputCall(ctx); err != nil {
|
||||||
t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
|
t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
|
||||||
}
|
}
|
||||||
|
cancelThird()
|
||||||
}()
|
}()
|
||||||
cancel()
|
cancelFirst()
|
||||||
<-ch
|
<-ch
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1169,6 +1256,87 @@ func testFailedServerStreaming(t *testing.T, e env) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// concurrentSendServer is a TestServiceServer whose
|
||||||
|
// StreamingOutputCall makes ten serial Send calls, sending payloads
|
||||||
|
// "0".."9", inclusive. TestServerStreaming_Concurrent verifies they
|
||||||
|
// were received in the correct order, and that there were no races.
|
||||||
|
//
|
||||||
|
// All other TestServiceServer methods crash if called.
|
||||||
|
type concurrentSendServer struct {
|
||||||
|
testpb.TestServiceServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s concurrentSendServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
stream.Send(&testpb.StreamingOutputCallResponse{
|
||||||
|
Payload: &testpb.Payload{
|
||||||
|
Body: []byte{'0' + uint8(i)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests doing a bunch of concurrent streaming output calls.
|
||||||
|
func TestServerStreaming_Concurrent(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
testServerStreaming_Concurrent(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testServerStreaming_Concurrent(t *testing.T, e env) {
|
||||||
|
et := newTest(t, e)
|
||||||
|
et.testServer = concurrentSendServer{}
|
||||||
|
et.startServer()
|
||||||
|
defer et.tearDown()
|
||||||
|
|
||||||
|
cc := et.clientConn()
|
||||||
|
tc := testpb.NewTestServiceClient(cc)
|
||||||
|
|
||||||
|
doStreamingCall := func() {
|
||||||
|
req := &testpb.StreamingOutputCallRequest{}
|
||||||
|
stream, err := tc.StreamingOutputCall(context.Background(), req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v.StreamingOutputCall(_) = _, %v, want <nil>", tc, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var ngot int
|
||||||
|
var buf bytes.Buffer
|
||||||
|
for {
|
||||||
|
reply, err := stream.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ngot++
|
||||||
|
if buf.Len() > 0 {
|
||||||
|
buf.WriteByte(',')
|
||||||
|
}
|
||||||
|
buf.Write(reply.GetPayload().GetBody())
|
||||||
|
}
|
||||||
|
if want := 10; ngot != want {
|
||||||
|
t.Errorf("Got %d replies, want %d", ngot, want)
|
||||||
|
}
|
||||||
|
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
|
||||||
|
t.Errorf("Got replies %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
doStreamingCall()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func TestClientStreaming(t *testing.T) {
|
func TestClientStreaming(t *testing.T) {
|
||||||
defer leakCheck(t)()
|
defer leakCheck(t)()
|
||||||
for _, e := range listTestEnv() {
|
for _, e := range listTestEnv() {
|
||||||
|
@ -75,10 +75,10 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
|
|||||||
}
|
}
|
||||||
|
|
||||||
st := &serverHandlerTransport{
|
st := &serverHandlerTransport{
|
||||||
rw: w,
|
rw: w,
|
||||||
req: r,
|
req: r,
|
||||||
closedCh: make(chan struct{}),
|
closedCh: make(chan struct{}),
|
||||||
wroteStatus: make(chan struct{}),
|
writes: make(chan func()),
|
||||||
}
|
}
|
||||||
|
|
||||||
if v := r.Header.Get("grpc-timeout"); v != "" {
|
if v := r.Header.Get("grpc-timeout"); v != "" {
|
||||||
@ -132,7 +132,10 @@ type serverHandlerTransport struct {
|
|||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
closedCh chan struct{} // closed on Close
|
closedCh chan struct{} // closed on Close
|
||||||
|
|
||||||
wroteStatus chan struct{} // closed on WriteStatus
|
// writes is a channel of code to run serialized in the
|
||||||
|
// ServeHTTP (HandleStreams) goroutine. The channel is closed
|
||||||
|
// when WriteStatus is called.
|
||||||
|
writes chan func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) Close() error {
|
func (ht *serverHandlerTransport) Close() error {
|
||||||
@ -166,31 +169,43 @@ func (a strAddr) Network() string {
|
|||||||
|
|
||||||
func (a strAddr) String() string { return string(a) }
|
func (a strAddr) String() string { return string(a) }
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
|
// do runs fn in the ServeHTTP goroutine.
|
||||||
ht.writeCommonHeaders(s)
|
func (ht *serverHandlerTransport) do(fn func()) error {
|
||||||
|
select {
|
||||||
// And flush, in case no header or body has been sent yet.
|
case ht.writes <- fn:
|
||||||
// This forces a separation of headers and trailers if this is the
|
return nil
|
||||||
// first call (for example, in end2end tests's TestNoService).
|
case <-ht.closedCh:
|
||||||
ht.rw.(http.Flusher).Flush()
|
return ErrConnClosing
|
||||||
|
|
||||||
h := ht.rw.Header()
|
|
||||||
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
|
|
||||||
if statusDesc != "" {
|
|
||||||
h.Set("Grpc-Message", statusDesc)
|
|
||||||
}
|
}
|
||||||
if md := s.Trailer(); len(md) > 0 {
|
}
|
||||||
for k, vv := range md {
|
|
||||||
for _, v := range vv {
|
func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
|
||||||
// http2 ResponseWriter mechanism to
|
err := ht.do(func() {
|
||||||
// send undeclared Trailers after the
|
ht.writeCommonHeaders(s)
|
||||||
// headers have possibly been written.
|
|
||||||
h.Add(http2.TrailerPrefix+k, v)
|
// And flush, in case no header or body has been sent yet.
|
||||||
|
// This forces a separation of headers and trailers if this is the
|
||||||
|
// first call (for example, in end2end tests's TestNoService).
|
||||||
|
ht.rw.(http.Flusher).Flush()
|
||||||
|
|
||||||
|
h := ht.rw.Header()
|
||||||
|
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
|
||||||
|
if statusDesc != "" {
|
||||||
|
h.Set("Grpc-Message", statusDesc)
|
||||||
|
}
|
||||||
|
if md := s.Trailer(); len(md) > 0 {
|
||||||
|
for k, vv := range md {
|
||||||
|
for _, v := range vv {
|
||||||
|
// http2 ResponseWriter mechanism to
|
||||||
|
// send undeclared Trailers after the
|
||||||
|
// headers have possibly been written.
|
||||||
|
h.Add(http2.TrailerPrefix+k, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
close(ht.wroteStatus)
|
close(ht.writes)
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeCommonHeaders sets common headers on the first write
|
// writeCommonHeaders sets common headers on the first write
|
||||||
@ -219,28 +234,30 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
|
func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
|
||||||
ht.writeCommonHeaders(s)
|
return ht.do(func() {
|
||||||
ht.rw.Write(data)
|
ht.writeCommonHeaders(s)
|
||||||
if !opts.Delay {
|
ht.rw.Write(data)
|
||||||
ht.rw.(http.Flusher).Flush()
|
if !opts.Delay {
|
||||||
}
|
ht.rw.(http.Flusher).Flush()
|
||||||
return nil
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
||||||
ht.writeCommonHeaders(s)
|
return ht.do(func() {
|
||||||
h := ht.rw.Header()
|
ht.writeCommonHeaders(s)
|
||||||
for k, vv := range md {
|
h := ht.rw.Header()
|
||||||
for _, v := range vv {
|
for k, vv := range md {
|
||||||
h.Add(k, v)
|
for _, v := range vv {
|
||||||
|
h.Add(k, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
ht.rw.WriteHeader(200)
|
||||||
ht.rw.WriteHeader(200)
|
ht.rw.(http.Flusher).Flush()
|
||||||
ht.rw.(http.Flusher).Flush()
|
})
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
|
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
|
||||||
// With this transport type there will be exactly 1 stream: this HTTP request.
|
// With this transport type there will be exactly 1 stream: this HTTP request.
|
||||||
|
|
||||||
var ctx context.Context
|
var ctx context.Context
|
||||||
@ -251,12 +268,18 @@ func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
|
|||||||
ctx, cancel = context.WithCancel(context.Background())
|
ctx, cancel = context.WithCancel(context.Background())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// requestOver is closed when either the request's context is done
|
||||||
|
// or the status has been written via WriteStatus.
|
||||||
|
requestOver := make(chan struct{})
|
||||||
|
|
||||||
// clientGone receives a single value if peer is gone, either
|
// clientGone receives a single value if peer is gone, either
|
||||||
// because the underlying connection is dead or because the
|
// because the underlying connection is dead or because the
|
||||||
// peer sends an http2 RST_STREAM.
|
// peer sends an http2 RST_STREAM.
|
||||||
clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
|
clientGone := ht.rw.(http.CloseNotifier).CloseNotify()
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
|
case <-requestOver:
|
||||||
|
return
|
||||||
case <-ht.closedCh:
|
case <-ht.closedCh:
|
||||||
case <-clientGone:
|
case <-clientGone:
|
||||||
}
|
}
|
||||||
@ -285,10 +308,6 @@ func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
|
|||||||
s.ctx = newContextWithStream(ctx, s)
|
s.ctx = newContextWithStream(ctx, s)
|
||||||
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
|
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
|
||||||
|
|
||||||
// requestOver is closed when either the request's context is done
|
|
||||||
// or the status has been written via WriteStatus.
|
|
||||||
requestOver := make(chan struct{})
|
|
||||||
|
|
||||||
// readerDone is closed when the Body.Read-ing goroutine exits.
|
// readerDone is closed when the Body.Read-ing goroutine exits.
|
||||||
readerDone := make(chan struct{})
|
readerDone := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
@ -296,34 +315,40 @@ func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) {
|
|||||||
for {
|
for {
|
||||||
buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership
|
buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership
|
||||||
n, err := req.Body.Read(buf)
|
n, err := req.Body.Read(buf)
|
||||||
select {
|
|
||||||
case <-requestOver:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
s.buf.put(&recvMsg{data: buf[:n]})
|
s.buf.put(&recvMsg{data: buf[:n]})
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.buf.put(&recvMsg{err: err})
|
s.buf.put(&recvMsg{err: err})
|
||||||
break
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// runStream is provided by the *grpc.Server.serveStreams.
|
// startStream is provided by the *grpc.Server's serveStreams.
|
||||||
// It starts a goroutine handling s and exits immediately.
|
// It starts a goroutine serving s and exits immediately.
|
||||||
runStream(s)
|
// The goroutine that is started is the one that then calls
|
||||||
|
// into ht, calling WriteHeader, Write, WriteStatus, Close, etc.
|
||||||
|
startStream(s)
|
||||||
|
|
||||||
// Wait for the stream to be done. It is considered done when
|
ht.runStream()
|
||||||
// either its context is done, or we've written its status.
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
case <-ht.wroteStatus:
|
|
||||||
}
|
|
||||||
close(requestOver)
|
close(requestOver)
|
||||||
|
|
||||||
// Wait for reading goroutine to finish.
|
// Wait for reading goroutine to finish.
|
||||||
req.Body.Close()
|
req.Body.Close()
|
||||||
<-readerDone
|
<-readerDone
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ht *serverHandlerTransport) runStream() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case fn, ok := <-ht.writes:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fn()
|
||||||
|
case <-ht.closedCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -293,13 +293,14 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
|
|||||||
|
|
||||||
func TestHandlerTransport_HandleStreams(t *testing.T) {
|
func TestHandlerTransport_HandleStreams(t *testing.T) {
|
||||||
st := newHandleStreamTest(t)
|
st := newHandleStreamTest(t)
|
||||||
st.ht.HandleStreams(func(s *Stream) {
|
handleStream := func(s *Stream) {
|
||||||
if want := "/service/foo.bar"; s.method != want {
|
if want := "/service/foo.bar"; s.method != want {
|
||||||
t.Errorf("stream method = %q; want %q", s.method, want)
|
t.Errorf("stream method = %q; want %q", s.method, want)
|
||||||
}
|
}
|
||||||
st.bodyw.Close() // no body
|
st.bodyw.Close() // no body
|
||||||
st.ht.WriteStatus(s, codes.OK, "")
|
st.ht.WriteStatus(s, codes.OK, "")
|
||||||
})
|
}
|
||||||
|
st.ht.HandleStreams(func(s *Stream) { go handleStream(s) })
|
||||||
wantHeader := http.Header{
|
wantHeader := http.Header{
|
||||||
"Date": nil,
|
"Date": nil,
|
||||||
"Content-Type": {"application/grpc"},
|
"Content-Type": {"application/grpc"},
|
||||||
@ -323,9 +324,10 @@ func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
|
|||||||
|
|
||||||
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
|
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
|
||||||
st := newHandleStreamTest(t)
|
st := newHandleStreamTest(t)
|
||||||
st.ht.HandleStreams(func(s *Stream) {
|
handleStream := func(s *Stream) {
|
||||||
st.ht.WriteStatus(s, statusCode, msg)
|
st.ht.WriteStatus(s, statusCode, msg)
|
||||||
})
|
}
|
||||||
|
st.ht.HandleStreams(func(s *Stream) { go handleStream(s) })
|
||||||
wantHeader := http.Header{
|
wantHeader := http.Header{
|
||||||
"Date": nil,
|
"Date": nil,
|
||||||
"Content-Type": {"application/grpc"},
|
"Content-Type": {"application/grpc"},
|
||||||
@ -358,7 +360,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ht.HandleStreams(func(s *Stream) {
|
runStream := func(s *Stream) {
|
||||||
defer bodyw.Close()
|
defer bodyw.Close()
|
||||||
select {
|
select {
|
||||||
case <-s.ctx.Done():
|
case <-s.ctx.Done():
|
||||||
@ -372,7 +374,8 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
ht.WriteStatus(s, codes.DeadlineExceeded, "too slow")
|
ht.WriteStatus(s, codes.DeadlineExceeded, "too slow")
|
||||||
})
|
}
|
||||||
|
ht.HandleStreams(func(s *Stream) { go runStream(s) })
|
||||||
wantHeader := http.Header{
|
wantHeader := http.Header{
|
||||||
"Date": nil,
|
"Date": nil,
|
||||||
"Content-Type": {"application/grpc"},
|
"Content-Type": {"application/grpc"},
|
||||||
@ -381,6 +384,6 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
|||||||
"Grpc-Message": {"too slow"},
|
"Grpc-Message": {"too slow"},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
|
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
|
||||||
t.Errorf("Header+Trailer Map: %#v; want %#v", rw.HeaderMap, wantHeader)
|
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,8 +62,8 @@ type http2Server struct {
|
|||||||
maxStreamID uint32 // max stream ID ever seen
|
maxStreamID uint32 // max stream ID ever seen
|
||||||
authInfo credentials.AuthInfo // auth info about the connection
|
authInfo credentials.AuthInfo // auth info about the connection
|
||||||
// writableChan synchronizes write access to the transport.
|
// writableChan synchronizes write access to the transport.
|
||||||
// A writer acquires the write lock by sending a value on writableChan
|
// A writer acquires the write lock by receiving a value on writableChan
|
||||||
// and releases it by receiving from writableChan.
|
// and releases it by sending on writableChan.
|
||||||
writableChan chan int
|
writableChan chan int
|
||||||
// shutdownChan is closed when Close is called.
|
// shutdownChan is closed when Close is called.
|
||||||
// Blocking operations should select on shutdownChan to avoid
|
// Blocking operations should select on shutdownChan to avoid
|
||||||
|
@ -352,30 +352,40 @@ func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, e
|
|||||||
// Options provides additional hints and information for message
|
// Options provides additional hints and information for message
|
||||||
// transmission.
|
// transmission.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
// Indicate whether it is the last piece for this stream.
|
// Last indicates whether this write is the last piece for
|
||||||
|
// this stream.
|
||||||
Last bool
|
Last bool
|
||||||
// The hint to transport impl whether the data could be buffered for
|
|
||||||
// batching write. Transport impl can feel free to ignore it.
|
// Delay is a hint to the transport implementation for whether
|
||||||
|
// the data could be buffered for a batching write. The
|
||||||
|
// Transport implementation may ignore the hint.
|
||||||
Delay bool
|
Delay bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallHdr carries the information of a particular RPC.
|
// CallHdr carries the information of a particular RPC.
|
||||||
type CallHdr struct {
|
type CallHdr struct {
|
||||||
// Host specifies peer host.
|
// Host specifies the peer's host.
|
||||||
Host string
|
Host string
|
||||||
|
|
||||||
// Method specifies the operation to perform.
|
// Method specifies the operation to perform.
|
||||||
Method string
|
Method string
|
||||||
// RecvCompress specifies the compression algorithm applied on inbound messages.
|
|
||||||
|
// RecvCompress specifies the compression algorithm applied on
|
||||||
|
// inbound messages.
|
||||||
RecvCompress string
|
RecvCompress string
|
||||||
// SendCompress specifies the compression algorithm applied on outbound message.
|
|
||||||
|
// SendCompress specifies the compression algorithm applied on
|
||||||
|
// outbound message.
|
||||||
SendCompress string
|
SendCompress string
|
||||||
// Flush indicates if new stream command should be sent to the peer without
|
|
||||||
// waiting for the first data. This is a hint though. The transport may modify
|
// Flush indicates whether a new stream command should be sent
|
||||||
// the flush decision for performance purpose.
|
// to the peer without waiting for the first data. This is
|
||||||
|
// only a hint. The transport may modify the flush decision
|
||||||
|
// for performance purposes.
|
||||||
Flush bool
|
Flush bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientTransport is the common interface for all gRPC client side transport
|
// ClientTransport is the common interface for all gRPC client-side transport
|
||||||
// implementations.
|
// implementations.
|
||||||
type ClientTransport interface {
|
type ClientTransport interface {
|
||||||
// Close tears down this transport. Once it returns, the transport
|
// Close tears down this transport. Once it returns, the transport
|
||||||
@ -404,21 +414,33 @@ type ClientTransport interface {
|
|||||||
Error() <-chan struct{}
|
Error() <-chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerTransport is the common interface for all gRPC server side transport
|
// ServerTransport is the common interface for all gRPC server-side transport
|
||||||
// implementations.
|
// implementations.
|
||||||
|
//
|
||||||
|
// Methods may be called concurrently from multiple goroutines, but
|
||||||
|
// Write methods for a given Stream will be called serially.
|
||||||
type ServerTransport interface {
|
type ServerTransport interface {
|
||||||
// WriteStatus sends the status of a stream to the client.
|
|
||||||
WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error
|
|
||||||
// Write sends the data for the given stream.
|
|
||||||
Write(s *Stream, data []byte, opts *Options) error
|
|
||||||
// WriteHeader sends the header metedata for the given stream.
|
|
||||||
WriteHeader(s *Stream, md metadata.MD) error
|
|
||||||
// HandleStreams receives incoming streams using the given handler.
|
// HandleStreams receives incoming streams using the given handler.
|
||||||
HandleStreams(func(*Stream))
|
HandleStreams(func(*Stream))
|
||||||
|
|
||||||
|
// WriteHeader sends the header metadata for the given stream.
|
||||||
|
// WriteHeader may not be called on all streams.
|
||||||
|
WriteHeader(s *Stream, md metadata.MD) error
|
||||||
|
|
||||||
|
// Write sends the data for the given stream.
|
||||||
|
// Write may not be called on all streams.
|
||||||
|
Write(s *Stream, data []byte, opts *Options) error
|
||||||
|
|
||||||
|
// WriteStatus sends the status of a stream to the client.
|
||||||
|
// WriteStatus is the final call made on a stream and always
|
||||||
|
// occurs.
|
||||||
|
WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error
|
||||||
|
|
||||||
// Close tears down the transport. Once it is called, the transport
|
// Close tears down the transport. Once it is called, the transport
|
||||||
// should not be accessed any more. All the pending streams and their
|
// should not be accessed any more. All the pending streams and their
|
||||||
// handlers will be terminated asynchronously.
|
// handlers will be terminated asynchronously.
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
// RemoteAddr returns the remote network address.
|
// RemoteAddr returns the remote network address.
|
||||||
RemoteAddr() net.Addr
|
RemoteAddr() net.Addr
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user