From 9afcd0c6977d24efce56e0102c0e050e1cec554b Mon Sep 17 00:00:00 2001
From: iamqizhao <toqizhao@gmail.com>
Date: Wed, 23 Sep 2015 19:09:37 -0700
Subject: [PATCH] preliminary refactoring for custom naming and load balancing

---
 call.go       |  13 ++--
 clientconn.go | 211 ++++++++++++++++++++++++++++++--------------------
 picker.go     |  45 +++++++++++
 stream.go     |  10 ++-
 4 files changed, 185 insertions(+), 94 deletions(-)
 create mode 100644 picker.go

diff --git a/call.go b/call.go
index 0115a28d..7ca088a8 100644
--- a/call.go
+++ b/call.go
@@ -116,7 +116,10 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 			o.after(&c)
 		}
 	}()
-
+	conn, err := cc.picker.Pick()
+	if err != nil {
+		return toRPCErr(err)
+	}
 	if EnableTracing {
 		c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method)
 		defer c.traceInfo.tr.Finish()
@@ -134,7 +137,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 		}()
 	}
 	callHdr := &transport.CallHdr{
-		Host:   cc.authority,
+		Host:   conn.authority,
 		Method: method,
 	}
 	topts := &transport.Options{
@@ -154,7 +157,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 		if lastErr != nil && c.failFast {
 			return toRPCErr(lastErr)
 		}
-		t, err = cc.wait(ctx)
+		t, err = conn.wait(ctx)
 		if err != nil {
 			if lastErr != nil {
 				// This was a retry; return the error from the last attempt.
@@ -165,7 +168,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 		if c.traceInfo.tr != nil {
 			c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
 		}
-		stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts)
+		stream, err = sendRequest(ctx, conn.dopts.codec, callHdr, t, args, topts)
 		if err != nil {
 			if _, ok := err.(transport.ConnectionError); ok {
 				lastErr = err
@@ -177,7 +180,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
 			return toRPCErr(err)
 		}
 		// Receive the response
-		lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply)
+		lastErr = recvResponse(conn.dopts.codec, t, &c, stream, reply)
 		if _, ok := lastErr.(transport.ConnectionError); ok {
 			continue
 		}
diff --git a/clientconn.go b/clientconn.go
index 87f302fc..2b8e699d 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -73,6 +73,7 @@ var (
 // values passed to Dial.
 type dialOptions struct {
 	codec    Codec
+	picker   Picker
 	block    bool
 	insecure bool
 	copts    transport.ConnectOptions
@@ -142,88 +143,18 @@ func WithUserAgent(s string) DialOption {
 
 // Dial creates a client connection the given target.
 func Dial(target string, opts ...DialOption) (*ClientConn, error) {
-	if target == "" {
-		return nil, ErrUnspecTarget
-	}
-	cc := &ClientConn{
-		target:       target,
-		shutdownChan: make(chan struct{}),
-	}
-	if EnableTracing {
-		cc.events = trace.NewEventLog("grpc.ClientConn", target)
-	}
+	var dopts dialOptions
 	for _, opt := range opts {
-		opt(&cc.dopts)
+		opt(&dopts)
 	}
-	if !cc.dopts.insecure {
-		var ok bool
-		for _, c := range cc.dopts.copts.AuthOptions {
-			if _, ok := c.(credentials.TransportAuthenticator); !ok {
-				continue
-			}
-			ok = true
-		}
-		if !ok {
-			return nil, ErrNoTransportSecurity
-		}
-	} else {
-		for _, c := range cc.dopts.copts.AuthOptions {
-			if c.RequireTransportSecurity() {
-				return nil, ErrCredentialsMisuse
-			}
-		}
-	}
-	colonPos := strings.LastIndex(target, ":")
-	if colonPos == -1 {
-		colonPos = len(target)
-	}
-	cc.authority = target[:colonPos]
-	if cc.dopts.codec == nil {
-		// Set the default codec.
-		cc.dopts.codec = protoCodec{}
-	}
-	cc.stateCV = sync.NewCond(&cc.mu)
-	if cc.dopts.block {
-		if err := cc.resetTransport(false); err != nil {
-			cc.mu.Lock()
-			cc.errorf("dial failed: %v", err)
-			cc.mu.Unlock()
-			cc.Close()
+	if dopts.picker == nil {
+		p, err := newSimplePicker(target, dopts)
+		if err != nil {
 			return nil, err
 		}
-		// Start to monitor the error status of transport.
-		go cc.transportMonitor()
-	} else {
-		// Start a goroutine connecting to the server asynchronously.
-		go func() {
-			if err := cc.resetTransport(false); err != nil {
-				cc.mu.Lock()
-				cc.errorf("dial failed: %v", err)
-				cc.mu.Unlock()
-				grpclog.Printf("Failed to dial %s: %v; please retry.", target, err)
-				cc.Close()
-				return
-			}
-			go cc.transportMonitor()
-		}()
-	}
-	return cc, nil
-}
-
-// printf records an event in cc's event log, unless cc has been closed.
-// REQUIRES cc.mu is held.
-func (cc *ClientConn) printf(format string, a ...interface{}) {
-	if cc.events != nil {
-		cc.events.Printf(format, a...)
-	}
-}
-
-// errorf records an error in cc's event log, unless cc has been closed.
-// REQUIRES cc.mu is held.
-func (cc *ClientConn) errorf(format string, a ...interface{}) {
-	if cc.events != nil {
-		cc.events.Errorf(format, a...)
+		dopts.picker = p
 	}
+	return &ClientConn{dopts.picker}, nil
 }
 
 // ConnectivityState indicates the state of a client connection.
@@ -261,6 +192,36 @@ func (s ConnectivityState) String() string {
 
 // ClientConn represents a client connection to an RPC service.
 type ClientConn struct {
+	picker Picker
+}
+
+// State returns the connectivity state of the Conn used for next upcoming RPC.
+func (cc *ClientConn) State() ConnectivityState {
+	c := cc.picker.Peek()
+	if c == nil {
+		return Idle
+	}
+	return c.getState()
+}
+
+// WaitForStateChange blocks until the state changes to something other than the sourceState
+// or timeout fires on the Conn used for next upcoming RPC. It returns false if the Conn is nil
+// or timeout fires, and true otherwise.
+func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool {
+	c := cc.picker.Peek()
+	if c == nil {
+		return false
+	}
+	return c.waitForStateChange(timeout, sourceState)
+}
+
+// Close starts to tear down the ClientConn.
+func (cc *ClientConn) Close() error {
+	return cc.picker.Close()
+}
+
+// Conn is a client connection to a single destination.
+type Conn struct {
 	target       string
 	authority    string
 	dopts        dialOptions
@@ -276,16 +237,94 @@ type ClientConn struct {
 	transport transport.ClientTransport
 }
 
-// State returns the connectivity state of the ClientConn
-func (cc *ClientConn) State() ConnectivityState {
+// NewConn creates a Conn.
+func NewConn(target string, dopts dialOptions) (*Conn, error) {
+	if target == "" {
+		return nil, ErrUnspecTarget
+	}
+	c := &Conn{
+		target:       target,
+		dopts:        dopts,
+		shutdownChan: make(chan struct{}),
+	}
+	if EnableTracing {
+		c.events = trace.NewEventLog("grpc.ClientConn", target)
+	}
+	if !c.dopts.insecure {
+		var ok bool
+		for _, cd := range c.dopts.copts.AuthOptions {
+			if _, ok := cd.(credentials.TransportAuthenticator); !ok {
+				continue
+			}
+			ok = true
+		}
+		if !ok {
+			return nil, ErrNoTransportSecurity
+		}
+	} else {
+		for _, cd := range c.dopts.copts.AuthOptions {
+			if cd.RequireTransportSecurity() {
+				return nil, ErrCredentialsMisuse
+			}
+		}
+	}
+	colonPos := strings.LastIndex(target, ":")
+	if colonPos == -1 {
+		colonPos = len(target)
+	}
+	c.authority = target[:colonPos]
+	if c.dopts.codec == nil {
+		// Set the default codec.
+		c.dopts.codec = protoCodec{}
+	}
+	c.stateCV = sync.NewCond(&c.mu)
+	if c.dopts.block {
+		if err := c.resetTransport(false); err != nil {
+			c.Close()
+			return nil, err
+		}
+		// Start to monitor the error status of transport.
+		go c.transportMonitor()
+	} else {
+		// Start a goroutine connecting to the server asynchronously.
+		go func() {
+			if err := c.resetTransport(false); err != nil {
+				grpclog.Printf("Failed to dial %s: %v; please retry.", target, err)
+				c.Close()
+				return
+			}
+			go c.transportMonitor()
+		}()
+	}
+	return c, nil
+}
+
+// printf records an event in cc's event log, unless cc has been closed.
+// REQUIRES cc.mu is held.
+func (cc *Conn) printf(format string, a ...interface{}) {
+	if cc.events != nil {
+		cc.events.Printf(format, a...)
+	}
+}
+
+// errorf records an error in cc's event log, unless cc has been closed.
+// REQUIRES cc.mu is held.
+func (cc *Conn) errorf(format string, a ...interface{}) {
+	if cc.events != nil {
+		cc.events.Errorf(format, a...)
+	}
+}
+
+// getState returns the connectivity state of the Conn
+func (cc *Conn) getState() ConnectivityState {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	return cc.state
 }
 
-// WaitForStateChange blocks until the state changes to something other than the sourceState
+// waitForStateChange blocks until the state changes to something other than the sourceState
 // or timeout fires. It returns false if timeout fires and true otherwise.
-func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool {
+func (cc *Conn) waitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool {
 	start := time.Now()
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
@@ -317,7 +356,7 @@ func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState Conn
 	return true
 }
 
-func (cc *ClientConn) resetTransport(closeTransport bool) error {
+func (cc *Conn) resetTransport(closeTransport bool) error {
 	var retries int
 	start := time.Now()
 	for {
@@ -402,7 +441,7 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error {
 
 // Run in a goroutine to track the error in transport and create the
 // new transport if an error happens. It returns when the channel is closing.
-func (cc *ClientConn) transportMonitor() {
+func (cc *Conn) transportMonitor() {
 	for {
 		select {
 		// shutdownChan is needed to detect the teardown when
@@ -429,7 +468,7 @@ func (cc *ClientConn) transportMonitor() {
 
 // When wait returns, either the new transport is up or ClientConn is
 // closing.
-func (cc *ClientConn) wait(ctx context.Context) (transport.ClientTransport, error) {
+func (cc *Conn) wait(ctx context.Context) (transport.ClientTransport, error) {
 	for {
 		cc.mu.Lock()
 		switch {
@@ -456,12 +495,12 @@ func (cc *ClientConn) wait(ctx context.Context) (transport.ClientTransport, erro
 	}
 }
 
-// Close starts to tear down the ClientConn. Returns ErrClientConnClosing if
+// Close starts to tear down the Conn. Returns ErrClientConnClosing if
 // it has been closed (mostly due to dial time-out).
 // TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in
 // some edge cases (e.g., the caller opens and closes many ClientConn's in a
 // tight loop.
-func (cc *ClientConn) Close() error {
+func (cc *Conn) Close() error {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	if cc.state == Shutdown {
diff --git a/picker.go b/picker.go
new file mode 100644
index 00000000..7c633471
--- /dev/null
+++ b/picker.go
@@ -0,0 +1,45 @@
+package grpc
+
+// Picker picks a Conn for RPC requests.
+// This is EXPERIMENTAL and Please do not implement your own Picker for now.
+type Picker interface {
+	// Pick returns the Conn to use for the upcoming RPC. It may return different
+	// Conn's up to the implementation.
+	Pick() (*Conn, error)
+	// Peek returns the Conn use use for the next upcoming RPC. It returns the same
+	// Conn until next time Pick gets invoked.
+	Peek() *Conn
+	// Close closes all the Conn's owned by this Picker.
+	Close() error
+}
+
+func newSimplePicker(target string, dopts dialOptions) (Picker, error) {
+	c, err := NewConn(target, dopts)
+	if err != nil {
+		return nil, err
+	}
+	return &simplePicker{
+		conn: c,
+	}, nil
+}
+
+// simplePicker is default Picker which is used when there is no custom Picker
+// specified by users. It always picks the same Conn.
+type simplePicker struct {
+	conn *Conn
+}
+
+func (p *simplePicker) Pick() (*Conn, error) {
+	return p.conn, nil
+}
+
+func (p *simplePicker) Peek() *Conn {
+	return p.conn
+}
+
+func (p *simplePicker) Close() error {
+	if p.conn != nil {
+		return p.conn.Close()
+	}
+	return nil
+}
diff --git a/stream.go b/stream.go
index e14664cb..d66e1a41 100644
--- a/stream.go
+++ b/stream.go
@@ -96,14 +96,18 @@ type ClientStream interface {
 // NewClientStream creates a new Stream for the client side. This is called
 // by generated code.
 func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
+	conn, err := cc.picker.Pick()
+	if err != nil {
+		return nil, toRPCErr(err)
+	}
 	// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
 	callHdr := &transport.CallHdr{
-		Host:   cc.authority,
+		Host:   conn.authority,
 		Method: method,
 	}
 	cs := &clientStream{
 		desc:    desc,
-		codec:   cc.dopts.codec,
+		codec:   conn.dopts.codec,
 		tracing: EnableTracing,
 	}
 	if cs.tracing {
@@ -114,7 +118,7 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
 		}
 		cs.traceInfo.tr.LazyLog(&cs.traceInfo.firstLine, false)
 	}
-	t, err := cc.wait(ctx)
+	t, err := conn.wait(ctx)
 	if err != nil {
 		return nil, toRPCErr(err)
 	}