make Codec configurable when creating grpc.ClientConn and grpc.Server

This commit is contained in:
iamqizhao
2015-04-01 14:02:26 -07:00
parent 828af96d42
commit 9a5de0e954
7 changed files with 68 additions and 35 deletions

12
call.go
View File

@ -47,7 +47,7 @@ import (
// On error, it returns the error and indicates whether the call should be retried.
//
// TODO(zhaoq): Check whether the received message sequence is valid.
func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any.
var err error
c.headerMD, err = stream.Header()
@ -56,7 +56,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St
}
p := &parser{s: stream}
for {
if err = recv(p, protoCodec{}, reply); err != nil {
if err = recv(p, codec, reply); err != nil {
if err == io.EOF {
break
}
@ -68,7 +68,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St
}
// sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
@ -81,7 +81,7 @@ func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.Cl
}
}()
// TODO(zhaoq): Support compression.
outBuf, err := encode(protoCodec{}, args, compressionNone)
outBuf, err := encode(codec, args, compressionNone)
if err != nil {
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
@ -148,7 +148,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
return toRPCErr(err)
}
stream, err = sendRequest(ctx, callHdr, t, args, topts)
stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts)
if err != nil {
if _, ok := err.(transport.ConnectionError); ok {
lastErr = err
@ -160,7 +160,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
return toRPCErr(err)
}
// Receive the response
lastErr = recvResponse(t, &c, stream, reply)
lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply)
if _, ok := lastErr.(transport.ConnectionError); ok {
continue
}

View File

