merge the conflict

This commit is contained in:
iamqizhao
2016-07-28 13:17:23 -07:00
24 changed files with 328 additions and 202 deletions

View File

@ -1,17 +1,21 @@
language: go language: go
go: go:
- 1.5.3 - 1.5.4
- 1.6 - 1.6.3
go_import_path: google.golang.org/grpc
before_install: before_install:
- go get golang.org/x/tools/cmd/goimports
- go get github.com/golang/lint/golint
- go get github.com/axw/gocov/gocov - go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls - go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover - go get golang.org/x/tools/cmd/cover
install:
- mkdir -p "$GOPATH/src/google.golang.org"
- mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/google.golang.org/grpc"
script: script:
- '! gofmt -s -d -l . 2>&1 | read'
- '! goimports -l . | read'
- '! golint ./... | grep -vE "(_string|\.pb)\.go:"'
- '! go tool vet -all . 2>&1 | grep -vE "constant [0-9]+ not a string in call to Errorf"'
- make test testrace - make test testrace

View File

@ -243,7 +243,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
} }
// Remove the server. // Remove the server.
updates := []*naming.Update{&naming.Update{ updates := []*naming.Update{{
Op: naming.Delete, Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port, Addr: "127.0.0.1:" + servers[0].port,
}} }}
@ -287,7 +287,7 @@ func TestGetOnWaitChannel(t *testing.T) {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
// Remove all servers so that all upcoming RPCs will block on waitCh. // Remove all servers so that all upcoming RPCs will block on waitCh.
updates := []*naming.Update{&naming.Update{ updates := []*naming.Update{{
Op: naming.Delete, Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port, Addr: "127.0.0.1:" + servers[0].port,
}} }}
@ -310,7 +310,7 @@ func TestGetOnWaitChannel(t *testing.T) {
} }
}() }()
// Add a connected server to get the above RPC through. // Add a connected server to get the above RPC through.
updates = []*naming.Update{&naming.Update{ updates = []*naming.Update{{
Op: naming.Add, Op: naming.Add,
Addr: "127.0.0.1:" + servers[0].port, Addr: "127.0.0.1:" + servers[0].port,
}} }}

View File

