Merge pull request #175 from iamqizhao/master

Add cancel_after_begin and cancel_afer_first_response test cases
This commit is contained in:
Qi Zhao
2015-04-22 15:13:10 -07:00
2 changed files with 62 additions and 7 deletions

View File

@ -45,7 +45,9 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
testpb "google.golang.org/grpc/interop/grpc_testing" testpb "google.golang.org/grpc/interop/grpc_testing"
) )
@ -66,7 +68,9 @@ var (
server_streaming : single request with response streaming; server_streaming : single request with response streaming;
ping_pong : full-duplex streaming; ping_pong : full-duplex streaming;
compute_engine_creds: large_unary with compute engine auth; compute_engine_creds: large_unary with compute engine auth;
service_account_creds: large_unary with service account auth.`) service_account_creds: large_unary with service account auth;
cancel_after_begin: cancellation after metadata has been sent but before payloads are sent;
cancel_after_first_response: cancellation after receiving 1st message from the server.`)
) )
var ( var (
@ -297,6 +301,57 @@ func doServiceAccountCreds(tc testpb.TestServiceClient) {
log.Println("ServiceAccountCreds done") log.Println("ServiceAccountCreds done")
} }
var (
testMetadata = metadata.MD{
"key1": "value1",
"key2": "value2",
}
)
func doCancelAfterBegin(tc testpb.TestServiceClient) {
ctx, cancel := context.WithCancel(metadata.NewContext(context.Background(), testMetadata))
stream, err := tc.StreamingInputCall(ctx)
if err != nil {
log.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
}
cancel()
_, err = stream.CloseAndRecv()
if grpc.Code(err) != codes.Canceled {
log.Fatalf("%v.CloseAndRecv() got error code %d, want %d", stream, grpc.Code(err), codes.Canceled)
}
log.Println("CancelAfterBegin done")
}
func doCancelAfterFirstResponse(tc testpb.TestServiceClient) {
ctx, cancel := context.WithCancel(context.Background())
stream, err := tc.FullDuplexCall(ctx)
if err != nil {
log.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err)
}
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(31415),
},
}
pl := newPayload(testpb.PayloadType_COMPRESSABLE, 27182)
req := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: pl,
}
if err := stream.Send(req); err != nil {
log.Fatalf("%v.Send(%v) = %v", stream, req, err)
}
if _, err := stream.Recv(); err != nil {
log.Fatalf("%v.Recv() = %v", stream, err)
}
cancel()
if _, err := stream.Recv(); grpc.Code(err) != codes.Canceled {
log.Fatalf("%v compleled with error code %d, want %d", stream, grpc.Code(err), codes.Canceled)
}
log.Println("CancelAfterFirstResponse done")
}
func main() { func main() {
flag.Parse() flag.Parse()
serverAddr := net.JoinHostPort(*serverHost, strconv.Itoa(*serverPort)) serverAddr := net.JoinHostPort(*serverHost, strconv.Itoa(*serverPort))
@ -354,6 +409,10 @@ func main() {
log.Fatalf("TLS is not enabled. TLS is required to execute service_account_creds test case.") log.Fatalf("TLS is not enabled. TLS is required to execute service_account_creds test case.")
} }
doServiceAccountCreds(tc) doServiceAccountCreds(tc)
case "cancel_after_begin":
doCancelAfterBegin(tc)
case "cancel_after_first_response":
doCancelAfterFirstResponse(tc)
default: default:
log.Fatal("Unsupported test case: ", *testCase) log.Fatal("Unsupported test case: ", *testCase)
} }

View File

@ -283,9 +283,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
statusDesc = err.Error() statusDesc = err.Error()
} }
} }
if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { t.WriteStatus(stream, statusCode, statusDesc)
log.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err)
}
default: default:
panic(fmt.Sprintf("payload format to be supported: %d", pf)) panic(fmt.Sprintf("payload format to be supported: %d", pf))
} }
@ -308,9 +306,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ss.statusDesc = appErr.Error() ss.statusDesc = appErr.Error()
} }
} }
if err := t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc); err != nil { t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc)
log.Printf("grpc: Server.processStreamingRPC failed to write status: %v", err)
}
} }
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {