diff --git a/call.go b/call.go index 0115a28d..7ca088a8 100644 --- a/call.go +++ b/call.go @@ -116,7 +116,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli o.after(&c) } }() - + conn, err := cc.picker.Pick() + if err != nil { + return toRPCErr(err) + } if EnableTracing { c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) defer c.traceInfo.tr.Finish() @@ -134,7 +137,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli }() } callHdr := &transport.CallHdr{ - Host: cc.authority, + Host: conn.authority, Method: method, } topts := &transport.Options{ @@ -154,7 +157,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if lastErr != nil && c.failFast { return toRPCErr(lastErr) } - t, err = cc.wait(ctx) + t, err = conn.wait(ctx) if err != nil { if lastErr != nil { // This was a retry; return the error from the last attempt. @@ -165,7 +168,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli if c.traceInfo.tr != nil { c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) } - stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts) + stream, err = sendRequest(ctx, conn.dopts.codec, callHdr, t, args, topts) if err != nil { if _, ok := err.(transport.ConnectionError); ok { lastErr = err @@ -177,7 +180,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli return toRPCErr(err) } // Receive the response - lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply) + lastErr = recvResponse(conn.dopts.codec, t, &c, stream, reply) if _, ok := lastErr.(transport.ConnectionError); ok { continue } diff --git a/clientconn.go b/clientconn.go index 87f302fc..2b8e699d 100644 --- a/clientconn.go +++ b/clientconn.go @@ -73,6 +73,7 @@ var ( // values passed to Dial. type dialOptions struct { codec Codec + picker Picker block bool insecure bool copts transport.ConnectOptions @@ -142,88 +143,18 @@ func WithUserAgent(s string) DialOption { // Dial creates a client connection the given target. func Dial(target string, opts ...DialOption) (*ClientConn, error) { - if target == "" { - return nil, ErrUnspecTarget - } - cc := &ClientConn{ - target: target, - shutdownChan: make(chan struct{}), - } - if EnableTracing { - cc.events = trace.NewEventLog("grpc.ClientConn", target) - } + var dopts dialOptions for _, opt := range opts { - opt(&cc.dopts) + opt(&dopts) } - if !cc.dopts.insecure { - var ok bool - for _, c := range cc.dopts.copts.AuthOptions { - if _, ok := c.(credentials.TransportAuthenticator); !ok { - continue - } - ok = true - } - if !ok { - return nil, ErrNoTransportSecurity - } - } else { - for _, c := range cc.dopts.copts.AuthOptions { - if c.RequireTransportSecurity() { - return nil, ErrCredentialsMisuse - } - } - } - colonPos := strings.LastIndex(target, ":") - if colonPos == -1 { - colonPos = len(target) - } - cc.authority = target[:colonPos] - if cc.dopts.codec == nil { - // Set the default codec. - cc.dopts.codec = protoCodec{} - } - cc.stateCV = sync.NewCond(&cc.mu) - if cc.dopts.block { - if err := cc.resetTransport(false); err != nil { - cc.mu.Lock() - cc.errorf("dial failed: %v", err) - cc.mu.Unlock() - cc.Close() + if dopts.picker == nil { + p, err := newSimplePicker(target, dopts) + if err != nil { return nil, err } - // Start to monitor the error status of transport. - go cc.transportMonitor() - } else { - // Start a goroutine connecting to the server asynchronously. - go func() { - if err := cc.resetTransport(false); err != nil { - cc.mu.Lock() - cc.errorf("dial failed: %v", err) - cc.mu.Unlock() - grpclog.Printf("Failed to dial %s: %v; please retry.", target, err) - cc.Close() - return - } - go cc.transportMonitor() - }() - } - return cc, nil -} - -// printf records an event in cc's event log, unless cc has been closed. -// REQUIRES cc.mu is held. -func (cc *ClientConn) printf(format string, a ...interface{}) { - if cc.events != nil { - cc.events.Printf(format, a...) - } -} - -// errorf records an error in cc's event log, unless cc has been closed. -// REQUIRES cc.mu is held. -func (cc *ClientConn) errorf(format string, a ...interface{}) { - if cc.events != nil { - cc.events.Errorf(format, a...) + dopts.picker = p } + return &ClientConn{dopts.picker}, nil } // ConnectivityState indicates the state of a client connection. @@ -261,6 +192,36 @@ func (s ConnectivityState) String() string { // ClientConn represents a client connection to an RPC service. type ClientConn struct { + picker Picker +} + +// State returns the connectivity state of the Conn used for next upcoming RPC. +func (cc *ClientConn) State() ConnectivityState { + c := cc.picker.Peek() + if c == nil { + return Idle + } + return c.getState() +} + +// WaitForStateChange blocks until the state changes to something other than the sourceState +// or timeout fires on the Conn used for next upcoming RPC. It returns false if the Conn is nil +// or timeout fires, and true otherwise. +func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool { + c := cc.picker.Peek() + if c == nil { + return false + } + return c.waitForStateChange(timeout, sourceState) +} + +// Close starts to tear down the ClientConn. +func (cc *ClientConn) Close() error { + return cc.picker.Close() +} + +// Conn is a client connection to a single destination. +type Conn struct { target string authority string dopts dialOptions @@ -276,16 +237,94 @@ type ClientConn struct { transport transport.ClientTransport } -// State returns the connectivity state of the ClientConn -func (cc *ClientConn) State() ConnectivityState { +// NewConn creates a Conn. +func NewConn(target string, dopts dialOptions) (*Conn, error) { + if target == "" { + return nil, ErrUnspecTarget + } + c := &Conn{ + target: target, + dopts: dopts, + shutdownChan: make(chan struct{}), + } + if EnableTracing { + c.events = trace.NewEventLog("grpc.ClientConn", target) + } + if !c.dopts.insecure { + var ok bool + for _, cd := range c.dopts.copts.AuthOptions { + if _, ok := cd.(credentials.TransportAuthenticator); !ok { + continue + } + ok = true + } + if !ok { + return nil, ErrNoTransportSecurity + } + } else { + for _, cd := range c.dopts.copts.AuthOptions { + if cd.RequireTransportSecurity() { + return nil, ErrCredentialsMisuse + } + } + } + colonPos := strings.LastIndex(target, ":") + if colonPos == -1 { + colonPos = len(target) + } + c.authority = target[:colonPos] + if c.dopts.codec == nil { + // Set the default codec. + c.dopts.codec = protoCodec{} + } + c.stateCV = sync.NewCond(&c.mu) + if c.dopts.block { + if err := c.resetTransport(false); err != nil { + c.Close() + return nil, err + } + // Start to monitor the error status of transport. + go c.transportMonitor() + } else { + // Start a goroutine connecting to the server asynchronously. + go func() { + if err := c.resetTransport(false); err != nil { + grpclog.Printf("Failed to dial %s: %v; please retry.", target, err) + c.Close() + return + } + go c.transportMonitor() + }() + } + return c, nil +} + +// printf records an event in cc's event log, unless cc has been closed. +// REQUIRES cc.mu is held. +func (cc *Conn) printf(format string, a ...interface{}) { + if cc.events != nil { + cc.events.Printf(format, a...) + } +} + +// errorf records an error in cc's event log, unless cc has been closed. +// REQUIRES cc.mu is held. +func (cc *Conn) errorf(format string, a ...interface{}) { + if cc.events != nil { + cc.events.Errorf(format, a...) + } +} + +// getState returns the connectivity state of the Conn +func (cc *Conn) getState() ConnectivityState { cc.mu.Lock() defer cc.mu.Unlock() return cc.state } -// WaitForStateChange blocks until the state changes to something other than the sourceState +// waitForStateChange blocks until the state changes to something other than the sourceState // or timeout fires. It returns false if timeout fires and true otherwise. -func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool { +func (cc *Conn) waitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool { start := time.Now() cc.mu.Lock() defer cc.mu.Unlock() @@ -317,7 +356,7 @@ func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState Conn return true } -func (cc *ClientConn) resetTransport(closeTransport bool) error { +func (cc *Conn) resetTransport(closeTransport bool) error { var retries int start := time.Now() for { @@ -402,7 +441,7 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error { // Run in a goroutine to track the error in transport and create the // new transport if an error happens. It returns when the channel is closing. -func (cc *ClientConn) transportMonitor() { +func (cc *Conn) transportMonitor() { for { select { // shutdownChan is needed to detect the teardown when @@ -429,7 +468,7 @@ func (cc *ClientConn) transportMonitor() { // When wait returns, either the new transport is up or ClientConn is // closing. -func (cc *ClientConn) wait(ctx context.Context) (transport.ClientTransport, error) { +func (cc *Conn) wait(ctx context.Context) (transport.ClientTransport, error) { for { cc.mu.Lock() switch { @@ -456,12 +495,12 @@ func (cc *ClientConn) wait(ctx context.Context) (transport.ClientTransport, erro } } -// Close starts to tear down the ClientConn. Returns ErrClientConnClosing if +// Close starts to tear down the Conn. Returns ErrClientConnClosing if // it has been closed (mostly due to dial time-out). // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in // some edge cases (e.g., the caller opens and closes many ClientConn's in a // tight loop. -func (cc *ClientConn) Close() error { +func (cc *Conn) Close() error { cc.mu.Lock() defer cc.mu.Unlock() if cc.state == Shutdown { diff --git a/picker.go b/picker.go new file mode 100644 index 00000000..7c633471 --- /dev/null +++ b/picker.go @@ -0,0 +1,45 @@ +package grpc + +// Picker picks a Conn for RPC requests. +// This is EXPERIMENTAL and Please do not implement your own Picker for now. +type Picker interface { + // Pick returns the Conn to use for the upcoming RPC. It may return different + // Conn's up to the implementation. + Pick() (*Conn, error) + // Peek returns the Conn use use for the next upcoming RPC. It returns the same + // Conn until next time Pick gets invoked. + Peek() *Conn + // Close closes all the Conn's owned by this Picker. + Close() error +} + +func newSimplePicker(target string, dopts dialOptions) (Picker, error) { + c, err := NewConn(target, dopts) + if err != nil { + return nil, err + } + return &simplePicker{ + conn: c, + }, nil +} + +// simplePicker is default Picker which is used when there is no custom Picker +// specified by users. It always picks the same Conn. +type simplePicker struct { + conn *Conn +} + +func (p *simplePicker) Pick() (*Conn, error) { + return p.conn, nil +} + +func (p *simplePicker) Peek() *Conn { + return p.conn +} + +func (p *simplePicker) Close() error { + if p.conn != nil { + return p.conn.Close() + } + return nil +} diff --git a/stream.go b/stream.go index e14664cb..d66e1a41 100644 --- a/stream.go +++ b/stream.go @@ -96,14 +96,18 @@ type ClientStream interface { // NewClientStream creates a new Stream for the client side. This is called // by generated code. func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { + conn, err := cc.picker.Pick() + if err != nil { + return nil, toRPCErr(err) + } // TODO(zhaoq): CallOption is omitted. Add support when it is needed. callHdr := &transport.CallHdr{ - Host: cc.authority, + Host: conn.authority, Method: method, } cs := &clientStream{ desc: desc, - codec: cc.dopts.codec, + codec: conn.dopts.codec, tracing: EnableTracing, } if cs.tracing { @@ -114,7 +118,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } cs.traceInfo.tr.LazyLog(&cs.traceInfo.firstLine, false) } - t, err := cc.wait(ctx) + t, err := conn.wait(ctx) if err != nil { return nil, toRPCErr(err) }