interop: Fix unimplemented method test (#2040)

* Don't send nil requests.

* Fix import name and get rid of condition.

* Let registered encoder deal with nil requests.

* Break encode into encode and compress.
This commit is contained in:
mmukhi
2018-05-02 16:08:12 -07:00
committed by GitHub
parent 7c204fd174
commit 3592bccfd9
8 changed files with 71 additions and 53 deletions

View File

@ -104,7 +104,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
} }
} }
// send a response back to end the stream. // send a response back to end the stream.
data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil) data, err := encode(testCodec{}, &expectedResponse, nil)
if err != nil { if err != nil {
t.Errorf("Failed to encode the response: %v", err) t.Errorf("Failed to encode the response: %v", err)
return return

View File

@ -65,6 +65,9 @@ func marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) {
} }
func (codec) Marshal(v interface{}) ([]byte, error) { func (codec) Marshal(v interface{}) ([]byte, error) {
if v == nil {
return nil, nil
}
if pm, ok := v.(proto.Marshaler); ok { if pm, ok := v.(proto.Marshaler); ok {
// object can marshal itself, no need for buffer // object can marshal itself, no need for buffer
return pm.Marshal() return pm.Marshal()

View File

@ -573,9 +573,8 @@ func DoUnimplementedService(tc testpb.UnimplementedServiceClient) {
// DoUnimplementedMethod attempts to call an unimplemented method. // DoUnimplementedMethod attempts to call an unimplemented method.
func DoUnimplementedMethod(cc *grpc.ClientConn) { func DoUnimplementedMethod(cc *grpc.ClientConn) {
var req, reply proto.Message if err := grpc.Invoke(context.Background(), "/grpc.testing.TestService/UnimplementedCall", &testpb.Empty{}, &testpb.Empty{}, cc); err == nil || status.Code(err) != codes.Unimplemented {
if err := cc.Invoke(context.Background(), "/grpc.testing.TestService/UnimplementedCall", req, reply); err == nil || status.Code(err) != codes.Unimplemented { grpclog.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented)
grpclog.Fatalf("ClientConn.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented)
} }
} }

View File

@ -435,45 +435,49 @@ func recvMsg(s *transport.Stream, maxRecvMsgSize int) (bool, []byte, error) {
} }
// encode serializes msg and returns a buffer of msg. // encode serializes msg and returns a buffer of msg.
// If msg is nil, it generates an empty buffer. func encode(c baseCodec, msg interface{}, outPayload *stats.OutPayload) ([]byte, error) {
// TODO(ddyihai): eliminate extra Compressor parameter. b, err := c.Marshal(msg)
func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, error) { if err != nil {
var ( return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
b []byte }
cbuf *bytes.Buffer if b == nil {
) // If there was no error while marshalling, yet payload was nil,
if msg != nil { // we update it to an empty slice, since a nil payload leads to
var err error // an empty data frame(no gRPC message header is added).
b, err = c.Marshal(msg) b = []byte{}
if err != nil { }
return nil, status.Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error()) if outPayload != nil {
outPayload.Payload = msg
// TODO truncate large payload.
outPayload.Data = b
outPayload.Length = len(b)
}
return b, nil
}
// compress the message if there is a compressor registered.
// TODO(mmukhi, dfawley): eliminate extra Compressor parameter.
func compress(b []byte, cp Compressor, compressor encoding.Compressor, outPayload *stats.OutPayload) ([]byte, bool, error) {
if len(b) <= 0 || (compressor == nil && cp == nil) {
return b, false, nil
}
cbuf := new(bytes.Buffer)
// Has compressor, check Compressor is set by UseCompressor first.
if compressor != nil {
z, _ := compressor.Compress(cbuf)
if _, err := z.Write(b); err != nil {
return nil, false, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
} }
if outPayload != nil { z.Close()
outPayload.Payload = msg } else {
// TODO truncate large payload. // If Compressor is not set by UseCompressor, use default Compressor
outPayload.Data = b if err := cp.Do(cbuf, b); err != nil {
outPayload.Length = len(b) return nil, false, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
if compressor != nil || cp != nil {
cbuf = new(bytes.Buffer)
// Has compressor, check Compressor is set by UseCompressor first.
if compressor != nil {
z, _ := compressor.Compress(cbuf)
if _, err := z.Write(b); err != nil {
return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
z.Close()
} else {
// If Compressor is not set by UseCompressor, use default Compressor
if err := cp.Do(cbuf, b); err != nil {
return nil, status.Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
}
b = cbuf.Bytes()
} }
} }
b = cbuf.Bytes()
if uint(len(b)) > math.MaxUint32 { if uint(len(b)) > math.MaxUint32 {
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b)) return nil, false, status.Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
} }
if outPayload != nil { if outPayload != nil {
@ -481,7 +485,7 @@ func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPa
// before it's put on wire. // before it's put on wire.
outPayload.WireLength = 5 + len(b) outPayload.WireLength = 5 + len(b)
} }
return b, nil return b, true, nil
} }
func checkRecvPayload(recvCompress string, haveCompressor bool) *status.Status { func checkRecvPayload(recvCompress string, haveCompressor bool) *status.Status {

View File

@ -48,16 +48,15 @@ func TestEncode(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
// input // input
msg proto.Message msg proto.Message
cp Compressor
// outputs // outputs
data []byte data []byte
err error err error
}{ }{
{nil, nil, []byte{}, nil}, {&perfpb.Buffer{}, []byte{}, nil},
} { } {
data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil) data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil)
if err != test.err || !bytes.Equal(data, test.data) { if err != test.err || !bytes.Equal(data, test.data) {
t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, data, err, test.data, test.err) t.Fatalf("encode(_, _, _, _) = %v, %v\nwant %v, %v", data, err, test.data, test.err)
} }
} }
} }
@ -156,7 +155,7 @@ func TestParseDialTarget(t *testing.T) {
func bmEncode(b *testing.B, mSize int) { func bmEncode(b *testing.B, mSize int) {
cdc := encoding.GetCodec(protoenc.Name) cdc := encoding.GetCodec(protoenc.Name)
msg := &perfpb.Buffer{Body: make([]byte, mSize)} msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encodeData, _ := encode(cdc, msg, nil, nil, nil) encodeData, _ := encode(cdc, msg, nil)
// 5 bytes of gRPC-specific message header // 5 bytes of gRPC-specific message header
// is added to the message before it is written // is added to the message before it is written
// to the wire. // to the wire.
@ -164,7 +163,7 @@ func bmEncode(b *testing.B, mSize int) {
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
encode(cdc, msg, nil, nil, nil) encode(cdc, msg, nil)
} }
b.SetBytes(encodedSz) b.SetBytes(encodedSz)
} }

