Merge pull request #116 from iamqizhao/master

Fix some bugs
This commit is contained in:
Qi Zhao
2015-03-13 00:26:30 -07:00
7 changed files with 38 additions and 82 deletions

View File

@ -127,8 +127,10 @@ func Invoke(ctx context.Context, method string, args, reply proto.Message, cc *C
Last: true, Last: true,
Delay: false, Delay: false,
} }
ts := 0 var (
var lastErr error // record the error that happened ts int // track the transport sequence number
lastErr error // record the error that happened
)
for { for {
var ( var (
err error err error

View File

@ -156,7 +156,7 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error {
if err != nil { if err != nil {
sleepTime := backoff(retries) sleepTime := backoff(retries)
// Fail early before falling into sleep. // Fail early before falling into sleep.
if cc.dopts.Timeout > 0 && cc.dopts.Timeout < sleepTime + time.Since(start) { if cc.dopts.Timeout > 0 && cc.dopts.Timeout < sleepTime+time.Since(start) {
cc.Close() cc.Close()
return ErrClientConnTimeout return ErrClientConnTimeout
} }

View File

@ -200,7 +200,7 @@ func Errorf(c codes.Code, format string, a ...interface{}) error {
} }
} }
// toRPCErr converts a transport error into a rpcError if possible. // toRPCErr converts an error into a rpcError.
func toRPCErr(err error) error { func toRPCErr(err error) error {
switch e := err.(type) { switch e := err.(type) {
case transport.StreamError: case transport.StreamError:
@ -214,7 +214,7 @@ func toRPCErr(err error) error {
desc: e.Desc, desc: e.Desc,
} }
} }
return Errorf(codes.Unknown, "grpc: failed to convert %v to rpcErr", err) return Errorf(codes.Unknown, "%v", err)
} }
// convertCode converts a standard Go error into its canonical code. Note that // convertCode converts a standard Go error into its canonical code. Note that

View File

@ -239,6 +239,7 @@ func TestReconnectTimeout(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to dial to the server %q: %v", addr, err) t.Fatalf("Failed to dial to the server %q: %v", addr, err)
} }
// Close unaccepted connection (i.e., conn).
lis.Close() lis.Close()
tc := testpb.NewTestServiceClient(conn) tc := testpb.NewTestServiceClient(conn)
waitC := make(chan struct{}) waitC := make(chan struct{})
@ -251,9 +252,8 @@ func TestReconnectTimeout(t *testing.T) {
ResponseSize: proto.Int32(int32(respSize)), ResponseSize: proto.Int32(int32(respSize)),
Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)), Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)),
} }
_, err := tc.UnaryCall(context.Background(), req) if _, err := tc.UnaryCall(context.Background(), req); err == nil {
if err != grpc.Errorf(codes.Internal, "%v", grpc.ErrClientConnClosing) { t.Fatalf("TestService/UnaryCall(_, _) = _, <nil>, want _, non-nil")
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %v", err, grpc.Errorf(codes.Internal, "%v", grpc.ErrClientConnClosing))
} }
}() }()
// Block untill reconnect times out. // Block untill reconnect times out.

View File

