make Codec configurable when creating grpc.ClientConn and grpc.Server
This commit is contained in:
		
							
								
								
									
										12
									
								
								call.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								call.go
									
									
									
									
									
								
							| @ -47,7 +47,7 @@ import ( | ||||
| // On error, it returns the error and indicates whether the call should be retried. | ||||
| // | ||||
| // TODO(zhaoq): Check whether the received message sequence is valid. | ||||
| func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { | ||||
| func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { | ||||
| 	// Try to acquire header metadata from the server if there is any. | ||||
| 	var err error | ||||
| 	c.headerMD, err = stream.Header() | ||||
| @ -56,7 +56,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St | ||||
| 	} | ||||
| 	p := &parser{s: stream} | ||||
| 	for { | ||||
| 		if err = recv(p, protoCodec{}, reply); err != nil { | ||||
| 		if err = recv(p, codec, reply); err != nil { | ||||
| 			if err == io.EOF { | ||||
| 				break | ||||
| 			} | ||||
| @ -68,7 +68,7 @@ func recvResponse(t transport.ClientTransport, c *callInfo, stream *transport.St | ||||
| } | ||||
|  | ||||
| // sendRequest writes out various information of an RPC such as Context and Message. | ||||
| func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { | ||||
| func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { | ||||
| 	stream, err := t.NewStream(ctx, callHdr) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -81,7 +81,7 @@ func sendRequest(ctx context.Context, callHdr *transport.CallHdr, t transport.Cl | ||||
| 		} | ||||
| 	}() | ||||
| 	// TODO(zhaoq): Support compression. | ||||
| 	outBuf, err := encode(protoCodec{}, args, compressionNone) | ||||
| 	outBuf, err := encode(codec, args, compressionNone) | ||||
| 	if err != nil { | ||||
| 		return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) | ||||
| 	} | ||||
| @ -148,7 +148,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli | ||||
| 			} | ||||
| 			return toRPCErr(err) | ||||
| 		} | ||||
| 		stream, err = sendRequest(ctx, callHdr, t, args, topts) | ||||
| 		stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts) | ||||
| 		if err != nil { | ||||
| 			if _, ok := err.(transport.ConnectionError); ok { | ||||
| 				lastErr = err | ||||
| @ -160,7 +160,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli | ||||
| 			return toRPCErr(err) | ||||
| 		} | ||||
| 		// Receive the response | ||||
| 		lastErr = recvResponse(t, &c, stream, reply) | ||||
| 		lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply) | ||||
| 		if _, ok := lastErr.(transport.ConnectionError); ok { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| @ -55,29 +55,43 @@ var ( | ||||
| 	ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") | ||||
| ) | ||||
|  | ||||
| // dialOptions configure a Dial call. dialOptions are set by the DialOption | ||||
| // values passed to Dial. | ||||
| type dialOptions struct { | ||||
| 	codec Codec | ||||
| 	copts transport.ConnectOptions | ||||
| } | ||||
|  | ||||
| // DialOption configures how we set up the connection. | ||||
| type DialOption func(*transport.DialOptions) | ||||
| type DialOption func(*dialOptions) | ||||
|  | ||||
| // WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling. | ||||
| func WithCodec(codec Codec) DialOption { | ||||
| 	return func(o *dialOptions) { | ||||
| 		o.codec = codec | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithTransportCredentials returns a DialOption which configures a | ||||
| // connection level security credentials (e.g., TLS/SSL). | ||||
| func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { | ||||
| 	return func(o *transport.DialOptions) { | ||||
| 		o.AuthOptions = append(o.AuthOptions, creds) | ||||
| 	return func(o *dialOptions) { | ||||
| 		o.copts.AuthOptions = append(o.copts.AuthOptions, creds) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithPerRPCCredentials returns a DialOption which sets | ||||
| // credentials which will place auth state on each outbound RPC. | ||||
| func WithPerRPCCredentials(creds credentials.Credentials) DialOption { | ||||
| 	return func(o *transport.DialOptions) { | ||||
| 		o.AuthOptions = append(o.AuthOptions, creds) | ||||
| 	return func(o *dialOptions) { | ||||
| 		o.copts.AuthOptions = append(o.copts.AuthOptions, creds) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithTimeout returns a DialOption that configures a timeout for dialing a client connection. | ||||
| func WithTimeout(d time.Duration) DialOption { | ||||
| 	return func(o *transport.DialOptions) { | ||||
| 		o.Timeout = d | ||||
| 	return func(o *dialOptions) { | ||||
| 		o.copts.Timeout = d | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -94,6 +108,10 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { | ||||
| 	for _, opt := range opts { | ||||
| 		opt(&cc.dopts) | ||||
| 	} | ||||
| 	if cc.dopts.codec == nil { | ||||
| 		// Set the default codec. | ||||
| 		cc.dopts.codec = &protoCodec{} | ||||
| 	} | ||||
| 	if err := cc.resetTransport(false); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @ -106,7 +124,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { | ||||
| // ClientConn represents a client connection to an RPC service. | ||||
| type ClientConn struct { | ||||
| 	target       string | ||||
| 	dopts        transport.DialOptions | ||||
| 	dopts        dialOptions | ||||
| 	shutdownChan chan struct{} | ||||
|  | ||||
| 	mu sync.Mutex | ||||
| @ -140,23 +158,23 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error { | ||||
| 			t.Close() | ||||
| 		} | ||||
| 		// Adjust timeout for the current try. | ||||
| 		dopts := cc.dopts | ||||
| 		if dopts.Timeout < 0 { | ||||
| 		copts := cc.dopts.copts | ||||
| 		if copts.Timeout < 0 { | ||||
| 			cc.Close() | ||||
| 			return ErrClientConnTimeout | ||||
| 		} | ||||
| 		if dopts.Timeout > 0 { | ||||
| 			dopts.Timeout -= time.Since(start) | ||||
| 			if dopts.Timeout <= 0 { | ||||
| 		if copts.Timeout > 0 { | ||||
| 			copts.Timeout -= time.Since(start) | ||||
| 			if copts.Timeout <= 0 { | ||||
| 				cc.Close() | ||||
| 				return ErrClientConnTimeout | ||||
| 			} | ||||
| 		} | ||||
| 		newTransport, err := transport.NewClientTransport(cc.target, &dopts) | ||||
| 		newTransport, err := transport.NewClientTransport(cc.target, &copts) | ||||
| 		if err != nil { | ||||
| 			sleepTime := backoff(retries) | ||||
| 			// Fail early before falling into sleep. | ||||
| 			if cc.dopts.Timeout > 0 && cc.dopts.Timeout < sleepTime+time.Since(start) { | ||||
| 			if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) { | ||||
| 				cc.Close() | ||||
| 				return ErrClientConnTimeout | ||||
| 			} | ||||
|  | ||||
							
								
								
									
										14
									
								
								server.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								server.go
									
									
									
									
									
								
							| @ -85,12 +85,19 @@ type Server struct { | ||||
| } | ||||
|  | ||||
| type options struct { | ||||
| 	codec Codec | ||||
| 	maxConcurrentStreams uint32 | ||||
| } | ||||
|  | ||||
| // A ServerOption sets options. | ||||
| type ServerOption func(*options) | ||||
|  | ||||
| func CustomCodec(codec Codec) ServerOption { | ||||
| 	return func(o *options) { | ||||
| 		o.codec = codec | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // MaxConcurrentStreams returns an Option that will apply a limit on the number | ||||
| // of concurrent streams to each ServerTransport. | ||||
| func MaxConcurrentStreams(n uint32) ServerOption { | ||||
| @ -106,6 +113,10 @@ func NewServer(opt ...ServerOption) *Server { | ||||
| 	for _, o := range opt { | ||||
| 		o(&opts) | ||||
| 	} | ||||
| 	if opts.codec == nil { | ||||
| 		// Set the default codec. | ||||
| 		opts.codec = &protoCodec{} | ||||
| 	} | ||||
| 	return &Server{ | ||||
| 		lis:   make(map[net.Listener]bool), | ||||
| 		opts:  opts, | ||||
| @ -203,7 +214,7 @@ func (s *Server) Serve(lis net.Listener) error { | ||||
| } | ||||
|  | ||||
| func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error { | ||||
| 	p, err := encode(protoCodec{}, msg, pf) | ||||
| 	p, err := encode(s.opts.codec, msg, pf) | ||||
| 	if err != nil { | ||||
| 		// This typically indicates a fatal issue (e.g., memory | ||||
| 		// corruption or hardware faults) the application program | ||||
| @ -286,6 +297,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp | ||||
| 		t: t, | ||||
| 		s: stream, | ||||
| 		p: &parser{s: stream}, | ||||
| 		codec: s.opts.codec, | ||||
| 	} | ||||
| 	if appErr := sd.Handler(srv.server, ss); appErr != nil { | ||||
| 		if err, ok := appErr.(rpcError); ok { | ||||
|  | ||||
							
								
								
									
										15
									
								
								stream.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								stream.go
									
									
									
									
									
								
							| @ -66,7 +66,7 @@ type Stream interface { | ||||
| 	// side. On server side, it simply returns the error to the caller. | ||||
| 	// SendMsg is called by generated code. | ||||
| 	SendMsg(m interface{}) error | ||||
| 	// RecvMsg blocks until it receives a proto message or the stream is | ||||
| 	// RecvMsg blocks until it receives a message or the stream is | ||||
| 	// done. On client side, it returns io.EOF when the stream is done. On | ||||
| 	// any other error, it aborts the streama nd returns an RPC status. On | ||||
| 	// server side, it simply returns the error to the caller. | ||||
| @ -116,6 +116,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth | ||||
| 		s:    s, | ||||
| 		p:    &parser{s: s}, | ||||
| 		desc: desc, | ||||
| 		codec: cc.dopts.codec, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @ -125,6 +126,7 @@ type clientStream struct { | ||||
| 	s    *transport.Stream | ||||
| 	p    *parser | ||||
| 	desc *StreamDesc | ||||
| 	codec Codec | ||||
| } | ||||
|  | ||||
| func (cs *clientStream) Context() context.Context { | ||||
| @ -155,7 +157,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { | ||||
| 		} | ||||
| 		err = toRPCErr(err) | ||||
| 	}() | ||||
| 	out, err := encode(protoCodec{}, m, compressionNone) | ||||
| 	out, err := encode(cs.codec, m, compressionNone) | ||||
| 	if err != nil { | ||||
| 		return transport.StreamErrorf(codes.Internal, "grpc: %v", err) | ||||
| 	} | ||||
| @ -163,13 +165,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { | ||||
| } | ||||
|  | ||||
| func (cs *clientStream) RecvMsg(m interface{}) (err error) { | ||||
| 	err = recv(cs.p, protoCodec{}, m) | ||||
| 	err = recv(cs.p, cs.codec, m) | ||||
| 	if err == nil { | ||||
| 		if !cs.desc.ClientStreams || cs.desc.ServerStreams { | ||||
| 			return | ||||
| 		} | ||||
| 		// Special handling for client streaming rpc. | ||||
| 		err = recv(cs.p, protoCodec{}, m) | ||||
| 		err = recv(cs.p, cs.codec, m) | ||||
| 		cs.t.CloseStream(cs.s, err) | ||||
| 		if err == nil { | ||||
| 			return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>")) | ||||
| @ -224,6 +226,7 @@ type serverStream struct { | ||||
| 	t          transport.ServerTransport | ||||
| 	s          *transport.Stream | ||||
| 	p          *parser | ||||
| 	codec      Codec | ||||
| 	statusCode codes.Code | ||||
| 	statusDesc string | ||||
| } | ||||
| @ -245,7 +248,7 @@ func (ss *serverStream) SetTrailer(md metadata.MD) { | ||||
| } | ||||
|  | ||||
| func (ss *serverStream) SendMsg(m interface{}) error { | ||||
| 	out, err := encode(protoCodec{}, m, compressionNone) | ||||
| 	out, err := encode(ss.codec, m, compressionNone) | ||||
| 	if err != nil { | ||||
| 		err = transport.StreamErrorf(codes.Internal, "grpc: %v", err) | ||||
| 		return err | ||||
| @ -254,5 +257,5 @@ func (ss *serverStream) SendMsg(m interface{}) error { | ||||
| } | ||||
|  | ||||
| func (ss *serverStream) RecvMsg(m interface{}) error { | ||||
| 	return recv(ss.p, protoCodec{}, m) | ||||
| 	return recv(ss.p, ss.codec, m) | ||||
| } | ||||
|  | ||||
| @ -100,7 +100,7 @@ type http2Client struct { | ||||
| // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 | ||||
| // and starts to receive messages on it. Non-nil error returns if construction | ||||
| // fails. | ||||
| func newHTTP2Client(addr string, opts *DialOptions) (_ ClientTransport, err error) { | ||||
| func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { | ||||
| 	var ( | ||||
| 		connErr error | ||||
| 		conn    net.Conn | ||||
|  | ||||
| @ -313,16 +313,16 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv | ||||
| 	return newHTTP2Server(conn, maxStreams) | ||||
| } | ||||
|  | ||||
| // DialOptions covers all relevant options for dialing a server. | ||||
| type DialOptions struct { | ||||
| // ConnectOptions covers all relevant options for dialing a server. | ||||
| type ConnectOptions struct { | ||||
| 	Protocol    string | ||||
| 	AuthOptions []credentials.Credentials | ||||
| 	Timeout     time.Duration | ||||
| } | ||||
|  | ||||
| // NewClientTransport establishes the transport with the required DialOptions | ||||
| // NewClientTransport establishes the transport with the required ConnectOptions | ||||
| // and returns it to the caller. | ||||
| func NewClientTransport(target string, opts *DialOptions) (ClientTransport, error) { | ||||
| func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { | ||||
| 	return newHTTP2Client(target, opts) | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -182,12 +182,12 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("Failed to create credentials %v", err) | ||||
| 		} | ||||
| 		dopts := DialOptions{ | ||||
| 		dopts := ConnectOptions{ | ||||
| 			AuthOptions: []credentials.Credentials{creds}, | ||||
| 		} | ||||
| 		ct, connErr = NewClientTransport(addr, &dopts) | ||||
| 	} else { | ||||
| 		ct, connErr = NewClientTransport(addr, &DialOptions{}) | ||||
| 		ct, connErr = NewClientTransport(addr, &ConnectOptions{}) | ||||
| 	} | ||||
| 	if connErr != nil { | ||||
| 		t.Fatalf("failed to create transport: %v", connErr) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 iamqizhao
					iamqizhao