From e60698345eea0fc5c76d2f8db4f09122cfea3678 Mon Sep 17 00:00:00 2001 From: dfawley Date: Tue, 29 Aug 2017 11:04:15 -0700 Subject: [PATCH] Fix context warnings from govet. (#1486) Pre-req work for #1484 --- balancer_test.go | 4 ++- clientconn_test.go | 3 +- stream.go | 5 +++ test/end2end_test.go | 63 ++++++++++++++++++++++--------------- transport/transport_test.go | 15 ++++++--- 5 files changed, 58 insertions(+), 32 deletions(-) diff --git a/balancer_test.go b/balancer_test.go index 3eb611b6..86fc1967 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -284,10 +284,12 @@ func TestGetOnWaitChannel(t *testing.T) { r.w.inject(updates) for { var reply string - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded { + cancel() break } + cancel() time.Sleep(10 * time.Millisecond) } var wg sync.WaitGroup diff --git a/clientconn_test.go b/clientconn_test.go index 95a99c02..bdaf3df1 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -293,7 +293,8 @@ func nonTemporaryErrorDialer(addr string, timeout time.Duration) (net.Conn, erro } func TestDialWithBlockErrorOnNonTemporaryErrorDialer(t *testing.T) { - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() if _, err := DialContext(ctx, "", WithInsecure(), WithDialer(nonTemporaryErrorDialer), WithBlock(), FailOnNonTempDialError(true)); err != nonTemporaryError { t.Fatalf("Dial(%q) = %v, want %v", "", err, nonTemporaryError) } diff --git a/stream.go b/stream.go index 2fcf3687..6f29a975 100644 --- a/stream.go +++ b/stream.go @@ -117,6 +117,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if mc.Timeout != nil { ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) + defer func() { + if err != nil { + cancel() + } + }() } opts = append(cc.dopts.callOptions, opts...) diff --git a/test/end2end_test.go b/test/end2end_test.go index 19f76342..574f4c7d 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -696,8 +696,9 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } te.srv.Stop() - ctx, _ := context.WithTimeout(context.Background(), time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)) + cancel() if e.balancer && grpc.Code(err) != codes.DeadlineExceeded { // If e.balancer == nil, the ac will stop reconnecting because the dialer returns non-temp error, // the error will be an internal error. @@ -756,11 +757,12 @@ func testServerGoAway(t *testing.T, e env) { }() // Loop until the server side GoAway signal is propagated to the client. for { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) == codes.DeadlineExceeded { - continue + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil && grpc.Code(err) != codes.DeadlineExceeded { + cancel() + break } - break + cancel() } // A new RPC should fail. if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable && grpc.Code(err) != codes.Internal { @@ -809,11 +811,12 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { }() // Loop until the server side GoAway signal is propagated to the client. for { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { - continue + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err != nil { + cancel() + break } - break + cancel() } respParam := []*testpb.ResponseParameters{ { @@ -885,11 +888,12 @@ func testServerMultipleGoAwayPendingRPC(t *testing.T, e env) { }() // Loop until the server side GoAway signal is propagated to the client. for { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { - continue + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err != nil { + cancel() + break } - break + cancel() } select { case <-ch1: @@ -1004,11 +1008,12 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) { }() // Loop until the server side GoAway signal is propagated to the client. for { - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil { - continue + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err != nil { + cancel() + break } - break + cancel() } // Stop the server and close all the connections. te.srv.Stop() @@ -1285,14 +1290,16 @@ func testServiceConfigTimeout(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) // The following RPCs are expected to become non-fail-fast ones with 1ns deadline. - ctx, _ := context.WithTimeout(context.Background(), time.Nanosecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) } - ctx, _ = context.WithTimeout(context.Background(), time.Nanosecond) + cancel() + ctx, cancel = context.WithTimeout(context.Background(), time.Nanosecond) if _, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded) } + cancel() // Generate a service config update. // Case2: Client API sets timeout to be 1hr and ServiceConfig sets timeout to be 1ns. Timeout should be 1ns (min of 1ns and 1hr) and the rpc will wait until deadline exceeds. @@ -1316,15 +1323,17 @@ func testServiceConfigTimeout(t *testing.T, e env) { break } - ctx, _ = context.WithTimeout(context.Background(), time.Hour) + ctx, cancel = context.WithTimeout(context.Background(), time.Hour) if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) } + cancel() - ctx, _ = context.WithTimeout(context.Background(), time.Hour) + ctx, cancel = context.WithTimeout(context.Background(), time.Hour) if _, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/FullDuplexCall(_) = _, %v, want %s", err, codes.DeadlineExceeded) } + cancel() } func TestServiceConfigMaxMsgSize(t *testing.T) { @@ -1846,7 +1855,8 @@ func testTap(t *testing.T, e env) { } func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { - ctx, _ := context.WithTimeout(context.Background(), d) + ctx, cancel := context.WithTimeout(context.Background(), d) + defer cancel() hc := healthpb.NewHealthClient(cc) req := &healthpb.HealthCheckRequest{ Service: serviceName, @@ -2872,10 +2882,11 @@ func testRPCTimeout(t *testing.T, e env) { Payload: payload, } for i := -1; i <= 10; i++ { - ctx, _ := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond) if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("TestService/UnaryCallv(_, _) = _, %v; want , error code: %s", err, codes.DeadlineExceeded) } + cancel() } } @@ -3355,7 +3366,8 @@ func testClientStreaming(t *testing.T, e env, sizes []int) { defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) - ctx, _ := context.WithTimeout(te.ctx, time.Second*30) + ctx, cancel := context.WithTimeout(te.ctx, time.Second*30) + defer cancel() stream, err := tc.StreamingInputCall(ctx) if err != nil { t.Fatalf("%v.StreamingInputCall(_) = _, %v, want ", tc, err) @@ -3569,7 +3581,8 @@ func testStreamsQuotaRecovery(t *testing.T, e env) { Payload: payload, } // No rpc should go through due to the max streams limit. - ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() if _, err := tc.UnaryCall(ctx, req, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) } diff --git a/transport/transport_test.go b/transport/transport_test.go index 6b7d6b8c..adaa1723 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -1020,7 +1020,8 @@ func TestLargeMessageSuspension(t *testing.T) { Method: "foo.Large", } // Set a long enough timeout for writing a large message out. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() s, err := ct.NewStream(ctx, callHdr) if err != nil { t.Fatalf("failed to open stream: %v", err) @@ -1846,8 +1847,9 @@ func TestAccountCheckExpandingWindow(t *testing.T) { st.fc.mu.Unlock() // Check flow conrtrol window on client stream is equal to out flow on server stream. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) serverStreamSendQuota, err := wait(ctx, nil, nil, nil, sstream.sendQuotaPool.acquire()) + cancel() if err != nil { return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err) } @@ -1860,8 +1862,9 @@ func TestAccountCheckExpandingWindow(t *testing.T) { } // Check flow control window on server stream is equal to out flow on client stream. - ctx, _ = context.WithTimeout(context.Background(), time.Second) + ctx, cancel = context.WithTimeout(context.Background(), time.Second) clientStreamSendQuota, err := wait(ctx, nil, nil, nil, cstream.sendQuotaPool.acquire()) + cancel() if err != nil { return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err) } @@ -1874,8 +1877,9 @@ func TestAccountCheckExpandingWindow(t *testing.T) { } // Check flow control window on client transport is equal to out flow of server transport. - ctx, _ = context.WithTimeout(context.Background(), time.Second) + ctx, cancel = context.WithTimeout(context.Background(), time.Second) serverTrSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire()) + cancel() if err != nil { return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err) } @@ -1888,8 +1892,9 @@ func TestAccountCheckExpandingWindow(t *testing.T) { } // Check flow control window on server transport is equal to out flow of client transport. - ctx, _ = context.WithTimeout(context.Background(), time.Second) + ctx, cancel = context.WithTimeout(context.Background(), time.Second) clientTrSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire()) + cancel() if err != nil { return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err) }