View File

@ -833,7 +833,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
if s.opts.statsHandler != nil { if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{} outPayload = &stats.OutPayload{}
} }
data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp) data, err := encode(s.getCodec(stream.ContentSubtype()), msg, outPayload)
if err != nil { if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err) grpclog.Errorln("grpc: server failed to encode response: ", err)
return err return err
@ -841,7 +841,12 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
if len(data) > s.opts.maxSendMessageSize { if len(data) > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
} }
opts.IsCompressed = cp != nil || comp != nil data, isCompressed, err := compress(data, cp, comp, outPayload)
opts.IsCompressed = isCompressed
if err != nil {
grpclog.Errorln("grpc: server failed to compress response: ", err)
return err
}
err = t.Write(stream, data, opts) err = t.Write(stream, data, opts)
if err == nil && outPayload != nil { if err == nil && outPayload != nil {
outPayload.SentTime = time.Now() outPayload.SentTime = time.Now()

View File

@ -470,19 +470,23 @@ func (a *csAttempt) sendMsg(m interface{}) (err error) {
Client: true, Client: true,
} }
} }
data, err := encode(cs.codec, m, cs.cp, outPayload, cs.comp) data, err := encode(cs.codec, m, outPayload)
if err != nil { if err != nil {
return err return err
} }
if len(data) > *cs.c.maxSendMessageSize { if len(data) > *cs.c.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
} }
data, isCompressed, err := compress(data, cs.cp, cs.comp, outPayload)
if err != nil {
return err
}
if !cs.desc.ClientStreams { if !cs.desc.ClientStreams {
cs.sentLast = true cs.sentLast = true
} }
opts := &transport.Options{ opts := &transport.Options{
Last: !cs.desc.ClientStreams, Last: !cs.desc.ClientStreams,
IsCompressed: cs.cp != nil || cs.comp != nil, IsCompressed: isCompressed,
} }
err = a.t.Write(a.s, data, opts) err = a.t.Write(a.s, data, opts)
if err == nil { if err == nil {
@ -701,16 +705,20 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if ss.statsHandler != nil { if ss.statsHandler != nil {
outPayload = &stats.OutPayload{} outPayload = &stats.OutPayload{}
} }
data, err := encode(ss.codec, m, ss.cp, outPayload, ss.comp) data, err := encode(ss.codec, m, outPayload)
if err != nil { if err != nil {
return err return err
} }
if len(data) > ss.maxSendMessageSize { if len(data) > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
} }
data, isCompressed, err := compress(data, ss.cp, ss.comp, outPayload)
if err != nil {
return err
}
opts := &transport.Options{ opts := &transport.Options{
Last: false, Last: false,
IsCompressed: ss.cp != nil || ss.comp != nil, IsCompressed: isCompressed,
} }
if err := ss.t.Write(ss.s, data, opts); err != nil { if err := ss.t.Write(ss.s, data, opts); err != nil {
return toRPCErr(err) return toRPCErr(err)

View File

@ -5923,7 +5923,7 @@ func TestMethodFromServerStream(t *testing.T) {
te.startServer(nil) te.startServer(nil)
defer te.tearDown() defer te.tearDown()
_ = te.clientConn().Invoke(context.Background(), testMethod, nil, nil) _ = te.clientConn().Invoke(context.Background(), testMethod, &testpb.Empty{}, &testpb.Empty{})
if !ok || method != testMethod { if !ok || method != testMethod {
t.Fatalf("Invoke with method %q, got %q, %v, want %q, true", testMethod, method, ok, testMethod) t.Fatalf("Invoke with method %q, got %q, %v, want %q, true", testMethod, method, ok, testMethod)
} }