@ -209,16 +209,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
return nil, ContextErr(context.DeadlineExceeded) return nil, ContextErr(context.DeadlineExceeded)
} }
} }
// HPACK encodes various headers. var authData map[string]string
t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
t.hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme})
t.hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method})
t.hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
for _, c := range t.authCreds { for _, c := range t.authCreds {
m, err := c.GetRequestMetadata(ctx) authData, err = c.GetRequestMetadata(ctx)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ContextErr(ctx.Err()) return nil, ContextErr(ctx.Err())
@ -227,13 +220,23 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
if err != nil { if err != nil {
return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err) return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err)
} }
for k, v := range m {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
} }
// HPACK encodes various headers. Note that once WriteField(...) is
// called, the corresponding headers/continuation frame has to be sent
// because hpack.Encoder is stateful.
t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"})
t.hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme})
t.hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method})
t.hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"})
if timeout > 0 { if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
} }
for k, v := range authData {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
if md, ok := metadata.FromContext(ctx); ok { if md, ok := metadata.FromContext(ctx); ok {
for k, v := range md { for k, v := range md {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})

View File

@ -205,6 +205,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
frame, err := t.framer.ReadFrame() frame, err := t.framer.ReadFrame()
if err != nil { if err != nil {
log.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close() t.Close()
return return
} }

View File

@ -230,8 +230,8 @@ func TestClientSendAndReceive(t *testing.T) {
if recvErr != io.EOF { if recvErr != io.EOF {
t.Fatalf("Error: %v; want <EOF>", recvErr) t.Fatalf("Error: %v; want <EOF>", recvErr)
} }
closeClient(ct, t) ct.Close()
closeServer(server, t) server.Close()
} }
func TestClientErrorNotify(t *testing.T) { func TestClientErrorNotify(t *testing.T) {
@ -248,10 +248,10 @@ func TestClientErrorNotify(t *testing.T) {
t.Fatalf("wrong stream id: %d", s.id) t.Fatalf("wrong stream id: %d", s.id)
} }
// Tear down the server. // Tear down the server.
go closeServer(server, t) go server.Close()
// ct.reader should detect the error and activate ct.Error(). // ct.reader should detect the error and activate ct.Error().
<-ct.Error() <-ct.Error()
closeClient(ct, t) ct.Close()
} }
func performOneRPC(ct ClientTransport) { func performOneRPC(ct ClientTransport) {
@ -284,11 +284,11 @@ func TestClientMix(t *testing.T) {
s, ct := setUp(t, true, 0, math.MaxUint32, false) s, ct := setUp(t, true, 0, math.MaxUint32, false)
go func(s *server) { go func(s *server) {
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
closeServer(s, t) s.Close()
}(s) }(s)
go func(ct ClientTransport) { go func(ct ClientTransport) {
<-ct.Error() <-ct.Error()
closeClient(ct, t) ct.Close()
}(ct) }(ct)
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -299,8 +299,8 @@ func TestClientMix(t *testing.T) {
func TestExceedMaxStreamsLimit(t *testing.T) { func TestExceedMaxStreamsLimit(t *testing.T) {
server, ct := setUp(t, true, 0, 1, false) server, ct := setUp(t, true, 0, 1, false)
defer func() { defer func() {
closeClient(ct, t) ct.Close()
closeServer(server, t) server.Close()
}() }()
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
@ -374,8 +374,8 @@ func TestLargeMessage(t *testing.T) {
}() }()
} }
wg.Wait() wg.Wait()
closeClient(ct, t) ct.Close()
closeServer(server, t) server.Close()
} }
func TestLargeMessageSuspension(t *testing.T) { func TestLargeMessageSuspension(t *testing.T) {
@ -396,8 +396,8 @@ func TestLargeMessageSuspension(t *testing.T) {
if err == nil || err != expectedErr { if err == nil || err != expectedErr {
t.Fatalf("Write got %v, want %v", err, expectedErr) t.Fatalf("Write got %v, want %v", err, expectedErr)
} }
closeClient(ct, t) ct.Close()
closeServer(server, t) server.Close()
} }
func TestStreamContext(t *testing.T) { func TestStreamContext(t *testing.T) {
@ -408,53 +408,3 @@ func TestStreamContext(t *testing.T) {
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream) t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream)
} }
} }
// closeClient shuts down the ClientTransport and reports any errors to the
// test framework and terminates the current test case.
func closeClient(ct ClientTransport, t *testing.T) {
if err := ct.Close(); err != nil {
t.Fatalf("ct.Close() = %v, want <nil>", err)
}
}
// closeServerWithErr shuts down the testing server, closing the associated
// transports. It returns the first error it encounters, if any.
func closeServerWithErr(s *server) error {
// Keep consistent with s.Close().
s.lis.Close()
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.conns {
if err := c.Close(); err != nil {
return err
}
}
return nil
}
// closeServer shuts down the and testing server, closing the associated
// transport. It reports any errors to the test framework and terminates the
// current test case.
func closeServer(s *server, t *testing.T) {
if err := closeServerWithErr(s); err != nil {
t.Fatalf("server.Close() = %v, want <nil>", err)
}
}
func TestClientServerDuplicatedClose(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, false)
if err := ct.Close(); err != nil {
t.Fatalf("ct.Close() = %v, want <nil>", err)
}
if err := ct.Close(); err == nil {
// Duplicated closes should gracefully issue an error.
t.Fatalf("ct.Close() = <nil>, want non-nil")
}
if err := closeServerWithErr(server); err != nil {
t.Fatalf("closeServerWithErr(server) = %v, want <nil>", err)
}
if err := closeServerWithErr(server); err == nil {
// Duplicated closes should gracefully issue an error.
t.Fatalf("closeServerWithErr(server) = <nil>, want non-nil")
}
}