From c01ea6e3591de0f3055e8119f469427ec35cbfde Mon Sep 17 00:00:00 2001
From: iamqizhao <toqizhao@gmail.com>
Date: Tue, 29 Sep 2015 10:24:03 -0700
Subject: [PATCH] revise Picker API

---
 call.go       |  2 +-
 clientconn.go | 48 +++++++++++++++++++++++++-----------------------
 picker.go     | 21 +++++++++++----------
 stream.go     |  2 +-
 4 files changed, 38 insertions(+), 35 deletions(-)

diff --git a/call.go b/call.go
index 5b64c243..8b688091 100644
--- a/call.go
+++ b/call.go
@@ -150,7 +150,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 		if lastErr != nil && c.failFast {
 			return toRPCErr(lastErr)
 		}
-		conn, err = cc.picker.Pick()
+		conn, err = cc.dopts.picker.Pick()
 		if err != nil {
 			return toRPCErr(err)
 		}
diff --git a/clientconn.go b/clientconn.go
index 7186a23d..ea3ccd0b 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -66,7 +66,7 @@ var (
 	// established or re-established within the specified timeout.
 	ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
 	// ErrTransientFailure indicates the connection failed due to a transient error.
-	ErrTransientFailure = errors.New("transient connection failure")
+	ErrTransientFailure = errors.New("grpc: transient connection failure")
 	// minimum time to give a connection to complete
 	minConnectTimeout = 20 * time.Second
 )
@@ -145,18 +145,19 @@ func WithUserAgent(s string) DialOption {
 
 // Dial creates a client connection the given target.
 func Dial(target string, opts ...DialOption) (*ClientConn, error) {
-	var dopts dialOptions
+	cc := &ClientConn{
+		target: target,
+	}
 	for _, opt := range opts {
-		opt(&dopts)
+		opt(&cc.dopts)
 	}
-	if dopts.picker == nil {
-		p, err := newUnicastPicker(target, dopts)
-		if err != nil {
-			return nil, err
-		}
-		dopts.picker = p
+	if cc.dopts.picker == nil {
+		cc.dopts.picker = &unicastPicker{}
 	}
-	return &ClientConn{dopts.picker}, nil
+	if err := cc.dopts.picker.Init(cc); err != nil {
+		return nil, err
+	}
+	return cc, nil
 }
 
 // ConnectivityState indicates the state of a client connection.
@@ -194,25 +195,26 @@ func (s ConnectivityState) String() string {
 
 // ClientConn represents a client connection to an RPC service.
 type ClientConn struct {
-	picker Picker
+	target string
+	dopts  dialOptions
 }
 
 // State returns the connectivity state of cc.
 // This is EXPERIMENTAL API.
 func (cc *ClientConn) State() ConnectivityState {
-	return cc.picker.State()
+	return cc.dopts.picker.State()
 }
 
 // WaitForStateChange blocks until the state changes to something other than the sourceState
 // or timeout fires on cc. It returns false if timeout fires, and true otherwise.
 // This is EXPERIMENTAL API.
 func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool {
-	return cc.picker.WaitForStateChange(timeout, sourceState)
+	return cc.dopts.picker.WaitForStateChange(timeout, sourceState)
 }
 
 // Close starts to tear down the ClientConn.
 func (cc *ClientConn) Close() error {
-	return cc.picker.Close()
+	return cc.dopts.picker.Close()
 }
 
 // Conn is a client connection to a single destination.
@@ -233,17 +235,17 @@ type Conn struct {
 }
 
 // NewConn creates a Conn.
-func NewConn(target string, dopts dialOptions) (*Conn, error) {
-	if target == "" {
+func NewConn(cc *ClientConn) (*Conn, error) {
+	if cc.target == "" {
 		return nil, ErrUnspecTarget
 	}
 	c := &Conn{
-		target:       target,
-		dopts:        dopts,
+		target:       cc.target,
+		dopts:        cc.dopts,
 		shutdownChan: make(chan struct{}),
 	}
 	if EnableTracing {
-		c.events = trace.NewEventLog("grpc.ClientConn", target)
+		c.events = trace.NewEventLog("grpc.ClientConn", c.target)
 	}
 	if !c.dopts.insecure {
 		var ok bool
@@ -263,11 +265,11 @@ func NewConn(target string, dopts dialOptions) (*Conn, error) {
 			}
 		}
 	}
-	colonPos := strings.LastIndex(target, ":")
+	colonPos := strings.LastIndex(c.target, ":")
 	if colonPos == -1 {
-		colonPos = len(target)
+		colonPos = len(c.target)
 	}
-	c.authority = target[:colonPos]
+	c.authority = c.target[:colonPos]
 	if c.dopts.codec == nil {
 		// Set the default codec.
 		c.dopts.codec = protoCodec{}
@@ -284,7 +286,7 @@ func NewConn(target string, dopts dialOptions) (*Conn, error) {
 		// Start a goroutine connecting to the server asynchronously.
 		go func() {
 			if err := c.resetTransport(false); err != nil {
-				grpclog.Printf("Failed to dial %s: %v; please retry.", target, err)
+				grpclog.Printf("Failed to dial %s: %v; please retry.", c.target, err)
 				c.Close()
 				return
 			}
diff --git a/picker.go b/picker.go
index 05759f3a..79f98868 100644
--- a/picker.go
+++ b/picker.go
@@ -40,6 +40,8 @@ import (
 // Picker picks a Conn for RPC requests.
 // This is EXPERIMENTAL and Please do not implement your own Picker for now.
 type Picker interface {
+	// Init does initial processing for the Picker, e.g., initiate some connections.
+	Init(cc *ClientConn) error
 	// Pick returns the Conn to use for the upcoming RPC. It may return different
 	// Conn's up to the implementation.
 	Pick() (*Conn, error)
@@ -53,22 +55,21 @@ type Picker interface {
 	Close() error
 }
 
-func newUnicastPicker(target string, dopts dialOptions) (Picker, error) {
-	c, err := NewConn(target, dopts)
-	if err != nil {
-		return nil, err
-	}
-	return &unicastPicker{
-		conn: c,
-	}, nil
-}
-
 // unicastPicker is the default Picker which is used when there is no custom Picker
 // specified by users. It always picks the same Conn.
 type unicastPicker struct {
 	conn *Conn
 }
 
+func (p *unicastPicker) Init(cc *ClientConn) error {
+	c, err := NewConn(cc)
+	if err != nil {
+		return err
+	}
+	p.conn = c
+	return nil
+}
+
 func (p *unicastPicker) Pick() (*Conn, error) {
 	return p.conn, nil
 }
diff --git a/stream.go b/stream.go
index a9d7c49c..605b873d 100644
--- a/stream.go
+++ b/stream.go
@@ -102,7 +102,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 		err  error
 	)
 	for {
-		conn, err = cc.picker.Pick()
+		conn, err = cc.dopts.picker.Pick()
 		if err != nil {
 			return nil, toRPCErr(err)
 		}