client: send RST_STREAM on client-side errors to prevent server from blocking (#1823)

This commit is contained in:
lyuxuan
2018-01-24 11:18:05 -08:00
committed by GitHub
parent 82e9f61ddd
commit c22018a9fb
3 changed files with 81 additions and 3 deletions

View File

@ -4099,6 +4099,7 @@ type funcServer struct {
testpb.TestServiceServer testpb.TestServiceServer
unaryCall func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) unaryCall func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error)
streamingInputCall func(stream testpb.TestService_StreamingInputCallServer) error streamingInputCall func(stream testpb.TestService_StreamingInputCallServer) error
fullDuplexCall func(stream testpb.TestService_FullDuplexCallServer) error
} }
func (s *funcServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { func (s *funcServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
@ -4109,6 +4110,10 @@ func (s *funcServer) StreamingInputCall(stream testpb.TestService_StreamingInput
return s.streamingInputCall(stream) return s.streamingInputCall(stream)
} }
func (s *funcServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
return s.fullDuplexCall(stream)
}
func TestClientRequestBodyErrorUnexpectedEOF(t *testing.T) { func TestClientRequestBodyErrorUnexpectedEOF(t *testing.T) {
defer leakcheck.Check(t) defer leakcheck.Check(t)
for _, e := range listTestEnv() { for _, e := range listTestEnv() {
@ -4230,6 +4235,76 @@ func testClientRequestBodyErrorCancelStreamingInput(t *testing.T, e env) {
}) })
} }
func TestClientResourceExhaustedCancelFullDuplex(t *testing.T) {
defer leakcheck.Check(t)
for _, e := range listTestEnv() {
if e.httpHandler {
// httpHanlder write won't be blocked on flow control window.
continue
}
testClientResourceExhaustedCancelFullDuplex(t, e)
}
}
func testClientResourceExhaustedCancelFullDuplex(t *testing.T, e env) {
te := newTest(t, e)
recvErr := make(chan error, 1)
ts := &funcServer{fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
defer close(recvErr)
_, err := stream.Recv()
if err != nil {
return status.Errorf(codes.Internal, "stream.Recv() got error: %v, want <nil>", err)
}
// create a payload that's larger than the default flow control window.
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 10)
if err != nil {
return err
}
resp := &testpb.StreamingOutputCallResponse{
Payload: payload,
}
ce := make(chan error)
go func() {
var err error
for {
if err = stream.Send(resp); err != nil {
break
}
}
ce <- err
}()
select {
case err = <-ce:
case <-time.After(10 * time.Second):
err = errors.New("10s timeout reached")
}
recvErr <- err
return err
}}
te.startServer(ts)
defer te.tearDown()
// set a low limit on receive message size to error with Resource Exhausted on
// client side when server send a large message.
te.maxClientReceiveMsgSize = newInt(10)
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
stream, err := tc.FullDuplexCall(context.Background())
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
req := &testpb.StreamingOutputCallRequest{}
if err := stream.Send(req); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, req, err)
}
if _, err := stream.Recv(); status.Code(err) != codes.ResourceExhausted {
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.ResourceExhausted)
}
err = <-recvErr
if status.Code(err) != codes.Canceled {
t.Fatalf("server got error %v, want error code: %s", err, codes.Canceled)
}
}
type clientTimeoutCreds struct { type clientTimeoutCreds struct {
timeoutReturned bool timeoutReturned bool
} }

View File

@ -571,6 +571,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
s.mu.Lock() s.mu.Lock()
rstStream = s.rstStream rstStream = s.rstStream
rstError = s.rstError rstError = s.rstError
rstRecv := s.rstReceived
if s.state == streamDone { if s.state == streamDone {
s.mu.Unlock() s.mu.Unlock()
return return
@ -581,7 +582,7 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
} }
s.state = streamDone s.state = streamDone
s.mu.Unlock() s.mu.Unlock()
if _, ok := err.(StreamError); ok { if err != nil && !rstStream && !rstRecv {
rstStream = true rstStream = true
rstError = http2.ErrCodeCancel rstError = http2.ErrCodeCancel
} }
@ -919,6 +920,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
statusCode = codes.Unknown statusCode = codes.Unknown
} }
s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode)) s.finish(status.Newf(statusCode, "stream terminated by RST_STREAM with error code: %v", f.ErrCode))
s.rstReceived = true
s.mu.Unlock() s.mu.Unlock()
s.write(recvMsg{err: io.EOF}) s.write(recvMsg{err: io.EOF})
} }

View File

@ -243,6 +243,7 @@ type Stream struct {
rstStream bool // indicates whether a RST_STREAM frame needs to be sent rstStream bool // indicates whether a RST_STREAM frame needs to be sent
rstError http2.ErrCode // the error that needs to be sent along with the RST_STREAM frame rstError http2.ErrCode // the error that needs to be sent along with the RST_STREAM frame
rstReceived bool // indicates whether a RST_STREAM frame has been received from the other endpoint
bytesReceived bool // indicates whether any bytes have been received on this stream bytesReceived bool // indicates whether any bytes have been received on this stream
unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream