interop: remove test.proto clones/variants and use grpc-proto repo instead (#4129)
This commit is contained in:
@ -35,8 +35,8 @@ import (
|
||||
"google.golang.org/grpc/grpclog"
|
||||
iblog "google.golang.org/grpc/internal/binarylog"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
testpb "google.golang.org/grpc/interop/grpc_testing"
|
||||
"google.golang.org/grpc/metadata"
|
||||
testpb "google.golang.org/grpc/stats/grpc_testing"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
@ -114,6 +114,17 @@ var (
|
||||
globalRPCID uint64 // RPC id starts with 1, but we do ++ at the beginning of each test.
|
||||
)
|
||||
|
||||
func idToPayload(id int32) *testpb.Payload {
|
||||
return &testpb.Payload{Body: []byte{byte(id), byte(id >> 8), byte(id >> 16), byte(id >> 24)}}
|
||||
}
|
||||
|
||||
func payloadToID(p *testpb.Payload) int32 {
|
||||
if p == nil || len(p.Body) != 4 {
|
||||
panic("invalid payload")
|
||||
}
|
||||
return int32(p.Body[0]) + int32(p.Body[1])<<8 + int32(p.Body[2])<<16 + int32(p.Body[3])<<24
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
testpb.UnimplementedTestServiceServer
|
||||
te *test
|
||||
@ -130,11 +141,11 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
||||
}
|
||||
}
|
||||
|
||||
if in.Id == errorID {
|
||||
return nil, fmt.Errorf("got error id: %v", in.Id)
|
||||
if id := payloadToID(in.Payload); id == errorID {
|
||||
return nil, fmt.Errorf("got error id: %v", id)
|
||||
}
|
||||
|
||||
return &testpb.SimpleResponse{Id: in.Id}, nil
|
||||
return &testpb.SimpleResponse{Payload: in.Payload}, nil
|
||||
}
|
||||
|
||||
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
@ -155,17 +166,17 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ
|
||||
return err
|
||||
}
|
||||
|
||||
if in.Id == errorID {
|
||||
return fmt.Errorf("got error id: %v", in.Id)
|
||||
if id := payloadToID(in.Payload); id == errorID {
|
||||
return fmt.Errorf("got error id: %v", id)
|
||||
}
|
||||
|
||||
if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
|
||||
if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error {
|
||||
func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInputCallServer) error {
|
||||
md, ok := metadata.FromIncomingContext(stream.Context())
|
||||
if ok {
|
||||
if err := stream.SendHeader(md); err != nil {
|
||||
@ -177,19 +188,19 @@ func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCall
|
||||
in, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
// read done.
|
||||
return stream.SendAndClose(&testpb.SimpleResponse{Id: int32(0)})
|
||||
return stream.SendAndClose(&testpb.StreamingInputCallResponse{AggregatedPayloadSize: 0})
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if in.Id == errorID {
|
||||
return fmt.Errorf("got error id: %v", in.Id)
|
||||
if id := payloadToID(in.Payload); id == errorID {
|
||||
return fmt.Errorf("got error id: %v", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error {
|
||||
func (s *testServer) StreamingOutputCall(in *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
|
||||
md, ok := metadata.FromIncomingContext(stream.Context())
|
||||
if ok {
|
||||
if err := stream.SendHeader(md); err != nil {
|
||||
@ -198,12 +209,12 @@ func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.Te
|
||||
stream.SetTrailer(testTrailerMetadata)
|
||||
}
|
||||
|
||||
if in.Id == errorID {
|
||||
return fmt.Errorf("got error id: %v", in.Id)
|
||||
if id := payloadToID(in.Payload); id == errorID {
|
||||
return fmt.Errorf("got error id: %v", id)
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
|
||||
if err := stream.Send(&testpb.StreamingOutputCallResponse{Payload: in.Payload}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -334,9 +345,9 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple
|
||||
)
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
if c.success {
|
||||
req = &testpb.SimpleRequest{Id: errorID + 1}
|
||||
req = &testpb.SimpleRequest{Payload: idToPayload(errorID + 1)}
|
||||
} else {
|
||||
req = &testpb.SimpleRequest{Id: errorID}
|
||||
req = &testpb.SimpleRequest{Payload: idToPayload(errorID)}
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@ -346,10 +357,10 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple
|
||||
return req, resp, err
|
||||
}
|
||||
|
||||
func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
|
||||
func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]proto.Message, []proto.Message, error) {
|
||||
var (
|
||||
reqs []*testpb.SimpleRequest
|
||||
resps []*testpb.SimpleResponse
|
||||
reqs []proto.Message
|
||||
resps []proto.Message
|
||||
err error
|
||||
)
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
@ -372,14 +383,14 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
|
||||
startID = errorID
|
||||
}
|
||||
for i := 0; i < c.count; i++ {
|
||||
req := &testpb.SimpleRequest{
|
||||
Id: int32(i) + startID,
|
||||
req := &testpb.StreamingOutputCallRequest{
|
||||
Payload: idToPayload(int32(i) + startID),
|
||||
}
|
||||
reqs = append(reqs, req)
|
||||
if err = stream.Send(req); err != nil {
|
||||
return reqs, resps, err
|
||||
}
|
||||
var resp *testpb.SimpleResponse
|
||||
var resp *testpb.StreamingOutputCallResponse
|
||||
if resp, err = stream.Recv(); err != nil {
|
||||
return reqs, resps, err
|
||||
}
|
||||
@ -395,10 +406,10 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
|
||||
return reqs, resps, nil
|
||||
}
|
||||
|
||||
func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
|
||||
func (te *test) doClientStreamCall(c *rpcConfig) ([]proto.Message, proto.Message, error) {
|
||||
var (
|
||||
reqs []*testpb.SimpleRequest
|
||||
resp *testpb.SimpleResponse
|
||||
reqs []proto.Message
|
||||
resp *testpb.StreamingInputCallResponse
|
||||
err error
|
||||
)
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
@ -406,7 +417,7 @@ func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *test
|
||||
defer cancel()
|
||||
ctx = metadata.NewOutgoingContext(ctx, testMetadata)
|
||||
|
||||
stream, err := tc.ClientStreamCall(ctx)
|
||||
stream, err := tc.StreamingInputCall(ctx)
|
||||
if err != nil {
|
||||
return reqs, resp, err
|
||||
}
|
||||
@ -415,8 +426,8 @@ func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *test
|
||||
startID = errorID
|
||||
}
|
||||
for i := 0; i < c.count; i++ {
|
||||
req := &testpb.SimpleRequest{
|
||||
Id: int32(i) + startID,
|
||||
req := &testpb.StreamingInputCallRequest{
|
||||
Payload: idToPayload(int32(i) + startID),
|
||||
}
|
||||
reqs = append(reqs, req)
|
||||
if err = stream.Send(req); err != nil {
|
||||
@ -427,10 +438,10 @@ func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *test
|
||||
return reqs, resp, err
|
||||
}
|
||||
|
||||
func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
|
||||
func (te *test) doServerStreamCall(c *rpcConfig) (proto.Message, []proto.Message, error) {
|
||||
var (
|
||||
req *testpb.SimpleRequest
|
||||
resps []*testpb.SimpleResponse
|
||||
req *testpb.StreamingOutputCallRequest
|
||||
resps []proto.Message
|
||||
err error
|
||||
)
|
||||
|
||||
@ -443,13 +454,13 @@ func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*test
|
||||
if !c.success {
|
||||
startID = errorID
|
||||
}
|
||||
req = &testpb.SimpleRequest{Id: startID}
|
||||
stream, err := tc.ServerStreamCall(ctx, req)
|
||||
req = &testpb.StreamingOutputCallRequest{Payload: idToPayload(startID)}
|
||||
stream, err := tc.StreamingOutputCall(ctx, req)
|
||||
if err != nil {
|
||||
return req, resps, err
|
||||
}
|
||||
for {
|
||||
var resp *testpb.SimpleResponse
|
||||
var resp *testpb.StreamingOutputCallResponse
|
||||
resp, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
return req, resps, nil
|
||||
@ -465,8 +476,8 @@ type expectedData struct {
|
||||
cc *rpcConfig
|
||||
|
||||
method string
|
||||
requests []*testpb.SimpleRequest
|
||||
responses []*testpb.SimpleResponse
|
||||
requests []proto.Message
|
||||
responses []proto.Message
|
||||
err error
|
||||
}
|
||||
|
||||
@ -534,7 +545,7 @@ func (ed *expectedData) newServerHeaderEntry(client bool, rpcID, inRPCID uint64)
|
||||
}
|
||||
}
|
||||
|
||||
func (ed *expectedData) newClientMessageEntry(client bool, rpcID, inRPCID uint64, msg *testpb.SimpleRequest) *pb.GrpcLogEntry {
|
||||
func (ed *expectedData) newClientMessageEntry(client bool, rpcID, inRPCID uint64, msg proto.Message) *pb.GrpcLogEntry {
|
||||
logger := pb.GrpcLogEntry_LOGGER_CLIENT
|
||||
if !client {
|
||||
logger = pb.GrpcLogEntry_LOGGER_SERVER
|
||||
@ -558,7 +569,7 @@ func (ed *expectedData) newClientMessageEntry(client bool, rpcID, inRPCID uint64
|
||||
}
|
||||
}
|
||||
|
||||
func (ed *expectedData) newServerMessageEntry(client bool, rpcID, inRPCID uint64, msg *testpb.SimpleResponse) *pb.GrpcLogEntry {
|
||||
func (ed *expectedData) newServerMessageEntry(client bool, rpcID, inRPCID uint64, msg proto.Message) *pb.GrpcLogEntry {
|
||||
logger := pb.GrpcLogEntry_LOGGER_CLIENT
|
||||
if !client {
|
||||
logger = pb.GrpcLogEntry_LOGGER_SERVER
|
||||
@ -795,20 +806,20 @@ func runRPCs(t *testing.T, tc *testConfig, cc *rpcConfig) *expectedData {
|
||||
case unaryRPC:
|
||||
expect.method = "/grpc.testing.TestService/UnaryCall"
|
||||
req, resp, err := te.doUnaryCall(cc)
|
||||
expect.requests = []*testpb.SimpleRequest{req}
|
||||
expect.responses = []*testpb.SimpleResponse{resp}
|
||||
expect.requests = []proto.Message{req}
|
||||
expect.responses = []proto.Message{resp}
|
||||
expect.err = err
|
||||
case clientStreamRPC:
|
||||
expect.method = "/grpc.testing.TestService/ClientStreamCall"
|
||||
expect.method = "/grpc.testing.TestService/StreamingInputCall"
|
||||
reqs, resp, err := te.doClientStreamCall(cc)
|
||||
expect.requests = reqs
|
||||
expect.responses = []*testpb.SimpleResponse{resp}
|
||||
expect.responses = []proto.Message{resp}
|
||||
expect.err = err
|
||||
case serverStreamRPC:
|
||||
expect.method = "/grpc.testing.TestService/ServerStreamCall"
|
||||
expect.method = "/grpc.testing.TestService/StreamingOutputCall"
|
||||
req, resps, err := te.doServerStreamCall(cc)
|
||||
expect.responses = resps
|
||||
expect.requests = []*testpb.SimpleRequest{req}
|
||||
expect.requests = []proto.Message{req}
|
||||
expect.err = err
|
||||
case fullDuplexStreamRPC, cancelRPC:
|
||||
expect.method = "/grpc.testing.TestService/FullDuplexCall"
|
||||
|
Reference in New Issue
Block a user