diff --git a/clientconn.go b/clientconn.go index fe0a3a04..c2df3e7b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -38,9 +38,9 @@ import ( "sync" "time" + "golang.org/x/net/context" "google.golang.org/grpc/credentials" "google.golang.org/grpc/transport" - "golang.org/x/net/context" ) type dialOptions struct { diff --git a/compiler/go_generator.cc b/compiler/go_generator.cc index b403cc8c..2330f4c2 100644 --- a/compiler/go_generator.cc +++ b/compiler/go_generator.cc @@ -151,7 +151,8 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer, const google::protobuf::MethodDescriptor* method, map* vars, const set& imports, - const map& import_alias) { + const map& import_alias, + int* stream_ind) { (*vars)["Method"] = method->name(); (*vars)["Request"] = GetFullMessageQualifiedName(method->input_type(), imports, import_alias); @@ -171,12 +172,15 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer, printer->Print("\t}\n"); printer->Print("\treturn out, nil\n"); printer->Print("}\n\n"); - } else if (BidiStreaming(method)) { + return; + } + (*vars)["StreamInd"] = std::to_string(*stream_ind); + if (BidiStreaming(method)) { printer->Print( *vars, "func (c *$ServiceStruct$Client) $Method$(ctx context.Context, opts " "...grpc.CallOption) ($Service$_$Method$Client, error) {\n" - "\tstream, err := grpc.NewClientStream(ctx, c.cc, " + "\tstream, err := grpc.NewClientStream(ctx, &_$Service$_serviceDesc.Streams[$StreamInd$], c.cc, " "\"/$Package$$Service$/$Method$\", opts...)\n" "\tif err != nil {\n" "\t\treturn nil, err\n" @@ -214,7 +218,7 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer, "func (c *$ServiceStruct$Client) $Method$(ctx context.Context, m " "*$Request$, " "opts ...grpc.CallOption) ($Service$_$Method$Client, error) {\n" - "\tstream, err := grpc.NewClientStream(ctx, c.cc, " + "\tstream, err := grpc.NewClientStream(ctx, &_$Service$_serviceDesc.Streams[$StreamInd$], c.cc, " "\"/$Package$$Service$/$Method$\", opts...)\n" "\tif err != nil {\n" "\t\treturn nil, err\n" @@ -252,7 +256,7 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer, *vars, "func (c *$ServiceStruct$Client) $Method$(ctx context.Context, opts " "...grpc.CallOption) ($Service$_$Method$Client, error) {\n" - "\tstream, err := grpc.NewClientStream(ctx, c.cc, " + "\tstream, err := grpc.NewClientStream(ctx, &_$Service$_serviceDesc.Streams[$StreamInd$], c.cc, " "\"/$Package$$Service$/$Method$\", opts...)\n" "\tif err != nil {\n" "\t\treturn nil, err\n" @@ -282,18 +286,13 @@ void PrintClientMethodImpl(google::protobuf::io::Printer* printer, "\t\treturn nil, err\n" "\t}\n" "\tm := new($Response$)\n" - "\tif err := x.ClientStream.RecvProto(m); err != nil {\n" + "\tif err := x.ClientStream.RecvProto(m); err != io.EOF {\n" "\t\treturn nil, err\n" "\t}\n" - "\t// Read EOF.\n" - "\tif err := x.ClientStream.RecvProto(m); err == io.EOF {\n" - "\t\treturn m, nil\n" - "\t}\n" - "\t// gRPC protocol violation.\n" - "\treturn m, fmt.Errorf(\"Violate gRPC client streaming protocol: no " - "EOF after the response.\")\n" + "\treturn m, nil\n" "}\n\n"); } + (*stream_ind)++; } void PrintClient(google::protobuf::io::Printer* printer, @@ -318,8 +317,10 @@ void PrintClient(google::protobuf::io::Printer* printer, "func New$Service$Client(cc *grpc.ClientConn) $Service$Client {\n" "\treturn &$ServiceStruct$Client{cc}\n" "}\n\n"); + int stream_ind = 0; for (int i = 0; i < service->method_count(); ++i) { - PrintClientMethodImpl(printer, service->method(i), vars, imports, import_alias); + PrintClientMethodImpl( + printer, service->method(i), vars, imports, import_alias, &stream_ind); } } @@ -489,6 +490,12 @@ void PrintServerStreamingMethodDesc( printer->Print("\t\t{\n"); printer->Print(*vars, "\t\t\tStreamName:\t\"$Method$\",\n"); printer->Print(*vars, "\t\t\tHandler:\t_$Service$_$Method$_Handler,\n"); + if (method->client_streaming()) { + printer->Print(*vars, "\t\t\tClientStreams:\ttrue,\n"); + } + if (method->server_streaming()) { + printer->Print(*vars, "\t\t\tServerStreams:\ttrue,\n"); + } printer->Print("\t\t},\n"); } @@ -505,7 +512,7 @@ void PrintServer(google::protobuf::io::Printer* printer, printer->Print("}\n\n"); printer->Print(*vars, - "func RegisterService(s *grpc.Server, srv $Service$Server) {\n" + "func Register$Service$Server(s *grpc.Server, srv $Service$Server) {\n" "\ts.RegisterService(&_$Service$_serviceDesc, srv)\n" "}\n\n"); @@ -613,7 +620,6 @@ string GetServices(const google::protobuf::FileDescriptor* file, printer.Print("import (\n"); if (HasClientOnlyStreaming(file)) { printer.Print( - "\t\"fmt\"\n" "\t\"io\"\n"); } printer.Print( diff --git a/interop/grpc_testing/test.pb.go b/interop/grpc_testing/test.pb.go index b3ae2b10..0159b8db 100755 --- a/interop/grpc_testing/test.pb.go +++ b/interop/grpc_testing/test.pb.go @@ -59,9 +59,9 @@ import math "math" import ( errors "errors" - io "io" context "golang.org/x/net/context" grpc "google.golang.org/grpc" + io "io" ) // Reference imports to suppress errors if they are not otherwise used. @@ -430,7 +430,7 @@ func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, op } func (c *testServiceClient) StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) { - stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...) + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...) if err != nil { return nil, err } @@ -462,7 +462,7 @@ func (x *testServiceStreamingOutputCallClient) Recv() (*StreamingOutputCallRespo } func (c *testServiceClient) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) { - stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...) + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...) if err != nil { return nil, err } @@ -489,20 +489,14 @@ func (x *testServiceStreamingInputCallClient) CloseAndRecv() (*StreamingInputCal return nil, err } m := new(StreamingInputCallResponse) - if err := x.ClientStream.RecvProto(m); err != nil { + if err := x.ClientStream.RecvProto(m); err != io.EOF { return nil, err } - // Read EOF. - dummy := new(StreamingInputCallResponse) - if err := x.ClientStream.RecvProto(dummy); err != io.EOF { - // gRPC protocol violation. - return nil, errors.New("gRPC client streaming protocol violation: no EOF after final response") - } return m, nil } func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) { - stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[2], c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) if err != nil { return nil, err } @@ -533,7 +527,7 @@ func (x *testServiceFullDuplexCallClient) Recv() (*StreamingOutputCallResponse, } func (c *testServiceClient) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) { - stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...) + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[3], c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...) if err != nil { return nil, err } @@ -730,20 +724,26 @@ var _TestService_serviceDesc = grpc.ServiceDesc{ }, Streams: []grpc.StreamDesc{ { - StreamName: "StreamingOutputCall", - Handler: _TestService_StreamingOutputCall_Handler, + StreamName: "StreamingOutputCall", + Handler: _TestService_StreamingOutputCall_Handler, + ServerStreams: true, }, { - StreamName: "StreamingInputCall", - Handler: _TestService_StreamingInputCall_Handler, + StreamName: "StreamingInputCall", + Handler: _TestService_StreamingInputCall_Handler, + ClientStreams: true, }, { - StreamName: "FullDuplexCall", - Handler: _TestService_FullDuplexCall_Handler, + StreamName: "FullDuplexCall", + Handler: _TestService_FullDuplexCall_Handler, + ClientStreams: true, + ServerStreams: true, }, { - StreamName: "HalfDuplexCall", - Handler: _TestService_HalfDuplexCall_Handler, + StreamName: "HalfDuplexCall", + Handler: _TestService_HalfDuplexCall_Handler, + ClientStreams: true, + ServerStreams: true, }, }, } diff --git a/rpc_util.go b/rpc_util.go index 2a536fb3..fa6e8b06 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -43,10 +43,10 @@ import ( "time" "github.com/golang/protobuf/proto" + "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/transport" - "golang.org/x/net/context" ) // CallOption configures a Call before it starts or extracts information from diff --git a/rpc_util_test.go b/rpc_util_test.go index 082c697c..fe2472ae 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -42,9 +42,9 @@ import ( "time" "github.com/golang/protobuf/proto" + "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/transport" - "golang.org/x/net/context" ) func TestSimpleParsing(t *testing.T) { diff --git a/server.go b/server.go index 714f7542..02b478b8 100644 --- a/server.go +++ b/server.go @@ -43,10 +43,10 @@ import ( "sync" "github.com/golang/protobuf/proto" + "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/transport" - "golang.org/x/net/context" ) type methodHandler func(srv interface{}, ctx context.Context, buf []byte) (proto.Message, error) @@ -57,14 +57,6 @@ type MethodDesc struct { Handler methodHandler } -type streamHandler func(srv interface{}, stream ServerStream) error - -// StreamDesc represents a streaming RPC service's method specification. -type StreamDesc struct { - StreamName string - Handler streamHandler -} - // ServiceDesc represents an RPC service's specification. type ServiceDesc struct { ServiceName string diff --git a/stream.go b/stream.go index 2a3be4ab..2e5cbcb2 100644 --- a/stream.go +++ b/stream.go @@ -34,6 +34,7 @@ package grpc import ( + "fmt" "io" "github.com/golang/protobuf/proto" @@ -43,6 +44,18 @@ import ( "google.golang.org/grpc/transport" ) +type streamHandler func(srv interface{}, stream ServerStream) error + +// StreamDesc represents a streaming RPC service's method specification. +type StreamDesc struct { + StreamName string + Handler streamHandler + + // At least one of these is true. + ServerStreams bool + ClientStreams bool +} + // Stream defines the common interface a client or server stream has to satisfy. type Stream interface { // Context returns the context for this stream. @@ -80,7 +93,7 @@ type ClientStream interface { // NewClientStream creates a new Stream for the client side. This is called // by generated code. -func NewClientStream(ctx context.Context, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { +func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { // TODO(zhaoq): CallOption is omitted. Add support when it is needed. callHdr := &transport.CallHdr{ Host: cc.target, @@ -95,17 +108,19 @@ func NewClientStream(ctx context.Context, cc *ClientConn, method string, opts .. return nil, toRPCErr(err) } return &clientStream{ - t: t, - s: s, - p: &parser{s: s}, + t: t, + s: s, + p: &parser{s: s}, + desc: desc, }, nil } // clientStream implements a client side Stream. type clientStream struct { - t transport.ClientTransport - s *transport.Stream - p *parser + t transport.ClientTransport + s *transport.Stream + p *parser + desc *StreamDesc } func (cs *clientStream) Context() context.Context { @@ -146,7 +161,14 @@ func (cs *clientStream) SendProto(m proto.Message) (err error) { func (cs *clientStream) RecvProto(m proto.Message) (err error) { err = recvProto(cs.p, m) if err == nil { - return + if !cs.desc.ClientStreams || cs.desc.ServerStreams { + return + } + // Special handling for client streaming rpc. + if err = recvProto(cs.p, m); err != io.EOF { + cs.t.CloseStream(cs.s, err) + return fmt.Errorf("gRPC client streaming protocol violation: %v, want ", err) + } } if _, ok := err.(transport.ConnectionError); !ok { cs.t.CloseStream(cs.s, err) diff --git a/test/grpc_testing/test.pb.go b/test/grpc_testing/test.pb.go index b3ae2b10..0159b8db 100755 --- a/test/grpc_testing/test.pb.go +++ b/test/grpc_testing/test.pb.go @@ -59,9 +59,9 @@ import math "math" import ( errors "errors" - io "io" context "golang.org/x/net/context" grpc "google.golang.org/grpc" + io "io" ) // Reference imports to suppress errors if they are not otherwise used. @@ -430,7 +430,7 @@ func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, op } func (c *testServiceClient) StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) { - stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...) + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...) if err != nil { return nil, err } @@ -462,7 +462,7 @@ func (x *testServiceStreamingOutputCallClient) Recv() (*StreamingOutputCallRespo } func (c *testServiceClient) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) { - stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...) + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...) if err != nil { return nil, err } @@ -489,20 +489,14 @@ func (x *testServiceStreamingInputCallClient) CloseAndRecv() (*StreamingInputCal return nil, err } m := new(StreamingInputCallResponse) - if err := x.ClientStream.RecvProto(m); err != nil { + if err := x.ClientStream.RecvProto(m); err != io.EOF { return nil, err } - // Read EOF. - dummy := new(StreamingInputCallResponse) - if err := x.ClientStream.RecvProto(dummy); err != io.EOF { - // gRPC protocol violation. - return nil, errors.New("gRPC client streaming protocol violation: no EOF after final response") - } return m, nil } func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) { - stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[2], c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) if err != nil { return nil, err } @@ -533,7 +527,7 @@ func (x *testServiceFullDuplexCallClient) Recv() (*StreamingOutputCallResponse, } func (c *testServiceClient) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) { - stream, err := grpc.NewClientStream(ctx, c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...) + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[3], c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...) if err != nil { return nil, err } @@ -730,20 +724,26 @@ var _TestService_serviceDesc = grpc.ServiceDesc{ }, Streams: []grpc.StreamDesc{ { - StreamName: "StreamingOutputCall", - Handler: _TestService_StreamingOutputCall_Handler, + StreamName: "StreamingOutputCall", + Handler: _TestService_StreamingOutputCall_Handler, + ServerStreams: true, }, { - StreamName: "StreamingInputCall", - Handler: _TestService_StreamingInputCall_Handler, + StreamName: "StreamingInputCall", + Handler: _TestService_StreamingInputCall_Handler, + ClientStreams: true, }, { - StreamName: "FullDuplexCall", - Handler: _TestService_FullDuplexCall_Handler, + StreamName: "FullDuplexCall", + Handler: _TestService_FullDuplexCall_Handler, + ClientStreams: true, + ServerStreams: true, }, { - StreamName: "HalfDuplexCall", - Handler: _TestService_HalfDuplexCall_Handler, + StreamName: "HalfDuplexCall", + Handler: _TestService_HalfDuplexCall_Handler, + ClientStreams: true, + ServerStreams: true, }, }, }