diff --git a/call.go b/call.go index aad89fa4..a516017c 100644 --- a/call.go +++ b/call.go @@ -47,7 +47,7 @@ import ( // On error, it returns the error and indicates whether the call should be retried. // // TODO(zhaoq): Check whether the received message sequence is valid. -func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { +func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { // Try to acquire header metadata from the server if there is any. var err error c.headerMD, err = stream.Header() @@ -56,7 +56,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St } p := &parser{s: stream} for { - if err = recv(p, protoCodec{}, reply); err != nil { + if err = recv(p, codec, reply); err != nil { if err == io.EOF { break } @@ -68,7 +68,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St } // sendRequest writes out various information of an RPC such as Context and Message. -func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { +func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { stream, err := t.NewStream(ctx, callHdr) if err != nil { return nil, err @@ -81,7 +81,7 @@ func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.Cl } }() // TODO(zhaoq): Support compression. - outBuf, err := encode(protoCodec{}, args, compressionNone) + outBuf, err := encode(codec, args, compressionNone) if err != nil { return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) } @@ -148,7 +148,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli } return toRPCErr(err) } - stream, err = sendRequest(ctx, callHdr, t, args, topts) + stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts) if err != nil { if _, ok := err.(transport.ConnectionError); ok { lastErr = err @@ -160,7 +160,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli return toRPCErr(err) } // Receive the response - lastErr = recvResponse(t, &c, stream, reply) + lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply) if _, ok := lastErr.(transport.ConnectionError); ok { continue } diff --git a/clientconn.go b/clientconn.go index 8ac660a4..6b78a768 100644 --- a/clientconn.go +++ b/clientconn.go @@ -55,29 +55,43 @@ var ( ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") ) +// dialOptions configure a Dial call. dialOptions are set by the DialOption +// values passed to Dial. +type dialOptions struct { + codec Codec + copts transport.ConnectOptions +} + // DialOption configures how we set up the connection. -type DialOption func(*transport.DialOptions) +type DialOption func(*dialOptions) + +// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling. +func WithCodec(codec Codec) DialOption { + return func(o *dialOptions) { + o.codec = codec + } +} // WithTransportCredentials returns a DialOption which configures a // connection level security credentials (e.g., TLS/SSL). func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { - return func(o *transport.DialOptions) { - o.AuthOptions = append(o.AuthOptions, creds) + return func(o *dialOptions) { + o.copts.AuthOptions = append(o.copts.AuthOptions, creds) } } // WithPerRPCCredentials returns a DialOption which sets // credentials which will place auth state on each outbound RPC. func WithPerRPCCredentials(creds credentials.Credentials) DialOption { - return func(o *transport.DialOptions) { - o.AuthOptions = append(o.AuthOptions, creds) + return func(o *dialOptions) { + o.copts.AuthOptions = append(o.copts.AuthOptions, creds) } } // WithTimeout returns a DialOption that configures a timeout for dialing a client connection. func WithTimeout(d time.Duration) DialOption { - return func(o *transport.DialOptions) { - o.Timeout = d + return func(o *dialOptions) { + o.copts.Timeout = d } } @@ -94,6 +108,10 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { for _, opt := range opts { opt(&cc.dopts) } + if cc.dopts.codec == nil { + // Set the default codec. + cc.dopts.codec = &protoCodec{} + } if err := cc.resetTransport(false); err != nil { return nil, err } @@ -106,7 +124,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { // ClientConn represents a client connection to an RPC service. type ClientConn struct { target string - dopts transport.DialOptions + dopts dialOptions shutdownChan chan struct{} mu sync.Mutex @@ -140,23 +158,23 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error { t.Close() } // Adjust timeout for the current try. - dopts := cc.dopts - if dopts.Timeout < 0 { + copts := cc.dopts.copts + if copts.Timeout < 0 { cc.Close() return ErrClientConnTimeout } - if dopts.Timeout > 0 { - dopts.Timeout -= time.Since(start) - if dopts.Timeout <= 0 { + if copts.Timeout > 0 { + copts.Timeout -= time.Since(start) + if copts.Timeout <= 0 { cc.Close() return ErrClientConnTimeout } } - newTransport, err := transport.NewClientTransport(cc.target, &dopts) + newTransport, err := transport.NewClientTransport(cc.target, &copts) if err != nil { sleepTime := backoff(retries) // Fail early before falling into sleep. - if cc.dopts.Timeout > 0 && cc.dopts.Timeout < sleepTime+time.Since(start) { + if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) { cc.Close() return ErrClientConnTimeout } diff --git a/server.go b/server.go index e50f2b5f..d2369bf9 100644 --- a/server.go +++ b/server.go @@ -85,12 +85,19 @@ type Server struct { } type options struct { + codec Codec maxConcurrentStreams uint32 } // A ServerOption sets options. type ServerOption func(*options) +func CustomCodec(codec Codec) ServerOption { + return func(o *options) { + o.codec = codec + } +} + // MaxConcurrentStreams returns an Option that will apply a limit on the number // of concurrent streams to each ServerTransport. func MaxConcurrentStreams(n uint32) ServerOption { @@ -106,6 +113,10 @@ func NewServer(opt ...ServerOption) *Server { for _, o := range opt { o(&opts) } + if opts.codec == nil { + // Set the default codec. + opts.codec = &protoCodec{} + } return &Server{ lis: make(map[net.Listener]bool), opts: opts, @@ -203,7 +214,7 @@ func (s *Server) Serve(lis net.Listener) error { } func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error { - p, err := encode(protoCodec{}, msg, pf) + p, err := encode(s.opts.codec, msg, pf) if err != nil { // This typically indicates a fatal issue (e.g., memory // corruption or hardware faults) the application program @@ -286,6 +297,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp t: t, s: stream, p: &parser{s: stream}, + codec: s.opts.codec, } if appErr := sd.Handler(srv.server, ss); appErr != nil { if err, ok := appErr.(rpcError); ok { diff --git a/stream.go b/stream.go index 625535a6..46e538f8 100644 --- a/stream.go +++ b/stream.go @@ -66,7 +66,7 @@ type Stream interface { // side. On server side, it simply returns the error to the caller. // SendMsg is called by generated code. SendMsg(m interface{}) error - // RecvMsg blocks until it receives a proto message or the stream is + // RecvMsg blocks until it receives a message or the stream is // done. On client side, it returns io.EOF when the stream is done. On // any other error, it aborts the streama nd returns an RPC status. On // server side, it simply returns the error to the caller. @@ -116,6 +116,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth s: s, p: &parser{s: s}, desc: desc, + codec: cc.dopts.codec, }, nil } @@ -125,6 +126,7 @@ type clientStream struct { s *transport.Stream p *parser desc *StreamDesc + codec Codec } func (cs *clientStream) Context() context.Context { @@ -155,7 +157,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } err = toRPCErr(err) }() - out, err := encode(protoCodec{}, m, compressionNone) + out, err := encode(cs.codec, m, compressionNone) if err != nil { return transport.StreamErrorf(codes.Internal, "grpc: %v", err) } @@ -163,13 +165,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - err = recv(cs.p, protoCodec{}, m) + err = recv(cs.p, cs.codec, m) if err == nil { if !cs.desc.ClientStreams || cs.desc.ServerStreams { return } // Special handling for client streaming rpc. - err = recv(cs.p, protoCodec{}, m) + err = recv(cs.p, cs.codec, m) cs.t.CloseStream(cs.s, err) if err == nil { return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) @@ -224,6 +226,7 @@ type serverStream struct { t transport.ServerTransport s *transport.Stream p *parser + codec Codec statusCode codes.Code statusDesc string } @@ -245,7 +248,7 @@ func (ss *serverStream) SetTrailer(md metadata.MD) { } func (ss *serverStream) SendMsg(m interface{}) error { - out, err := encode(protoCodec{}, m, compressionNone) + out, err := encode(ss.codec, m, compressionNone) if err != nil { err = transport.StreamErrorf(codes.Internal, "grpc: %v", err) return err @@ -254,5 +257,5 @@ func (ss *serverStream) SendMsg(m interface{}) error { } func (ss *serverStream) RecvMsg(m interface{}) error { - return recv(ss.p, protoCodec{}, m) + return recv(ss.p, ss.codec, m) } diff --git a/transport/http2_client.go b/transport/http2_client.go index 3c1044be..c61b4c5b 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -100,7 +100,7 @@ type http2Client struct { // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // and starts to receive messages on it. Non-nil error returns if construction // fails. -func newHTTP2Client(addr string, opts *DialOptions) (_ ClientTransport, err error) { +func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { var ( connErr error conn net.Conn diff --git a/transport/transport.go b/transport/transport.go index 56b31b87..a887857d 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -313,16 +313,16 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv return newHTTP2Server(conn, maxStreams) } -// DialOptions covers all relevant options for dialing a server. -type DialOptions struct { +// ConnectOptions covers all relevant options for dialing a server. +type ConnectOptions struct { Protocol string AuthOptions []credentials.Credentials Timeout time.Duration } -// NewClientTransport establishes the transport with the required DialOptions +// NewClientTransport establishes the transport with the required ConnectOptions // and returns it to the caller. -func NewClientTransport(target string, opts *DialOptions) (ClientTransport, error) { +func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { return newHTTP2Client(target, opts) } diff --git a/transport/transport_test.go b/transport/transport_test.go index 39792e4a..f017f1cb 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -182,12 +182,12 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool) if err != nil { t.Fatalf("Failed to create credentials %v", err) } - dopts := DialOptions{ + dopts := ConnectOptions{ AuthOptions: []credentials.Credentials{creds}, } ct, connErr = NewClientTransport(addr, &dopts) } else { - ct, connErr = NewClientTransport(addr, &DialOptions{}) + ct, connErr = NewClientTransport(addr, &ConnectOptions{}) } if connErr != nil { t.Fatalf("failed to create transport: %v", connErr)