@ -55,29 +55,43 @@ var (
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
)
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
codec Codec
copts transport.ConnectOptions
}
// DialOption configures how we set up the connection.
type DialOption func(*transport.DialOptions)
type DialOption func(*dialOptions)
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
func WithCodec(codec Codec) DialOption {
return func(o *dialOptions) {
o.codec = codec
}
}
// WithTransportCredentials returns a DialOption which configures a
// connection level security credentials (e.g., TLS/SSL).
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
return func(o *transport.DialOptions) {
o.AuthOptions = append(o.AuthOptions, creds)
return func(o *dialOptions) {
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
}
}
// WithPerRPCCredentials returns a DialOption which sets
// credentials which will place auth state on each outbound RPC.
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
return func(o *transport.DialOptions) {
o.AuthOptions = append(o.AuthOptions, creds)
return func(o *dialOptions) {
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
}
}
// WithTimeout returns a DialOption that configures a timeout for dialing a client connection.
func WithTimeout(d time.Duration) DialOption {
return func(o *transport.DialOptions) {
o.Timeout = d
return func(o *dialOptions) {
o.copts.Timeout = d
}
}
@ -94,6 +108,10 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
for _, opt := range opts {
opt(&cc.dopts)
}
if cc.dopts.codec == nil {
// Set the default codec.
cc.dopts.codec = &protoCodec{}
}
if err := cc.resetTransport(false); err != nil {
return nil, err
}
@ -106,7 +124,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
// ClientConn represents a client connection to an RPC service.
type ClientConn struct {
target string
dopts transport.DialOptions
dopts dialOptions
shutdownChan chan struct{}
mu sync.Mutex
@ -140,23 +158,23 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error {
t.Close()
}
// Adjust timeout for the current try.
dopts := cc.dopts
if dopts.Timeout < 0 {
copts := cc.dopts.copts
if copts.Timeout < 0 {
cc.Close()
return ErrClientConnTimeout
}
if dopts.Timeout > 0 {
dopts.Timeout -= time.Since(start)
if dopts.Timeout <= 0 {
if copts.Timeout > 0 {
copts.Timeout -= time.Since(start)
if copts.Timeout <= 0 {
cc.Close()
return ErrClientConnTimeout
}
}
newTransport, err := transport.NewClientTransport(cc.target, &dopts)
newTransport, err := transport.NewClientTransport(cc.target, &copts)
if err != nil {
sleepTime := backoff(retries)
// Fail early before falling into sleep.
if cc.dopts.Timeout > 0 && cc.dopts.Timeout < sleepTime+time.Since(start) {
if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) {
cc.Close()
return ErrClientConnTimeout
}

View File

@ -85,12 +85,19 @@ type Server struct {
}
type options struct {
codec Codec
maxConcurrentStreams uint32
}
// A ServerOption sets options.
type ServerOption func(*options)
func CustomCodec(codec Codec) ServerOption {
return func(o *options) {
o.codec = codec
}
}
// MaxConcurrentStreams returns an Option that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
@ -106,6 +113,10 @@ func NewServer(opt ...ServerOption) *Server {
for _, o := range opt {
o(&opts)
}
if opts.codec == nil {
// Set the default codec.
opts.codec = &protoCodec{}
}
return &Server{
lis: make(map[net.Listener]bool),
opts: opts,
@ -203,7 +214,7 @@ func (s *Server) Serve(lis net.Listener) error {
}
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error {
p, err := encode(protoCodec{}, msg, pf)
p, err := encode(s.opts.codec, msg, pf)
if err != nil {
// This typically indicates a fatal issue (e.g., memory
// corruption or hardware faults) the application program
@ -286,6 +297,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
t: t,
s: stream,
p: &parser{s: stream},
codec: s.opts.codec,
}
if appErr := sd.Handler(srv.server, ss); appErr != nil {
if err, ok := appErr.(rpcError); ok {

View File

@ -66,7 +66,7 @@ type Stream interface {
// side. On server side, it simply returns the error to the caller.
// SendMsg is called by generated code.
SendMsg(m interface{}) error
// RecvMsg blocks until it receives a proto message or the stream is
// RecvMsg blocks until it receives a message or the stream is
// done. On client side, it returns io.EOF when the stream is done. On
// any other error, it aborts the streama nd returns an RPC status. On
// server side, it simply returns the error to the caller.
@ -116,6 +116,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
s: s,
p: &parser{s: s},
desc: desc,
codec: cc.dopts.codec,
}, nil
}
@ -125,6 +126,7 @@ type clientStream struct {
s *transport.Stream
p *parser
desc *StreamDesc
codec Codec
}
func (cs *clientStream) Context() context.Context {
@ -155,7 +157,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}
err = toRPCErr(err)
}()
out, err := encode(protoCodec{}, m, compressionNone)
out, err := encode(cs.codec, m, compressionNone)
if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
@ -163,13 +165,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, protoCodec{}, m)
err = recv(cs.p, cs.codec, m)
if err == nil {
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
return
}
// Special handling for client streaming rpc.
err = recv(cs.p, protoCodec{}, m)
err = recv(cs.p, cs.codec, m)
cs.t.CloseStream(cs.s, err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
@ -224,6 +226,7 @@ type serverStream struct {
t transport.ServerTransport
s *transport.Stream
p *parser
codec Codec
statusCode codes.Code
statusDesc string
}
@ -245,7 +248,7 @@ func (ss *serverStream) SetTrailer(md metadata.MD) {
}
func (ss *serverStream) SendMsg(m interface{}) error {
out, err := encode(protoCodec{}, m, compressionNone)
out, err := encode(ss.codec, m, compressionNone)
if err != nil {
err = transport.StreamErrorf(codes.Internal, "grpc: %v", err)
return err
@ -254,5 +257,5 @@ func (ss *serverStream) SendMsg(m interface{}) error {
}
func (ss *serverStream) RecvMsg(m interface{}) error {
return recv(ss.p, protoCodec{}, m)
return recv(ss.p, ss.codec, m)
}

View File

@ -100,7 +100,7 @@ type http2Client struct {
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
func newHTTP2Client(addr string, opts *DialOptions) (_ ClientTransport, err error) {
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
var (
connErr error
conn net.Conn

View File

@ -313,16 +313,16 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
return newHTTP2Server(conn, maxStreams)
}
// DialOptions covers all relevant options for dialing a server.
type DialOptions struct {
// ConnectOptions covers all relevant options for dialing a server.
type ConnectOptions struct {
Protocol string
AuthOptions []credentials.Credentials
Timeout time.Duration
}
// NewClientTransport establishes the transport with the required DialOptions
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
func NewClientTransport(target string, opts *DialOptions) (ClientTransport, error) {
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(target, opts)
}

View File

@ -182,12 +182,12 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool)
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
dopts := DialOptions{
dopts := ConnectOptions{
AuthOptions: []credentials.Credentials{creds},
}
ct, connErr = NewClientTransport(addr, &dopts)
} else {
ct, connErr = NewClientTransport(addr, &DialOptions{})
ct, connErr = NewClientTransport(addr, &ConnectOptions{})
}
if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr)