make Codec configurable when creating grpc.ClientConn and grpc.Server
This commit is contained in:
12
call.go
12
call.go
@ -47,7 +47,7 @@ import (
|
||||
// On error, it returns the error and indicates whether the call should be retried.
|
||||
//
|
||||
// TODO(zhaoq): Check whether the received message sequence is valid.
|
||||
func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
|
||||
func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
|
||||
// Try to acquire header metadata from the server if there is any.
|
||||
var err error
|
||||
c.headerMD, err = stream.Header()
|
||||
@ -56,7 +56,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St
|
||||
}
|
||||
p := &parser{s: stream}
|
||||
for {
|
||||
if err = recv(p, protoCodec{}, reply); err != nil {
|
||||
if err = recv(p, codec, reply); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
@ -68,7 +68,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St
|
||||
}
|
||||
|
||||
// sendRequest writes out various information of an RPC such as Context and Message.
|
||||
func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
|
||||
func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
|
||||
stream, err := t.NewStream(ctx, callHdr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -81,7 +81,7 @@ func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.Cl
|
||||
}
|
||||
}()
|
||||
// TODO(zhaoq): Support compression.
|
||||
outBuf, err := encode(protoCodec{}, args, compressionNone)
|
||||
outBuf, err := encode(codec, args, compressionNone)
|
||||
if err != nil {
|
||||
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
||||
}
|
||||
@ -148,7 +148,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
||||
}
|
||||
return toRPCErr(err)
|
||||
}
|
||||
stream, err = sendRequest(ctx, callHdr, t, args, topts)
|
||||
stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts)
|
||||
if err != nil {
|
||||
if _, ok := err.(transport.ConnectionError); ok {
|
||||
lastErr = err
|
||||
@ -160,7 +160,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
||||
return toRPCErr(err)
|
||||
}
|
||||
// Receive the response
|
||||
lastErr = recvResponse(t, &c, stream, reply)
|
||||
lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply)
|
||||
if _, ok := lastErr.(transport.ConnectionError); ok {
|
||||
continue
|
||||
}
|
||||
|
@ -55,29 +55,43 @@ var (
|
||||
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
|
||||
)
|
||||
|
||||
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
||||
// values passed to Dial.
|
||||
type dialOptions struct {
|
||||
codec Codec
|
||||
copts transport.ConnectOptions
|
||||
}
|
||||
|
||||
// DialOption configures how we set up the connection.
|
||||
type DialOption func(*transport.DialOptions)
|
||||
type DialOption func(*dialOptions)
|
||||
|
||||
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
|
||||
func WithCodec(codec Codec) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.codec = codec
|
||||
}
|
||||
}
|
||||
|
||||
// WithTransportCredentials returns a DialOption which configures a
|
||||
// connection level security credentials (e.g., TLS/SSL).
|
||||
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
|
||||
return func(o *transport.DialOptions) {
|
||||
o.AuthOptions = append(o.AuthOptions, creds)
|
||||
return func(o *dialOptions) {
|
||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
||||
}
|
||||
}
|
||||
|
||||
// WithPerRPCCredentials returns a DialOption which sets
|
||||
// credentials which will place auth state on each outbound RPC.
|
||||
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
||||
return func(o *transport.DialOptions) {
|
||||
o.AuthOptions = append(o.AuthOptions, creds)
|
||||
return func(o *dialOptions) {
|
||||
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout returns a DialOption that configures a timeout for dialing a client connection.
|
||||
func WithTimeout(d time.Duration) DialOption {
|
||||
return func(o *transport.DialOptions) {
|
||||
o.Timeout = d
|
||||
return func(o *dialOptions) {
|
||||
o.copts.Timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,6 +108,10 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||
for _, opt := range opts {
|
||||
opt(&cc.dopts)
|
||||
}
|
||||
if cc.dopts.codec == nil {
|
||||
// Set the default codec.
|
||||
cc.dopts.codec = &protoCodec{}
|
||||
}
|
||||
if err := cc.resetTransport(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -106,7 +124,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||
// ClientConn represents a client connection to an RPC service.
|
||||
type ClientConn struct {
|
||||
target string
|
||||
dopts transport.DialOptions
|
||||
dopts dialOptions
|
||||
shutdownChan chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
@ -140,23 +158,23 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error {
|
||||
t.Close()
|
||||
}
|
||||
// Adjust timeout for the current try.
|
||||
dopts := cc.dopts
|
||||
if dopts.Timeout < 0 {
|
||||
copts := cc.dopts.copts
|
||||
if copts.Timeout < 0 {
|
||||
cc.Close()
|
||||
return ErrClientConnTimeout
|
||||
}
|
||||
if dopts.Timeout > 0 {
|
||||
dopts.Timeout -= time.Since(start)
|
||||
if dopts.Timeout <= 0 {
|
||||
if copts.Timeout > 0 {
|
||||
copts.Timeout -= time.Since(start)
|
||||
if copts.Timeout <= 0 {
|
||||
cc.Close()
|
||||
return ErrClientConnTimeout
|
||||
}
|
||||
}
|
||||
newTransport, err := transport.NewClientTransport(cc.target, &dopts)
|
||||
newTransport, err := transport.NewClientTransport(cc.target, &copts)
|
||||
if err != nil {
|
||||
sleepTime := backoff(retries)
|
||||
// Fail early before falling into sleep.
|
||||
if cc.dopts.Timeout > 0 && cc.dopts.Timeout < sleepTime+time.Since(start) {
|
||||
if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) {
|
||||
cc.Close()
|
||||
return ErrClientConnTimeout
|
||||
}
|
||||
|
14
server.go
14
server.go
@ -85,12 +85,19 @@ type Server struct {
|
||||
}
|
||||
|
||||
type options struct {
|
||||
codec Codec
|
||||
maxConcurrentStreams uint32
|
||||
}
|
||||
|
||||
// A ServerOption sets options.
|
||||
type ServerOption func(*options)
|
||||
|
||||
func CustomCodec(codec Codec) ServerOption {
|
||||
return func(o *options) {
|
||||
o.codec = codec
|
||||
}
|
||||
}
|
||||
|
||||
// MaxConcurrentStreams returns an Option that will apply a limit on the number
|
||||
// of concurrent streams to each ServerTransport.
|
||||
func MaxConcurrentStreams(n uint32) ServerOption {
|
||||
@ -106,6 +113,10 @@ func NewServer(opt ...ServerOption) *Server {
|
||||
for _, o := range opt {
|
||||
o(&opts)
|
||||
}
|
||||
if opts.codec == nil {
|
||||
// Set the default codec.
|
||||
opts.codec = &protoCodec{}
|
||||
}
|
||||
return &Server{
|
||||
lis: make(map[net.Listener]bool),
|
||||
opts: opts,
|
||||
@ -203,7 +214,7 @@ func (s *Server) Serve(lis net.Listener) error {
|
||||
}
|
||||
|
||||
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error {
|
||||
p, err := encode(protoCodec{}, msg, pf)
|
||||
p, err := encode(s.opts.codec, msg, pf)
|
||||
if err != nil {
|
||||
// This typically indicates a fatal issue (e.g., memory
|
||||
// corruption or hardware faults) the application program
|
||||
@ -286,6 +297,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
||||
t: t,
|
||||
s: stream,
|
||||
p: &parser{s: stream},
|
||||
codec: s.opts.codec,
|
||||
}
|
||||
if appErr := sd.Handler(srv.server, ss); appErr != nil {
|
||||
if err, ok := appErr.(rpcError); ok {
|
||||
|
15
stream.go
15
stream.go
@ -66,7 +66,7 @@ type Stream interface {
|
||||
// side. On server side, it simply returns the error to the caller.
|
||||
// SendMsg is called by generated code.
|
||||
SendMsg(m interface{}) error
|
||||
// RecvMsg blocks until it receives a proto message or the stream is
|
||||
// RecvMsg blocks until it receives a message or the stream is
|
||||
// done. On client side, it returns io.EOF when the stream is done. On
|
||||
// any other error, it aborts the streama nd returns an RPC status. On
|
||||
// server side, it simply returns the error to the caller.
|
||||
@ -116,6 +116,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
||||
s: s,
|
||||
p: &parser{s: s},
|
||||
desc: desc,
|
||||
codec: cc.dopts.codec,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -125,6 +126,7 @@ type clientStream struct {
|
||||
s *transport.Stream
|
||||
p *parser
|
||||
desc *StreamDesc
|
||||
codec Codec
|
||||
}
|
||||
|
||||
func (cs *clientStream) Context() context.Context {
|
||||
@ -155,7 +157,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
||||
}
|
||||
err = toRPCErr(err)
|
||||
}()
|
||||
out, err := encode(protoCodec{}, m, compressionNone)
|
||||
out, err := encode(cs.codec, m, compressionNone)
|
||||
if err != nil {
|
||||
return transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
||||
}
|
||||
@ -163,13 +165,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
||||
}
|
||||
|
||||
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||
err = recv(cs.p, protoCodec{}, m)
|
||||
err = recv(cs.p, cs.codec, m)
|
||||
if err == nil {
|
||||
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
|
||||
return
|
||||
}
|
||||
// Special handling for client streaming rpc.
|
||||
err = recv(cs.p, protoCodec{}, m)
|
||||
err = recv(cs.p, cs.codec, m)
|
||||
cs.t.CloseStream(cs.s, err)
|
||||
if err == nil {
|
||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||
@ -224,6 +226,7 @@ type serverStream struct {
|
||||
t transport.ServerTransport
|
||||
s *transport.Stream
|
||||
p *parser
|
||||
codec Codec
|
||||
statusCode codes.Code
|
||||
statusDesc string
|
||||
}
|
||||
@ -245,7 +248,7 @@ func (ss *serverStream) SetTrailer(md metadata.MD) {
|
||||
}
|
||||
|
||||
func (ss *serverStream) SendMsg(m interface{}) error {
|
||||
out, err := encode(protoCodec{}, m, compressionNone)
|
||||
out, err := encode(ss.codec, m, compressionNone)
|
||||
if err != nil {
|
||||
err = transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
||||
return err
|
||||
@ -254,5 +257,5 @@ func (ss *serverStream) SendMsg(m interface{}) error {
|
||||
}
|
||||
|
||||
func (ss *serverStream) RecvMsg(m interface{}) error {
|
||||
return recv(ss.p, protoCodec{}, m)
|
||||
return recv(ss.p, ss.codec, m)
|
||||
}
|
||||
|
@ -100,7 +100,7 @@ type http2Client struct {
|
||||
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
||||
// and starts to receive messages on it. Non-nil error returns if construction
|
||||
// fails.
|
||||
func newHTTP2Client(addr string, opts *DialOptions) (_ ClientTransport, err error) {
|
||||
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
|
||||
var (
|
||||
connErr error
|
||||
conn net.Conn
|
||||
|
@ -313,16 +313,16 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
|
||||
return newHTTP2Server(conn, maxStreams)
|
||||
}
|
||||
|
||||
// DialOptions covers all relevant options for dialing a server.
|
||||
type DialOptions struct {
|
||||
// ConnectOptions covers all relevant options for dialing a server.
|
||||
type ConnectOptions struct {
|
||||
Protocol string
|
||||
AuthOptions []credentials.Credentials
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// NewClientTransport establishes the transport with the required DialOptions
|
||||
// NewClientTransport establishes the transport with the required ConnectOptions
|
||||
// and returns it to the caller.
|
||||
func NewClientTransport(target string, opts *DialOptions) (ClientTransport, error) {
|
||||
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) {
|
||||
return newHTTP2Client(target, opts)
|
||||
}
|
||||
|
||||
|
@ -182,12 +182,12 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create credentials %v", err)
|
||||
}
|
||||
dopts := DialOptions{
|
||||
dopts := ConnectOptions{
|
||||
AuthOptions: []credentials.Credentials{creds},
|
||||
}
|
||||
ct, connErr = NewClientTransport(addr, &dopts)
|
||||
} else {
|
||||
ct, connErr = NewClientTransport(addr, &DialOptions{})
|
||||
ct, connErr = NewClientTransport(addr, &ConnectOptions{})
|
||||
}
|
||||
if connErr != nil {
|
||||
t.Fatalf("failed to create transport: %v", connErr)
|
||||
|
Reference in New Issue
Block a user