This commit is contained in:
iamqizhao
2016-07-28 13:04:50 -07:00
27 changed files with 480 additions and 204 deletions

View File

@ -1,17 +1,21 @@
language: go
go:
- 1.5.3
- 1.6
- 1.5.4
- 1.6.3
go_import_path: google.golang.org/grpc
before_install:
- go get golang.org/x/tools/cmd/goimports
- go get github.com/golang/lint/golint
- go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover
install:
- mkdir -p "$GOPATH/src/google.golang.org"
- mv "$TRAVIS_BUILD_DIR" "$GOPATH/src/google.golang.org/grpc"
script:
- '! gofmt -s -d -l . 2>&1 | read'
- '! goimports -l . | read'
- '! golint ./... | grep -vE "(_string|\.pb)\.go:"'
- '! go tool vet -all . 2>&1 | grep -vE "constant [0-9]+ not a string in call to Errorf"'
- make test testrace

View File

@ -243,7 +243,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
}
// Remove the server.
updates := []*naming.Update{&naming.Update{
updates := []*naming.Update{{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
}}
@ -287,7 +287,7 @@ func TestGetOnWaitChannel(t *testing.T) {
t.Fatalf("Failed to create ClientConn: %v", err)
}
// Remove all servers so that all upcoming RPCs will block on waitCh.
updates := []*naming.Update{&naming.Update{
updates := []*naming.Update{{
Op: naming.Delete,
Addr: "127.0.0.1:" + servers[0].port,
}}
@ -310,7 +310,7 @@ func TestGetOnWaitChannel(t *testing.T) {
}
}()
// Add a connected server to get the above RPC through.
updates = []*naming.Update{&naming.Update{
updates = []*naming.Update{{
Op: naming.Add,
Addr: "127.0.0.1:" + servers[0].port,
}}

View File

@ -58,7 +58,7 @@ func closeLoopUnary() {
for i := 0; i < *maxConcurrentRPCs; i++ {
go func() {
for _ = range ch {
for range ch {
start := time.Now()
unaryCaller(tc)
elapse := time.Since(start)

View File

@ -196,7 +196,7 @@ func WithTimeout(d time.Duration) DialOption {
}
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption {
func WithDialer(f func(string, time.Duration, <-chan struct{}) (net.Conn, error)) DialOption {
return func(o *dialOptions) {
o.copts.Dialer = f
}
@ -361,11 +361,11 @@ func (cc *ClientConn) lbWatcher() {
func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
ac := &addrConn{
cc: cc,
addr: addr,
dopts: cc.dopts,
shutdownChan: make(chan struct{}),
cc: cc,
addr: addr,
dopts: cc.dopts,
}
ac.dopts.copts.Cancel = make(chan struct{})
if EnableTracing {
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
}
@ -468,11 +468,10 @@ func (cc *ClientConn) Close() error {
// addrConn is a network connection to a given address.
type addrConn struct {
cc *ClientConn
addr Address
dopts dialOptions
shutdownChan chan struct{}
events trace.EventLog
cc *ClientConn
addr Address
dopts dialOptions
events trace.EventLog
mu sync.Mutex
state ConnectivityState
@ -558,12 +557,13 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
t.Close()
}
sleepTime := ac.dopts.bs.backoff(retries)
ac.dopts.copts.Timeout = sleepTime
copts := ac.dopts.copts
copts.Timeout = sleepTime
if sleepTime < minConnectTimeout {
ac.dopts.copts.Timeout = minConnectTimeout
copts.Timeout = minConnectTimeout
}
connectTime := time.Now()
newTransport, err := transport.NewClientTransport(ac.addr.Addr, &ac.dopts.copts)
newTransport, err := transport.NewClientTransport(ac.addr.Addr, copts)
if err != nil {
ac.mu.Lock()
if ac.state == Shutdown {
@ -586,7 +586,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
closeTransport = false
select {
case <-time.After(sleepTime):
case <-ac.shutdownChan:
case <-ac.dopts.copts.Cancel:
}
retries++
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
@ -621,9 +621,9 @@ func (ac *addrConn) transportMonitor() {
t := ac.transport
ac.mu.Unlock()
select {
// shutdownChan is needed to detect the teardown when
// Cancel is needed to detect the teardown when
// the addrConn is idle (i.e., no RPC in flight).
case <-ac.shutdownChan:
case <-ac.dopts.copts.Cancel:
return
case <-t.GoAway():
ac.tearDown(errConnDrain)
@ -724,8 +724,8 @@ func (ac *addrConn) tearDown(err error) {
if ac.transport != nil && err != errConnDrain {
ac.transport.Close()
}
if ac.shutdownChan != nil {
close(ac.shutdownChan)
if ac.dopts.copts.Cancel != nil {
close(ac.dopts.copts.Cancel)
}
return
}

View File

@ -152,7 +152,7 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
})
}
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := *c.config
cfg := cloneTLSConfig(c.config)
if c.config.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
@ -160,7 +160,7 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
}
cfg.ServerName = addr[:colonPos]
}
conn := tls.Client(rawConn, &cfg)
conn := tls.Client(rawConn, cfg)
if timeout == 0 {
err = conn.Handshake()
} else {
@ -189,7 +189,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
// NewTLS uses c to construct a TransportCredentials based on TLS.
func NewTLS(c *tls.Config) TransportCredentials {
tc := &tlsCreds{c}
tc := &tlsCreds{cloneTLSConfig(c)}
tc.config.NextProtos = alpnProtoStr
return tc
}

View File

@ -0,0 +1,76 @@
// +build go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
Renegotiation: cfg.Renegotiation,
}
}

View File

@ -0,0 +1,74 @@
// +build !go1.7
/*
*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package credentials
import (
"crypto/tls"
)
// cloneTLSConfig returns a shallow clone of the exported
// fields of cfg, ignoring the unexported sync.Once, which
// contains a mutex and must not be copied.
//
// If cfg is nil, a new zero tls.Config is returned.
//
// TODO replace this function with official clone function.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
SessionTicketsDisabled: cfg.SessionTicketsDisabled,
SessionTicketKey: cfg.SessionTicketKey,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@ -153,7 +153,7 @@ func runRouteChat(client pb.RouteGuideClient) {
func randomPoint(r *rand.Rand) *pb.Point {
lat := (r.Int31n(180) - 90) * 1e7
long := (r.Int31n(360) - 180) * 1e7
return &pb.Point{lat, long}
return &pb.Point{Latitude: lat, Longitude: long}
}
func main() {
@ -186,13 +186,16 @@ func main() {
client := pb.NewRouteGuideClient(conn)
// Looking for a valid feature
printFeature(client, &pb.Point{409146138, -746188906})
printFeature(client, &pb.Point{Latitude: 409146138, Longitude: -746188906})
// Feature missing.
printFeature(client, &pb.Point{0, 0})
printFeature(client, &pb.Point{Latitude: 0, Longitude: 0})
// Looking for features between 40, -75 and 42, -73.
printFeatures(client, &pb.Rectangle{&pb.Point{Latitude: 400000000, Longitude: -750000000}, &pb.Point{Latitude: 420000000, Longitude: -730000000}})
printFeatures(client, &pb.Rectangle{
Lo: &pb.Point{Latitude: 400000000, Longitude: -750000000},
Hi: &pb.Point{Latitude: 420000000, Longitude: -730000000},
})
// RecordRoute
runRecordRoute(client)

View File

@ -79,7 +79,7 @@ func (s *routeGuideServer) GetFeature(ctx context.Context, point *pb.Point) (*pb
}
}
// No feature was found, return an unnamed feature
return &pb.Feature{"", point}, nil
return &pb.Feature{Location: point}, nil
}
// ListFeatures lists all features contained within the given bounding Rectangle.

View File

@ -11,19 +11,22 @@ import (
healthpb "google.golang.org/grpc/health/grpc_health_v1"
)
type HealthServer struct {
// Server implements `service Health`.
type Server struct {
mu sync.Mutex
// statusMap stores the serving status of the services this HealthServer monitors.
// statusMap stores the serving status of the services this Server monitors.
statusMap map[string]healthpb.HealthCheckResponse_ServingStatus
}
func NewHealthServer() *HealthServer {
return &HealthServer{
// NewServer returns a new Server.
func NewServer() *Server {
return &Server{
statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus),
}
}
func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
// Check implements `service Health`.
func (s *Server) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
if in.Service == "" {
@ -42,7 +45,7 @@ func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckReques
// SetServingStatus is called when need to reset the serving status of a service
// or insert a new service entry into the statusMap.
func (s *HealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
func (s *Server) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
s.mu.Lock()
s.statusMap[service] = status
s.mu.Unlock()

View File

@ -60,15 +60,21 @@ func encodeKeyValue(k, v string) (string, string) {
// DecodeKeyValue returns the original key and value corresponding to the
// encoded data in k, v.
// If k is a binary header and v contains comma, v is split on comma before decoded,
// and the decoded v will be joined with comma before returned.
func DecodeKeyValue(k, v string) (string, string, error) {
if !strings.HasSuffix(k, binHdrSuffix) {
return k, v, nil
}
val, err := base64.StdEncoding.DecodeString(v)
if err != nil {
return "", "", err
vvs := strings.Split(v, ",")
for i, vv := range vvs {
val, err := base64.StdEncoding.DecodeString(vv)
if err != nil {
return "", "", err
}
vvs[i] = string(val)
}
return k, string(val), nil
return k, strings.Join(vvs, ","), nil
}
// MD is a mapping from metadata keys to values. Users should use the following

View File

@ -74,6 +74,8 @@ func TestDecodeKeyValue(t *testing.T) {
{"a", "abc", "a", "abc", nil},
{"key-bin", "Zm9vAGJhcg==", "key-bin", "foo\x00bar", nil},
{"key-bin", "woA=", "key-bin", binaryValue, nil},
{"a", "abc,efg", "a", "abc,efg", nil},
{"key-bin", "Zm9vAGJhcg==,Zm9vAGJhcg==", "key-bin", "foo\x00bar,foo\x00bar", nil},
} {
k, v, err := DecodeKeyValue(test.kin, test.vin)
if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) {

View File

@ -70,7 +70,7 @@ import (
type serverReflectionServer struct {
s *grpc.Server
// TODO add more cache if necessary
serviceInfo map[string]*grpc.ServiceInfo // cache for s.GetServiceInfo()
serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo()
}
// Register registers the server reflection service on the given gRPC server.

View File

@ -184,8 +184,8 @@ func TestContextErr(t *testing.T) {
func TestErrorsWithSameParameters(t *testing.T) {
const description = "some description"
e1 := Errorf(codes.AlreadyExists, description)
e2 := Errorf(codes.AlreadyExists, description)
e1 := Errorf(codes.AlreadyExists, description).(*rpcError)
e2 := Errorf(codes.AlreadyExists, description).(*rpcError)
if e1 == e2 {
t.Fatalf("Error interfaces should not be considered equal - e1: %p - %v e2: %p - %v", e1, e1, e2, e2)
}

View File

@ -281,8 +281,8 @@ type ServiceInfo struct {
// GetServiceInfo returns a map from service names to ServiceInfo.
// Service names include the package names, in the form of <package>.<service>.
func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
ret := make(map[string]*ServiceInfo)
func (s *Server) GetServiceInfo() map[string]ServiceInfo {
ret := make(map[string]ServiceInfo)
for n, srv := range s.m {
methods := make([]MethodInfo, 0, len(srv.md)+len(srv.sd))
for m := range srv.md {
@ -300,7 +300,7 @@ func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
})
}
ret[n] = &ServiceInfo{
ret[n] = ServiceInfo{
Methods: methods,
Metadata: srv.mdata,
}

View File

@ -90,15 +90,15 @@ func TestGetServiceInfo(t *testing.T) {
server.RegisterService(&testSd, &testServer{})
info := server.GetServiceInfo()
want := map[string]*ServiceInfo{
"grpc.testing.EmptyService": &ServiceInfo{
want := map[string]ServiceInfo{
"grpc.testing.EmptyService": {
Methods: []MethodInfo{
MethodInfo{
{
Name: "EmptyCall",
IsClientStream: false,
IsServerStream: false,
},
MethodInfo{
{
Name: "EmptyStream",
IsClientStream: true,
IsServerStream: false,
@ -108,6 +108,6 @@ func TestGetServiceInfo(t *testing.T) {
}
if !reflect.DeepEqual(info, want) {
t.Errorf("GetServiceInfo() = %q, want %q", info, want)
t.Errorf("GetServiceInfo() = %+v, want %+v", info, want)
}
}

View File

@ -300,39 +300,35 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
const tlsDir = "testdata/"
func unixDialer(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("unix", addr, timeout)
}
type env struct {
name string
network string // The type of network such as tcp, unix, etc.
dialer func(addr string, timeout time.Duration) (net.Conn, error)
security string // The security protocol such as TLS, SSH, etc.
httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS
}
func (e env) runnable() bool {
if runtime.GOOS == "windows" && strings.HasPrefix(e.name, "unix-") {
if runtime.GOOS == "windows" && e.network == "unix" {
return false
}
return true
}
func (e env) getDialer() func(addr string, timeout time.Duration) (net.Conn, error) {
if e.dialer != nil {
return e.dialer
}
return func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
func (e env) dialer(addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) {
// NB: Go 1.6 added a Cancel field on net.Dialer, which would allow this
// to be written as
//
// `(&net.Dialer{Cancel: cancel, Timeout: timeout}).Dial(e.network, addr)`
//
// but that would break compatibility with earlier Go versions.
return net.DialTimeout(e.network, addr, timeout)
}
var (
tcpClearEnv = env{name: "tcp-clear", network: "tcp"}
tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"}
unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer}
unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"}
unixClearEnv = env{name: "unix-clear", network: "unix"}
unixTLSEnv = env{name: "unix-tls", network: "unix", security: "tls"}
handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true}
allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv}
)
@ -371,7 +367,7 @@ type test struct {
// Configurable knobs, after newTest returns:
testServer testpb.TestServiceServer // nil means none
healthServer *health.HealthServer // nil means disabled
healthServer *health.Server // nil means disabled
maxStream uint32
maxMsgSize int
userAgent string
@ -519,9 +515,7 @@ func (te *test) declareLogNoise(phrases ...string) {
}
func (te *test) withServerTester(fn func(st *serverTester)) {
var c net.Conn
var err error
c, err = te.e.getDialer()(te.srvAddr, 10*time.Second)
c, err := te.e.dialer(te.srvAddr, 10*time.Second, nil)
if err != nil {
te.t.Fatal(err)
}
@ -752,7 +746,7 @@ func TestHealthCheckOnSuccess(t *testing.T) {
func testHealthCheckOnSuccess(t *testing.T, e env) {
te := newTest(t, e)
hs := health.NewHealthServer()
hs := health.NewServer()
hs.SetServingStatus("grpc.health.v1.Health", 1)
te.healthServer = hs
te.startServer(&testServer{security: e.security})
@ -778,7 +772,7 @@ func testHealthCheckOnFailure(t *testing.T, e env) {
"Failed to dial ",
"grpc: the client connection is closing; please retry",
)
hs := health.NewHealthServer()
hs := health.NewServer()
hs.SetServingStatus("grpc.health.v1.HealthCheck", 1)
te.healthServer = hs
te.startServer(&testServer{security: e.security})
@ -822,7 +816,7 @@ func TestHealthCheckServingStatus(t *testing.T) {
func testHealthCheckServingStatus(t *testing.T, e env) {
te := newTest(t, e)
hs := health.NewHealthServer()
hs := health.NewServer()
te.healthServer = hs
te.startServer(&testServer{security: e.security})
defer te.tearDown()

45
transport/go16.go Normal file
View File

@ -0,0 +1,45 @@
// +build go1.6
/*
* Copyright 2014, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
)
// newDialer constructs a net.Dialer.
func newDialer(timeout time.Duration, cancel <-chan struct{}) *net.Dialer {
return &net.Dialer{Cancel: cancel, Timeout: timeout}
}

View File

@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
}
if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := timeoutDecode(v)
to, err := decodeTimeout(v)
if err != nil {
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
}
@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" {
h.Set("Grpc-Message", grpcMessageEncode(statusDesc))
h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
}
if md := s.Trailer(); len(md) > 0 {
for k, vv := range md {

View File

@ -333,7 +333,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message"},
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {grpcMessageEncode(msg)},
"Grpc-Message": {encodeGrpcMessage(msg)},
}
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
@ -381,7 +381,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message"},
"Grpc-Status": {"4"},
"Grpc-Message": {grpcMessageEncode("too slow")},
"Grpc-Message": {encodeGrpcMessage("too slow")},
}
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)

View File

@ -107,20 +107,21 @@ type http2Client struct {
prevGoAwayID uint32
}
func dial(fn func(string, time.Duration, <-chan struct{}) (net.Conn, error), addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) {
if fn != nil {
return fn(addr, timeout, cancel)
}
return newDialer(timeout, cancel).Dial("tcp", addr)
}
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) {
if opts.Dialer == nil {
// Set the default Dialer.
opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
}
func newHTTP2Client(addr string, opts ConnectOptions) (_ ClientTransport, err error) {
scheme := "http"
startT := time.Now()
timeout := opts.Timeout
conn, connErr := opts.Dialer(addr, timeout)
conn, connErr := dial(opts.Dialer, addr, timeout, opts.Cancel)
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
}
@ -341,7 +342,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
}
if timeout > 0 {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
}
for k, v := range authData {
// Capital header names are illegal in HTTP/2.

View File

@ -265,6 +265,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
t.controlBuf.put(&resetStream{se.StreamID, se.Code})
continue
}
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
t.Close()
return
}
@ -507,7 +508,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
Name: "grpc-status",
Value: strconv.Itoa(int(statusCode)),
})
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: grpcMessageEncode(statusDesc)})
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
// Attach the trailer metadata.
for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.

View File

@ -35,6 +35,7 @@ package transport
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
}
d.statusCode = codes.Code(code)
case "grpc-message":
d.statusDesc = grpcMessageDecode(f.Value)
d.statusDesc = decodeGrpcMessage(f.Value)
case "grpc-timeout":
d.timeoutSet = true
var err error
d.timeout, err = timeoutDecode(f.Value)
d.timeout, err = decodeTimeout(f.Value)
if err != nil {
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
return
@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
}
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
func timeoutEncode(t time.Duration) string {
func encodeTimeout(t time.Duration) string {
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
return strconv.FormatInt(d, 10) + "n"
}
@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
}
func timeoutDecode(s string) (time.Duration, error) {
func decodeTimeout(s string) (time.Duration, error) {
size := len(s)
if size < 2 {
return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
return d * time.Duration(t), nil
}
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// encodeGrpcMessage is used to encode status code in header field
// "grpc-message".
// It checks to see if each individual byte in msg is an
// allowable byte, and then either percent encoding or passing it through.
// When percent encoding, the byte is converted into hexadecimal notation
// with a '%' prepended.
func encodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return encodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func encodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
buf.WriteByte(c)
} else {
buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
func decodeGrpcMessage(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return decodeGrpcMessageUnchecked(msg)
}
}
return msg
}
func decodeGrpcMessageUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
buf.WriteByte(c)
} else {
buf.WriteByte(byte(parsed))
i += 2
}
} else {
buf.WriteByte(c)
}
}
return buf.String()
}
type framer struct {
numWriters int32
reader io.Reader

View File

@ -59,7 +59,7 @@ func TestTimeoutEncode(t *testing.T) {
if err != nil {
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
}
out := timeoutEncode(d)
out := encodeTimeout(d)
if out != test.out {
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
}
@ -79,7 +79,7 @@ func TestTimeoutDecode(t *testing.T) {
{"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")},
{"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")},
} {
d, err := timeoutDecode(test.s)
d, err := decodeTimeout(test.s)
if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) {
t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err)
}
@ -107,3 +107,38 @@ func TestValidContentType(t *testing.T) {
}
}
}
func TestEncodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string
expected string
}{
{"", ""},
{"Hello", "Hello"},
{"my favorite character is \u0000", "my favorite character is %00"},
{"my favorite character is %", "my favorite character is %25"},
} {
actual := encodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
}
}
}
func TestDecodeGrpcMessage(t *testing.T) {
for _, tt := range []struct {
input string
expected string
}{
{"", ""},
{"Hello", "Hello"},
{"H%61o", "Hao"},
{"H%6", "H%6"},
{"%G0", "%G0"},
} {
actual := decodeGrpcMessage(tt.input)
if tt.expected != actual {
t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
}
}
}

45
transport/pre_go16.go Normal file
View File

@ -0,0 +1,45 @@
// +build !go1.6
/*
* Copyright 2016, Google Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
*/
package transport
import (
"net"
"time"
)
// newDialer constructs a net.Dialer.
func newDialer(timeout time.Duration, _ <-chan struct{}) *net.Dialer {
return &net.Dialer{Timeout: timeout}
}

View File

@ -43,7 +43,6 @@ import (
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
@ -354,8 +353,10 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authI
type ConnectOptions struct {
// UserAgent is the application user agent.
UserAgent string
// Cancel is closed to indicate that dialing should be cancelled.
Cancel chan struct{}
// Dialer specifies how to dial a network address.
Dialer func(string, time.Duration) (net.Conn, error)
Dialer func(string, time.Duration, <-chan struct{}) (net.Conn, error)
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
PerRPCCredentials []credentials.PerRPCCredentials
// TransportCredentials stores the Authenticator required to setup a client connection.
@ -366,7 +367,7 @@ type ConnectOptions struct {
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) {
func NewClientTransport(target string, opts ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(target, opts)
}
@ -559,74 +560,3 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
return i, nil
}
}
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// grpcMessageEncode encodes the grpc-message field in the same
// manner as https://github.com/grpc/grpc-java/pull/1517.
func grpcMessageEncode(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return grpcMessageEncodeUnchecked(msg)
}
}
return msg
}
func grpcMessageEncodeUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
_ = buf.WriteByte(c)
} else {
_, _ = buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// grpcMessageDecode decodes the grpc-message field in the same
// manner as https://github.com/grpc/grpc-java/pull/1517.
func grpcMessageDecode(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return grpcMessageDecodeUnchecked(msg)
}
}
return msg
}
func grpcMessageDecodeUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
_ = buf.WriteByte(c)
} else {
_ = buf.WriteByte(byte(parsed))
i += 2
}
} else {
_ = buf.WriteByte(c)
}
}
return buf.String()
}

View File

@ -37,7 +37,6 @@ import (
"bytes"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"reflect"
@ -131,7 +130,7 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) {
// raw newline is not accepted by http2 framer so it must be encoded.
h.t.WriteStatus(s, codes.Internal, "\n")
h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
}
// start starts server. Other goroutines should block on s.readyChan for further operations.
@ -221,7 +220,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client
ct ClientTransport
connErr error
)
ct, connErr = NewClientTransport(addr, &ConnectOptions{})
ct, connErr = NewClientTransport(addr, ConnectOptions{})
if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr)
}
@ -714,6 +713,11 @@ func TestClientWithMisbehavedServer(t *testing.T) {
server.stop()
}
var (
encodingTestStatusCode = codes.Internal
encodingTestStatusDesc = "\n"
)
func TestEncodingRequiredStatus(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
callHdr := &CallHdr{
@ -731,8 +735,12 @@ func TestEncodingRequiredStatus(t *testing.T) {
if err := ct.Write(s, expectedRequest, &opts); err != nil {
t.Fatalf("Failed to write the request: %v", err)
}
if _, err = ioutil.ReadAll(s); err != nil {
t.Fatal(err)
p := make([]byte, http2MaxFrameLen)
if _, err := s.dec.Read(p); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF)
}
if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc {
t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc)
}
ct.Close()
server.stop()
@ -769,29 +777,3 @@ func TestIsReservedHeader(t *testing.T) {
}
}
}
func TestGrpcMessageEncode(t *testing.T) {
testGrpcMessageEncode(t, "my favorite character is \u0000", "my favorite character is %00")
testGrpcMessageEncode(t, "my favorite character is %", "my favorite character is %25")
}
func TestGrpcMessageDecode(t *testing.T) {
testGrpcMessageDecode(t, "Hello", "Hello")
testGrpcMessageDecode(t, "H%61o", "Hao")
testGrpcMessageDecode(t, "H%6", "H%6")
testGrpcMessageDecode(t, "%G0", "%G0")
}
func testGrpcMessageEncode(t *testing.T, input string, expected string) {
actual := grpcMessageEncode(input)
if expected != actual {
t.Errorf("Expected %s from grpcMessageEncode, got %s", expected, actual)
}
}
func testGrpcMessageDecode(t *testing.T, input string, expected string) {
actual := grpcMessageDecode(input)
if expected != actual {
t.Errorf("Expected %s from grpcMessageDecode, got %s", expected, actual)
}
}