@ -58,7 +58,7 @@ func closeLoopUnary() {
for i := 0; i < *maxConcurrentRPCs; i++ { for i := 0; i < *maxConcurrentRPCs; i++ {
go func() { go func() {
for _ = range ch { for range ch {
start := time.Now() start := time.Now()
unaryCaller(tc) unaryCaller(tc)
elapse := time.Since(start) elapse := time.Since(start)

View File

@ -196,7 +196,7 @@ func WithTimeout(d time.Duration) DialOption {
} }
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { func WithDialer(f func(string, time.Duration, <-chan struct{}) (net.Conn, error)) DialOption {
return func(o *dialOptions) { return func(o *dialOptions) {
o.copts.Dialer = f o.copts.Dialer = f
} }
@ -364,8 +364,8 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
cc: cc, cc: cc,
addr: addr, addr: addr,
dopts: cc.dopts, dopts: cc.dopts,
shutdownChan: make(chan struct{}),
} }
ac.dopts.copts.Cancel = make(chan struct{})
if EnableTracing { if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr) ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
} }
@ -471,7 +471,6 @@ type addrConn struct {
cc *ClientConn cc *ClientConn
addr Address addr Address
dopts dialOptions dopts dialOptions
shutdownChan chan struct{}
events trace.EventLog events trace.EventLog
mu sync.Mutex mu sync.Mutex
@ -558,12 +557,13 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
t.Close() t.Close()
} }
sleepTime := ac.dopts.bs.backoff(retries) sleepTime := ac.dopts.bs.backoff(retries)
ac.dopts.copts.Timeout = sleepTime copts := ac.dopts.copts
copts.Timeout = sleepTime
if sleepTime < minConnectTimeout { if sleepTime < minConnectTimeout {
ac.dopts.copts.Timeout = minConnectTimeout copts.Timeout = minConnectTimeout
} }
connectTime := time.Now() connectTime := time.Now()
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts) newTransport, err := transport.NewClientTransport(ac.addr.Addr, copts)
if err != nil { if err != nil {
ac.mu.Lock() ac.mu.Lock()
if ac.state == Shutdown { if ac.state == Shutdown {
@ -586,7 +586,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
closeTransport = false closeTransport = false
select { select {
case <-time.After(sleepTime): case <-time.After(sleepTime):
case <-ac.shutdownChan: case <-ac.dopts.copts.Cancel:
} }
retries++ retries++
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr) grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
@ -621,9 +621,9 @@ func (ac *addrConn) transportMonitor() {
t := ac.transport t := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
select { select {
// shutdownChan is needed to detect the teardown when // Cancel is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight). // the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan: case <-ac.dopts.copts.Cancel:
select { select {
case <-t.Error(): case <-t.Error():
t.Close() t.Close()
@ -647,7 +647,7 @@ func (ac *addrConn) transportMonitor() {
return return
case <-t.Error(): case <-t.Error():
select { select {
case <-ac.shutdownChan: case <-ac.dopts.copts.Cancel:
t.Close() t.Close()
return return
case <-t.GoAway(): case <-t.GoAway():
@ -750,8 +750,8 @@ func (ac *addrConn) tearDown(err error) {
if ac.transport != nil && err != errConnDrain { if ac.transport != nil && err != errConnDrain {
ac.transport.Close() ac.transport.Close()
} }
if ac.shutdownChan != nil { if ac.dopts.copts.Cancel != nil {
close(ac.shutdownChan) close(ac.dopts.copts.Cancel)
} }
return return
} }

View File

@ -153,7 +153,7 @@ func runRouteChat(client pb.RouteGuideClient) {
func randomPoint(r *rand.Rand) *pb.Point { func randomPoint(r *rand.Rand) *pb.Point {
lat := (r.Int31n(180) - 90) * 1e7 lat := (r.Int31n(180) - 90) * 1e7
long := (r.Int31n(360) - 180) * 1e7 long := (r.Int31n(360) - 180) * 1e7
return &pb.Point{lat, long} return &pb.Point{Latitude: lat, Longitude: long}
} }
func main() { func main() {
@ -186,13 +186,16 @@ func main() {
client := pb.NewRouteGuideClient(conn) client := pb.NewRouteGuideClient(conn)
// Looking for a valid feature // Looking for a valid feature
printFeature(client, &pb.Point{409146138, -746188906}) printFeature(client, &pb.Point{Latitude: 409146138, Longitude: -746188906})
// Feature missing. // Feature missing.
printFeature(client, &pb.Point{0, 0}) printFeature(client, &pb.Point{Latitude: 0, Longitude: 0})
// Looking for features between 40, -75 and 42, -73. // Looking for features between 40, -75 and 42, -73.
printFeatures(client, &pb.Rectangle{&pb.Point{Latitude: 400000000, Longitude: -750000000}, &pb.Point{Latitude: 420000000, Longitude: -730000000}}) printFeatures(client, &pb.Rectangle{
Lo: &pb.Point{Latitude: 400000000, Longitude: -750000000},
Hi: &pb.Point{Latitude: 420000000, Longitude: -730000000},
})
// RecordRoute // RecordRoute
runRecordRoute(client) runRecordRoute(client)

View File

@ -79,7 +79,7 @@ func (s *routeGuideServer) GetFeature(ctx context.Context, point *pb.Point) (*pb
} }
} }
// No feature was found, return an unnamed feature // No feature was found, return an unnamed feature
return &pb.Feature{"", point}, nil return &pb.Feature{Location: point}, nil
} }
// ListFeatures lists all features contained within the given bounding Rectangle. // ListFeatures lists all features contained within the given bounding Rectangle.

View File

@ -11,19 +11,22 @@ import (
healthpb "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1"
) )
type HealthServer struct { // Server implements `service Health`.
type Server struct {
mu sync.Mutex mu sync.Mutex
// statusMap stores the serving status of the services this HealthServer monitors. // statusMap stores the serving status of the services this Server monitors.
statusMap map[string]healthpb.HealthCheckResponse_ServingStatus statusMap map[string]healthpb.HealthCheckResponse_ServingStatus
} }
func NewHealthServer() *HealthServer { // NewServer returns a new Server.
return &HealthServer{ func NewServer() *Server {
return &Server{
statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus), statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus),
} }
} }
func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { // Check implements `service Health`.
func (s *Server) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if in.Service == "" { if in.Service == "" {
@ -42,7 +45,7 @@ func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckReques
// SetServingStatus is called when need to reset the serving status of a service // SetServingStatus is called when need to reset the serving status of a service
// or insert a new service entry into the statusMap. // or insert a new service entry into the statusMap.
func (s *HealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) { func (s *Server) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
s.mu.Lock() s.mu.Lock()
s.statusMap[service] = status s.statusMap[service] = status
s.mu.Unlock() s.mu.Unlock()

View File

@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
// DecodeKeyValue returns the original key and value corresponding to the // DecodeKeyValue returns the original key and value corresponding to the
// encoded data in k, v. // encoded data in k, v.
// If k is a binary header and v contains comma, v is split on comma before decoded,
// and the decoded v will be joined with comma before returned.
func DecodeKeyValue(k, v string) (string, string, error) { func DecodeKeyValue(k, v string) (string, string, error) {
if !strings.HasSuffix(k, binHdrSuffix) { if !strings.HasSuffix(k, binHdrSuffix) {
return k, v, nil return k, v, nil
} }
val, err := base64.StdEncoding.DecodeString(v) vvs := strings.Split(v, ",")
for i, vv := range vvs {
val, err := base64.StdEncoding.DecodeString(vv)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
return k, string(val), nil vvs[i] = string(val)
}
return k, strings.Join(vvs, ","), nil
} }
// MD is a mapping from metadata keys to values. Users should use the following // MD is a mapping from metadata keys to values. Users should use the following

View File

@ -74,6 +74,8 @@ func TestDecodeKeyValue(t *testing.T) {
{"a", "abc", "a", "abc", nil}, {"a", "abc", "a", "abc", nil},
{"key-bin", "Zm9vAGJhcg==", "key-bin", "foo\x00bar", nil}, {"key-bin", "Zm9vAGJhcg==", "key-bin", "foo\x00bar", nil},
{"key-bin", "woA=", "key-bin", binaryValue, nil}, {"key-bin", "woA=", "key-bin", binaryValue, nil},
{"a", "abc,efg", "a", "abc,efg", nil},
{"key-bin", "Zm9vAGJhcg==,Zm9vAGJhcg==", "key-bin", "foo\x00bar,foo\x00bar", nil},
} { } {
k, v, err := DecodeKeyValue(test.kin, test.vin) k, v, err := DecodeKeyValue(test.kin, test.vin)
if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) { if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) {

View File

@ -70,7 +70,7 @@ import (
type serverReflectionServer struct { type serverReflectionServer struct {
s *grpc.Server s *grpc.Server
// TODO add more cache if necessary // TODO add more cache if necessary
serviceInfo map[string]*grpc.ServiceInfo // cache for s.GetServiceInfo() serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo()
} }
// Register registers the server reflection service on the given gRPC server. // Register registers the server reflection service on the given gRPC server.

View File

@ -184,8 +184,8 @@ func TestContextErr(t *testing.T) {
func TestErrorsWithSameParameters(t *testing.T) { func TestErrorsWithSameParameters(t *testing.T) {
const description = "some description" const description = "some description"
e1 := Errorf(codes.AlreadyExists, description) e1 := Errorf(codes.AlreadyExists, description).(*rpcError)
e2 := Errorf(codes.AlreadyExists, description) e2 := Errorf(codes.AlreadyExists, description).(*rpcError)
if e1 == e2 { if e1 == e2 {
t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2) t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
} }

View File

@ -269,8 +269,8 @@ type ServiceInfo struct {
// GetServiceInfo returns a map from service names to ServiceInfo. // GetServiceInfo returns a map from service names to ServiceInfo.
// Service names include the package names, in the form of <package>.<service>. // Service names include the package names, in the form of <package>.<service>.
func (s *Server) GetServiceInfo() map[string]*ServiceInfo { func (s *Server) GetServiceInfo() map[string]ServiceInfo {
ret := make(map[string]*ServiceInfo) ret := make(map[string]ServiceInfo)
for n, srv := range s.m { for n, srv := range s.m {
methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd)) methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
for m := range srv.md { for m := range srv.md {
@ -288,7 +288,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
}) })
} }
ret[n] = &ServiceInfo{ ret[n] = ServiceInfo{
Methods: methods, Methods: methods,
Metadata: srv.mdata, Metadata: srv.mdata,
} }

View File

@ -90,15 +90,15 @@ func TestGetServiceInfo(t *testing.T) {
server.RegisterService(&testSd, &testServer{}) server.RegisterService(&testSd, &testServer{})
info := server.GetServiceInfo() info := server.GetServiceInfo()
want := map[string]*ServiceInfo{ want := map[string]ServiceInfo{
"grpc.testing.EmptyService": &ServiceInfo{ "grpc.testing.EmptyService": {
Methods: []MethodInfo{ Methods: []MethodInfo{
MethodInfo{ {
Name: "EmptyCall", Name: "EmptyCall",
IsClientStream: false, IsClientStream: false,
IsServerStream: false, IsServerStream: false,
}, },
MethodInfo{ {
Name: "EmptyStream", Name: "EmptyStream",
IsClientStream: true, IsClientStream: true,
IsServerStream: false, IsServerStream: false,
@ -108,6 +108,6 @@ func TestGetServiceInfo(t *testing.T) {
} }
if !reflect.DeepEqual(info, want) { if !reflect.DeepEqual(info, want) {
t.Errorf("GetServiceInfo() = %q, want %q", info, want) t.Errorf("GetServiceInfo() = %+v, want %+v", info, want)
} }
} }

View File

@ -300,39 +300,35 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
const tlsDir = "testdata/" const tlsDir = "testdata/"
func unixDialer(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}
type env struct { type env struct {
name string name string
network string // The type of network such as tcp, unix, etc. network string // The type of network such as tcp, unix, etc.
dialer func(addr string, timeout time.Duration) (net.Conn, error)
security string // The security protocol such as TLS, SSH, etc. security string // The security protocol such as TLS, SSH, etc.
httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS
} }
func (e env) runnable() bool { func (e env) runnable() bool {
if runtime.GOOS == "windows" && strings.HasPrefix(e.name, "unix-") { if runtime.GOOS == "windows" && e.network == "unix" {
return false return false
} }
return true return true
} }
func (e env) getDialer() func(addr string, timeout time.Duration) (net.Conn, error) { func (e env) dialer(addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) {
if e.dialer != nil { // NB: Go 1.6 added a Cancel field on net.Dialer, which would allow this
return e.dialer // to be written as
} //
return func(addr string, timeout time.Duration) (net.Conn, error) { // `(&net.Dialer{Cancel: cancel, Timeout: timeout}).Dial(e.network, addr)`
return net.DialTimeout("tcp", addr, timeout) //
} // but that would break compatibility with earlier Go versions.
return net.DialTimeout(e.network, addr, timeout)
} }
var ( var (
tcpClearEnv = env{name: "tcp-clear", network: "tcp"} tcpClearEnv = env{name: "tcp-clear", network: "tcp"}
tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"}
unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer} unixClearEnv = env{name: "unix-clear", network: "unix"}
unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"} unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls"}
handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true} handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true}
allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv} allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv}
) )
@ -371,7 +367,7 @@ type test struct {
// Configurable knobs, after newTest returns: // Configurable knobs, after newTest returns:
testServer testpb.TestServiceServer // nil means none testServer testpb.TestServiceServer // nil means none
healthServer *health.HealthServer // nil means disabled healthServer *health.Server // nil means disabled
maxStream uint32 maxStream uint32
userAgent string userAgent string
clientCompression bool clientCompression bool
@ -515,9 +511,7 @@ func (te *test) declareLogNoise(phrases ...string) {
} }
func (te *test) withServerTester(fn func(st *serverTester)) { func (te *test) withServerTester(fn func(st *serverTester)) {
var c net.Conn c, err := te.e.dialer(te.srvAddr, 10*time.Second, nil)
var err error
c, err = te.e.getDialer()(te.srvAddr, 10*time.Second)
if err != nil { if err != nil {
te.t.Fatal(err) te.t.Fatal(err)
} }
@ -857,7 +851,7 @@ func TestHealthCheckOnSuccess(t *testing.T) {
func testHealthCheckOnSuccess(t *testing.T, e env) { func testHealthCheckOnSuccess(t *testing.T, e env) {
te := newTest(t, e) te := newTest(t, e)
hs := health.NewHealthServer() hs := health.NewServer()
hs.SetServingStatus("grpc.health.v1.Health", 1) hs.SetServingStatus("grpc.health.v1.Health", 1)
te.healthServer = hs te.healthServer = hs
te.startServer(&testServer{security: e.security}) te.startServer(&testServer{security: e.security})
@ -883,7 +877,7 @@ func testHealthCheckOnFailure(t *testing.T, e env) {
"Failed to dial ", "Failed to dial ",
"grpc: the client connection is closing; please retry", "grpc: the client connection is closing; please retry",
) )
hs := health.NewHealthServer() hs := health.NewServer()
hs.SetServingStatus("grpc.health.v1.HealthCheck", 1) hs.SetServingStatus("grpc.health.v1.HealthCheck", 1)
te.healthServer = hs te.healthServer = hs
te.startServer(&testServer{security: e.security}) te.startServer(&testServer{security: e.security})
@ -927,7 +921,7 @@ func TestHealthCheckServingStatus(t *testing.T) {
func testHealthCheckServingStatus(t *testing.T, e env) { func testHealthCheckServingStatus(t *testing.T, e env) {
te := newTest(t, e) te := newTest(t, e)
hs := health.NewHealthServer() hs := health.NewServer()
te.healthServer = hs te.healthServer = hs
te.startServer(&testServer{security: e.security}) te.startServer(&testServer{security: e.security})
defer te.tearDown() defer te.tearDown()

45
transport/go16.go Normal file
View File

@ -0,0 +1,45 @@
// +build go1.6
/*
* Copyright 2014, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
)
// newDialer constructs a net.Dialer.
func newDialer(timeout time.Duration, cancel <-chan struct{}) *net.Dialer {
return &net.Dialer{Cancel: cancel, Timeout: timeout}
}

View File

@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
} }
if v := r.Header.Get("grpc-timeout"); v != "" { if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := timeoutDecode(v) to, err := decodeTimeout(v)
if err != nil { if err != nil {
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err) return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
} }
@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
h := ht.rw.Header() h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" { if statusDesc != "" {
h.Set("Grpc-Message", grpcMessageEncode(statusDesc)) h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
} }
if md := s.Trailer(); len(md) > 0 { if md := s.Trailer(); len(md) > 0 {
for k, vv := range md { for k, vv := range md {

View File

@ -333,7 +333,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
"Content-Type": {"application/grpc"}, "Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message"}, "Trailer": {"Grpc-Status", "Grpc-Message"},
"Grpc-Status": {fmt.Sprint(uint32(statusCode))}, "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {grpcMessageEncode(msg)}, "Grpc-Message": {encodeGrpcMessage(msg)},
} }
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
@ -381,7 +381,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
"Content-Type": {"application/grpc"}, "Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message"}, "Trailer": {"Grpc-Status", "Grpc-Message"},
"Grpc-Status": {"4"}, "Grpc-Status": {"4"},
"Grpc-Message": {grpcMessageEncode("too slow")}, "Grpc-Message": {encodeGrpcMessage("too slow")},
} }
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) { if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader) t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)

View File

@ -107,20 +107,21 @@ type http2Client struct {
prevGoAwayID uint32 prevGoAwayID uint32
} }
func dial(fn func(string, time.Duration, <-chan struct{}) (net.Conn, error), addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) {
if fn != nil {
return fn(addr, timeout, cancel)
}
return newDialer(timeout, cancel).Dial("tcp", addr)
}
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction // and starts to receive messages on it. Non-nil error returns if construction
// fails. // fails.
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { func newHTTP2Client(addr string, opts ConnectOptions) (_ ClientTransport, err error) {
if opts.Dialer == nil {
// Set the default Dialer.
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
}
scheme := "http" scheme := "http"
startT := time.Now() startT := time.Now()
timeout := opts.Timeout timeout := opts.Timeout
conn, connErr := opts.Dialer(addr, timeout) conn, connErr := dial(opts.Dialer, addr, timeout, opts.Cancel)
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
} }
@ -341,7 +342,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
} }
if timeout > 0 { if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
} }
for k, v := range authData { for k, v := range authData {
// Capital header names are illegal in HTTP/2. // Capital header names are illegal in HTTP/2.

View File

@ -265,6 +265,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
t.controlBuf.put(&resetStream{se.StreamID, se.Code}) t.controlBuf.put(&resetStream{se.StreamID, se.Code})
continue continue
} }
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close() t.Close()
return return
} }
@ -507,7 +508,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
Name: "grpc-status", Name: "grpc-status",
Value: strconv.Itoa(int(statusCode)), Value: strconv.Itoa(int(statusCode)),
}) })
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: grpcMessageEncode(statusDesc)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
// Attach the trailer metadata. // Attach the trailer metadata.
for k, v := range s.trailer { for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent. // Clients don't tolerate reading restricted headers after some non restricted ones were sent.

View File

@ -35,6 +35,7 @@ package transport
import ( import (
"bufio" "bufio"
"bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
} }
d.statusCode = codes.Code(code) d.statusCode = codes.Code(code)
case "grpc-message": case "grpc-message":
d.statusDesc = grpcMessageDecode(f.Value) d.statusDesc = decodeGrpcMessage(f.Value)
case "grpc-timeout": case "grpc-timeout":
d.timeoutSet = true d.timeoutSet = true
var err error var err error
d.timeout, err = timeoutDecode(f.Value) d.timeout, err = decodeTimeout(f.Value)
if err != nil { if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err)) d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return return
@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
} }
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. // TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func timeoutEncode(t time.Duration) string { func encodeTimeout(t time.Duration) string {
if d := div(t, time.Nanosecond); d <= maxTimeoutValue { if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n" return strconv.FormatInt(d, 10) + "n"
} }
@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
return strconv.FormatInt(div(t, time.Hour), 10) + "H" return strconv.FormatInt(div(t, time.Hour), 10) + "H"
} }
func timeoutDecode(s string) (time.Duration, error) { func decodeTimeout(s string) (time.Duration, error) {
size := len(s) size := len(s)
if size < 2 { if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s) return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
return d * time.Duration(t), nil return d * time.Duration(t), nil
} }
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
// "grpc-message".
// It checks to see if each individual byte in msg is an
// allowable byte, and then either percent encoding or passing it through.
// When percent encoding, the byte is converted into hexadecimal notation
// with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
buf.WriteByte(c)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
func decodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return decodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func decodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
buf.WriteByte(c)
} else {
buf.WriteByte(byte(parsed))
i += 2
}
} else {
buf.WriteByte(c)
}
}
return buf.String()
}
type framer struct { type framer struct {
numWriters int32 numWriters int32
reader io.Reader reader io.Reader

View File

@ -59,7 +59,7 @@ func TestTimeoutEncode(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to parse duration string %s: %v", test.in, err) t.Fatalf("failed to parse duration string %s: %v", test.in, err)
} }
out := timeoutEncode(d) out := encodeTimeout(d)
if out != test.out { if out != test.out {
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out) t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
} }
@ -79,7 +79,7 @@ func TestTimeoutDecode(t *testing.T) {
{"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")}, {"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")},
{"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")}, {"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")},
} { } {
d, err := timeoutDecode(test.s) d, err := decodeTimeout(test.s)
if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) { if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) {
t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err) t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err)
} }
@ -107,3 +107,38 @@ func TestValidContentType(t *testing.T) {
} }
} }
} }
func TestEncodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string
expected string
}{
{"", ""},
{"Hello", "Hello"},
{"my favorite character is \u0000", "my favorite character is %00"},
{"my favorite character is %", "my favorite character is %25"},
} {
actual := encodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
}
}
}
func TestDecodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string
expected string
}{
{"", ""},
{"Hello", "Hello"},
{"H%61o", "Hao"},
{"H%6", "H%6"},
{"%G0", "%G0"},
} {
actual := decodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
}
}
}

45
transport/pre_go16.go Normal file
View File

@ -0,0 +1,45 @@
// +build !go1.6
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
)
// newDialer constructs a net.Dialer.
func newDialer(timeout time.Duration, _ <-chan struct{}) *net.Dialer {
return &net.Dialer{Timeout: timeout}
}

View File

@ -43,7 +43,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strconv"
"sync" "sync"
"time" "time"
@ -354,8 +353,10 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authI
type ConnectOptions struct { type ConnectOptions struct {
// UserAgent is the application user agent. // UserAgent is the application user agent.
UserAgent string UserAgent string
// Cancel is closed to indicate that dialing should be cancelled.
Cancel chan struct{}
// Dialer specifies how to dial a network address. // Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error) Dialer func(string, time.Duration, <-chan struct{}) (net.Conn, error)
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs. // PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection. // TransportCredentials stores the Authenticator required to setup a client connection.
@ -366,7 +367,7 @@ type ConnectOptions struct {
// NewClientTransport establishes the transport with the required ConnectOptions // NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller. // and returns it to the caller.
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { func NewClientTransport(target string, opts ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(target, opts) return newHTTP2Client(target, opts)
} }
@ -559,74 +560,3 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
return i, nil return i, nil
} }
} }
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// grpcMessageEncode encodes the grpc-message field in the same
// manner as https://github.com/grpc/grpc-java/pull/1517.
func grpcMessageEncode(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return grpcMessageEncodeUnchecked(msg)
}
}
return msg
}
func grpcMessageEncodeUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
_ = buf.WriteByte(c)
} else {
_, _ = buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// grpcMessageDecode decodes the grpc-message field in the same
// manner as https://github.com/grpc/grpc-java/pull/1517.
func grpcMessageDecode(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return grpcMessageDecodeUnchecked(msg)
}
}
return msg
}
func grpcMessageDecodeUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
_ = buf.WriteByte(c)
} else {
_ = buf.WriteByte(byte(parsed))
i += 2
}
} else {
_ = buf.WriteByte(c)
}
}
return buf.String()
}

View File

@ -37,7 +37,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"net" "net"
"reflect" "reflect"
@ -131,7 +130,7 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) {
// raw newline is not accepted by http2 framer so it must be encoded. // raw newline is not accepted by http2 framer so it must be encoded.
h.t.WriteStatus(s, codes.Internal, "\n") h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
} }
// start starts server. Other goroutines should block on s.readyChan for further operations. // start starts server. Other goroutines should block on s.readyChan for further operations.
@ -221,7 +220,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client
ct ClientTransport ct ClientTransport
connErr error connErr error
) )
ct, connErr = NewClientTransport(addr, &ConnectOptions{}) ct, connErr = NewClientTransport(addr, ConnectOptions{})
if connErr != nil { if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr) t.Fatalf("failed to create transport: %v", connErr)
} }
@ -714,6 +713,11 @@ func TestClientWithMisbehavedServer(t *testing.T) {
server.stop() server.stop()
} }
var (
encodingTestStatusCode = codes.Internal
encodingTestStatusDesc = "\n"
)
func TestEncodingRequiredStatus(t *testing.T) { func TestEncodingRequiredStatus(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
callHdr := &CallHdr{ callHdr := &CallHdr{
@ -731,8 +735,12 @@ func TestEncodingRequiredStatus(t *testing.T) {
if err := ct.Write(s, expectedRequest, &opts); err != nil { if err := ct.Write(s, expectedRequest, &opts); err != nil {
t.Fatalf("Failed to write the request: %v", err) t.Fatalf("Failed to write the request: %v", err)
} }
if _, err = ioutil.ReadAll(s); err != nil { p := make([]byte, http2MaxFrameLen)
t.Fatal(err) if _, err := s.dec.Read(p); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF)
}
if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc {
t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc)
} }
ct.Close() ct.Close()
server.stop() server.stop()
@ -769,29 +777,3 @@ func TestIsReservedHeader(t *testing.T) {
} }
} }
} }
func TestGrpcMessageEncode(t *testing.T) {
testGrpcMessageEncode(t, "my favorite character is \u0000", "my favorite character is %00")
testGrpcMessageEncode(t, "my favorite character is %", "my favorite character is %25")
}
func TestGrpcMessageDecode(t *testing.T) {
testGrpcMessageDecode(t, "Hello", "Hello")
testGrpcMessageDecode(t, "H%61o", "Hao")
testGrpcMessageDecode(t, "H%6", "H%6")
testGrpcMessageDecode(t, "%G0", "%G0")
}
func testGrpcMessageEncode(t *testing.T, input string, expected string) {
actual := grpcMessageEncode(input)
if expected != actual {
t.Errorf("Expected %s from grpcMessageEncode, got %s", expected, actual)
}
}
func testGrpcMessageDecode(t *testing.T, input string, expected string) {
actual := grpcMessageDecode(input)
if expected != actual {
t.Errorf("Expected %s from grpcMessageDecode, got %s", expected, actual)
}
}