make Codec configurable when creating grpc.ClientConn and grpc.Server
This commit is contained in:
12
call.go
12
call.go
@ -47,7 +47,7 @@ import (
|
|||||||
// On error, it returns the error and indicates whether the call should be retried.
|
// On error, it returns the error and indicates whether the call should be retried.
|
||||||
//
|
//
|
||||||
// TODO(zhaoq): Check whether the received message sequence is valid.
|
// TODO(zhaoq): Check whether the received message sequence is valid.
|
||||||
func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
|
func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
|
||||||
// Try to acquire header metadata from the server if there is any.
|
// Try to acquire header metadata from the server if there is any.
|
||||||
var err error
|
var err error
|
||||||
c.headerMD, err = stream.Header()
|
c.headerMD, err = stream.Header()
|
||||||
@ -56,7 +56,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St
|
|||||||
}
|
}
|
||||||
p := &parser{s: stream}
|
p := &parser{s: stream}
|
||||||
for {
|
for {
|
||||||
if err = recv(p, protoCodec{}, reply); err != nil {
|
if err = recv(p, codec, reply); err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -68,7 +68,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sendRequest writes out various information of an RPC such as Context and Message.
|
// sendRequest writes out various information of an RPC such as Context and Message.
|
||||||
func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
|
func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
|
||||||
stream, err := t.NewStream(ctx, callHdr)
|
stream, err := t.NewStream(ctx, callHdr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -81,7 +81,7 @@ func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.Cl
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
// TODO(zhaoq): Support compression.
|
// TODO(zhaoq): Support compression.
|
||||||
outBuf, err := encode(protoCodec{}, args, compressionNone)
|
outBuf, err := encode(codec, args, compressionNone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
||||||
}
|
}
|
||||||
@ -148,7 +148,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
}
|
}
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
stream, err = sendRequest(ctx, callHdr, t, args, topts)
|
stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(transport.ConnectionError); ok {
|
if _, ok := err.(transport.ConnectionError); ok {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
@ -160,7 +160,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
|
|||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
// Receive the response
|
// Receive the response
|
||||||
lastErr = recvResponse(t, &c, stream, reply)
|
lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply)
|
||||||
if _, ok := lastErr.(transport.ConnectionError); ok {
|
if _, ok := lastErr.(transport.ConnectionError); ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -55,29 +55,43 @@ var (
|
|||||||
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
|
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// dialOptions configure a Dial call. dialOptions are set by the DialOption
|
||||||
|
// values passed to Dial.
|
||||||
|
type dialOptions struct {
|
||||||
|
codec Codec
|
||||||
|
copts transport.ConnectOptions
|
||||||
|
}
|
||||||
|
|
||||||
// DialOption configures how we set up the connection.
|
// DialOption configures how we set up the connection.
|
||||||
type DialOption func(*transport.DialOptions)
|
type DialOption func(*dialOptions)
|
||||||
|
|
||||||
|
// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
|
||||||
|
func WithCodec(codec Codec) DialOption {
|
||||||
|
return func(o *dialOptions) {
|
||||||
|
o.codec = codec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithTransportCredentials returns a DialOption which configures a
|
// WithTransportCredentials returns a DialOption which configures a
|
||||||
// connection level security credentials (e.g., TLS/SSL).
|
// connection level security credentials (e.g., TLS/SSL).
|
||||||
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
|
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
|
||||||
return func(o *transport.DialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.AuthOptions = append(o.AuthOptions, creds)
|
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPerRPCCredentials returns a DialOption which sets
|
// WithPerRPCCredentials returns a DialOption which sets
|
||||||
// credentials which will place auth state on each outbound RPC.
|
// credentials which will place auth state on each outbound RPC.
|
||||||
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
||||||
return func(o *transport.DialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.AuthOptions = append(o.AuthOptions, creds)
|
o.copts.AuthOptions = append(o.copts.AuthOptions, creds)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithTimeout returns a DialOption that configures a timeout for dialing a client connection.
|
// WithTimeout returns a DialOption that configures a timeout for dialing a client connection.
|
||||||
func WithTimeout(d time.Duration) DialOption {
|
func WithTimeout(d time.Duration) DialOption {
|
||||||
return func(o *transport.DialOptions) {
|
return func(o *dialOptions) {
|
||||||
o.Timeout = d
|
o.copts.Timeout = d
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,6 +108,10 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
|||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(&cc.dopts)
|
opt(&cc.dopts)
|
||||||
}
|
}
|
||||||
|
if cc.dopts.codec == nil {
|
||||||
|
// Set the default codec.
|
||||||
|
cc.dopts.codec = &protoCodec{}
|
||||||
|
}
|
||||||
if err := cc.resetTransport(false); err != nil {
|
if err := cc.resetTransport(false); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -106,7 +124,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
|||||||
// ClientConn represents a client connection to an RPC service.
|
// ClientConn represents a client connection to an RPC service.
|
||||||
type ClientConn struct {
|
type ClientConn struct {
|
||||||
target string
|
target string
|
||||||
dopts transport.DialOptions
|
dopts dialOptions
|
||||||
shutdownChan chan struct{}
|
shutdownChan chan struct{}
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -140,23 +158,23 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error {
|
|||||||
t.Close()
|
t.Close()
|
||||||
}
|
}
|
||||||
// Adjust timeout for the current try.
|
// Adjust timeout for the current try.
|
||||||
dopts := cc.dopts
|
copts := cc.dopts.copts
|
||||||
if dopts.Timeout < 0 {
|
if copts.Timeout < 0 {
|
||||||
cc.Close()
|
cc.Close()
|
||||||
return ErrClientConnTimeout
|
return ErrClientConnTimeout
|
||||||
}
|
}
|
||||||
if dopts.Timeout > 0 {
|
if copts.Timeout > 0 {
|
||||||
dopts.Timeout -= time.Since(start)
|
copts.Timeout -= time.Since(start)
|
||||||
if dopts.Timeout <= 0 {
|
if copts.Timeout <= 0 {
|
||||||
cc.Close()
|
cc.Close()
|
||||||
return ErrClientConnTimeout
|
return ErrClientConnTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
newTransport, err := transport.NewClientTransport(cc.target, &dopts)
|
newTransport, err := transport.NewClientTransport(cc.target, &copts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sleepTime := backoff(retries)
|
sleepTime := backoff(retries)
|
||||||
// Fail early before falling into sleep.
|
// Fail early before falling into sleep.
|
||||||
if cc.dopts.Timeout > 0 && cc.dopts.Timeout < sleepTime+time.Since(start) {
|
if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) {
|
||||||
cc.Close()
|
cc.Close()
|
||||||
return ErrClientConnTimeout
|
return ErrClientConnTimeout
|
||||||
}
|
}
|
||||||
|
14
server.go
14
server.go
@ -85,12 +85,19 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
|
codec Codec
|
||||||
maxConcurrentStreams uint32
|
maxConcurrentStreams uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// A ServerOption sets options.
|
// A ServerOption sets options.
|
||||||
type ServerOption func(*options)
|
type ServerOption func(*options)
|
||||||
|
|
||||||
|
func CustomCodec(codec Codec) ServerOption {
|
||||||
|
return func(o *options) {
|
||||||
|
o.codec = codec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MaxConcurrentStreams returns an Option that will apply a limit on the number
|
// MaxConcurrentStreams returns an Option that will apply a limit on the number
|
||||||
// of concurrent streams to each ServerTransport.
|
// of concurrent streams to each ServerTransport.
|
||||||
func MaxConcurrentStreams(n uint32) ServerOption {
|
func MaxConcurrentStreams(n uint32) ServerOption {
|
||||||
@ -106,6 +113,10 @@ func NewServer(opt ...ServerOption) *Server {
|
|||||||
for _, o := range opt {
|
for _, o := range opt {
|
||||||
o(&opts)
|
o(&opts)
|
||||||
}
|
}
|
||||||
|
if opts.codec == nil {
|
||||||
|
// Set the default codec.
|
||||||
|
opts.codec = &protoCodec{}
|
||||||
|
}
|
||||||
return &Server{
|
return &Server{
|
||||||
lis: make(map[net.Listener]bool),
|
lis: make(map[net.Listener]bool),
|
||||||
opts: opts,
|
opts: opts,
|
||||||
@ -203,7 +214,7 @@ func (s *Server) Serve(lis net.Listener) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error {
|
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error {
|
||||||
p, err := encode(protoCodec{}, msg, pf)
|
p, err := encode(s.opts.codec, msg, pf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This typically indicates a fatal issue (e.g., memory
|
// This typically indicates a fatal issue (e.g., memory
|
||||||
// corruption or hardware faults) the application program
|
// corruption or hardware faults) the application program
|
||||||
@ -286,6 +297,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
t: t,
|
t: t,
|
||||||
s: stream,
|
s: stream,
|
||||||
p: &parser{s: stream},
|
p: &parser{s: stream},
|
||||||
|
codec: s.opts.codec,
|
||||||
}
|
}
|
||||||
if appErr := sd.Handler(srv.server, ss); appErr != nil {
|
if appErr := sd.Handler(srv.server, ss); appErr != nil {
|
||||||
if err, ok := appErr.(rpcError); ok {
|
if err, ok := appErr.(rpcError); ok {
|
||||||
|
15
stream.go
15
stream.go
@ -66,7 +66,7 @@ type Stream interface {
|
|||||||
// side. On server side, it simply returns the error to the caller.
|
// side. On server side, it simply returns the error to the caller.
|
||||||
// SendMsg is called by generated code.
|
// SendMsg is called by generated code.
|
||||||
SendMsg(m interface{}) error
|
SendMsg(m interface{}) error
|
||||||
// RecvMsg blocks until it receives a proto message or the stream is
|
// RecvMsg blocks until it receives a message or the stream is
|
||||||
// done. On client side, it returns io.EOF when the stream is done. On
|
// done. On client side, it returns io.EOF when the stream is done. On
|
||||||
// any other error, it aborts the streama nd returns an RPC status. On
|
// any other error, it aborts the streama nd returns an RPC status. On
|
||||||
// server side, it simply returns the error to the caller.
|
// server side, it simply returns the error to the caller.
|
||||||
@ -116,6 +116,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||||||
s: s,
|
s: s,
|
||||||
p: &parser{s: s},
|
p: &parser{s: s},
|
||||||
desc: desc,
|
desc: desc,
|
||||||
|
codec: cc.dopts.codec,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,6 +126,7 @@ type clientStream struct {
|
|||||||
s *transport.Stream
|
s *transport.Stream
|
||||||
p *parser
|
p *parser
|
||||||
desc *StreamDesc
|
desc *StreamDesc
|
||||||
|
codec Codec
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) Context() context.Context {
|
func (cs *clientStream) Context() context.Context {
|
||||||
@ -155,7 +157,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
err = toRPCErr(err)
|
err = toRPCErr(err)
|
||||||
}()
|
}()
|
||||||
out, err := encode(protoCodec{}, m, compressionNone)
|
out, err := encode(cs.codec, m, compressionNone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
return transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
||||||
}
|
}
|
||||||
@ -163,13 +165,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||||
err = recv(cs.p, protoCodec{}, m)
|
err = recv(cs.p, cs.codec, m)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
|
if !cs.desc.ClientStreams || cs.desc.ServerStreams {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Special handling for client streaming rpc.
|
// Special handling for client streaming rpc.
|
||||||
err = recv(cs.p, protoCodec{}, m)
|
err = recv(cs.p, cs.codec, m)
|
||||||
cs.t.CloseStream(cs.s, err)
|
cs.t.CloseStream(cs.s, err)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||||
@ -224,6 +226,7 @@ type serverStream struct {
|
|||||||
t transport.ServerTransport
|
t transport.ServerTransport
|
||||||
s *transport.Stream
|
s *transport.Stream
|
||||||
p *parser
|
p *parser
|
||||||
|
codec Codec
|
||||||
statusCode codes.Code
|
statusCode codes.Code
|
||||||
statusDesc string
|
statusDesc string
|
||||||
}
|
}
|
||||||
@ -245,7 +248,7 @@ func (ss *serverStream) SetTrailer(md metadata.MD) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ss *serverStream) SendMsg(m interface{}) error {
|
func (ss *serverStream) SendMsg(m interface{}) error {
|
||||||
out, err := encode(protoCodec{}, m, compressionNone)
|
out, err := encode(ss.codec, m, compressionNone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
err = transport.StreamErrorf(codes.Internal, "grpc: %v", err)
|
||||||
return err
|
return err
|
||||||
@ -254,5 +257,5 @@ func (ss *serverStream) SendMsg(m interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ss *serverStream) RecvMsg(m interface{}) error {
|
func (ss *serverStream) RecvMsg(m interface{}) error {
|
||||||
return recv(ss.p, protoCodec{}, m)
|
return recv(ss.p, ss.codec, m)
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ type http2Client struct {
|
|||||||
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
||||||
// and starts to receive messages on it. Non-nil error returns if construction
|
// and starts to receive messages on it. Non-nil error returns if construction
|
||||||
// fails.
|
// fails.
|
||||||
func newHTTP2Client(addr string, opts *DialOptions) (_ ClientTransport, err error) {
|
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
|
||||||
var (
|
var (
|
||||||
connErr error
|
connErr error
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
@ -313,16 +313,16 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
|
|||||||
return newHTTP2Server(conn, maxStreams)
|
return newHTTP2Server(conn, maxStreams)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialOptions covers all relevant options for dialing a server.
|
// ConnectOptions covers all relevant options for dialing a server.
|
||||||
type DialOptions struct {
|
type ConnectOptions struct {
|
||||||
Protocol string
|
Protocol string
|
||||||
AuthOptions []credentials.Credentials
|
AuthOptions []credentials.Credentials
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTransport establishes the transport with the required DialOptions
|
// NewClientTransport establishes the transport with the required ConnectOptions
|
||||||
// and returns it to the caller.
|
// and returns it to the caller.
|
||||||
func NewClientTransport(target string, opts *DialOptions) (ClientTransport, error) {
|
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) {
|
||||||
return newHTTP2Client(target, opts)
|
return newHTTP2Client(target, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,12 +182,12 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create credentials %v", err)
|
t.Fatalf("Failed to create credentials %v", err)
|
||||||
}
|
}
|
||||||
dopts := DialOptions{
|
dopts := ConnectOptions{
|
||||||
AuthOptions: []credentials.Credentials{creds},
|
AuthOptions: []credentials.Credentials{creds},
|
||||||
}
|
}
|
||||||
ct, connErr = NewClientTransport(addr, &dopts)
|
ct, connErr = NewClientTransport(addr, &dopts)
|
||||||
} else {
|
} else {
|
||||||
ct, connErr = NewClientTransport(addr, &DialOptions{})
|
ct, connErr = NewClientTransport(addr, &ConnectOptions{})
|
||||||
}
|
}
|
||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
t.Fatalf("failed to create transport: %v", connErr)
|
t.Fatalf("failed to create transport: %v", connErr)
|
||||||
|
Reference in New Issue
Block a user