make the test happy

This commit is contained in:
iamqizhao
2015-08-24 11:40:40 -07:00
parent d12ff72146
commit 97574c6499
5 changed files with 41 additions and 16 deletions

@ -169,7 +169,6 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]str
info := make(map[string][]string) info := make(map[string][]string)
info[transportSecurityType] = []string{"tls"} info[transportSecurityType] = []string{"tls"}
for _, certs := range state.VerifiedChains { for _, certs := range state.VerifiedChains {
fmt.Println("DEBUG: reach here")
for _, cert := range certs { for _, cert := range certs {
info[x509CN] = append(info[x509CN], cert.Subject.CommonName) info[x509CN] = append(info[x509CN], cert.Subject.CommonName)
for _, san := range cert.DNSNames { for _, san := range cert.DNSNames {

@ -118,10 +118,13 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
} }
func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
if _, ok := metadata.FromContext(stream.Context()); ok { if md, ok := metadata.FromContext(stream.Context()); ok {
// For testing purpose, returns an error if there is attached metadata. delete(md, "transport_security_type")
// For testing purpose, returns an error if there is attached metadata other than transport_security_type.
if len(md) > 0 {
return grpc.Errorf(codes.DataLoss, "got extra metadata") return grpc.Errorf(codes.DataLoss, "got extra metadata")
} }
}
cs := args.GetResponseParameters() cs := args.GetResponseParameters()
for _, c := range cs { for _, c := range cs {
if us := c.GetIntervalUs(); us > 0 { if us := c.GetIntervalUs(); us > 0 {
@ -588,6 +591,10 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
if err != nil { if err != nil {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err) t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
} }
if e.security == "tls" {
delete(header, "transport_security_type")
delete(trailer, "transport_security_type")
}
if !reflect.DeepEqual(testMetadata, header) { if !reflect.DeepEqual(testMetadata, header) {
t.Fatalf("Received header metadata %v, want %v", header, testMetadata) t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
} }
@ -775,11 +782,17 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
} }
go func() { go func() {
headerMD, err := stream.Header() headerMD, err := stream.Header()
if e.security == "tls" {
delete(headerMD, "transport_security_type")
}
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata) t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
} }
// test the cached value. // test the cached value.
headerMD, err = stream.Header() headerMD, err = stream.Header()
if e.security == "tls" {
delete(headerMD, "transport_security_type")
}
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata) t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
} }
@ -810,6 +823,9 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
} }
} }
trailerMD := stream.Trailer() trailerMD := stream.Trailer()
if e.security == "tls" {
delete(trailerMD, "transport_security_type")
}
if !reflect.DeepEqual(testMetadata, trailerMD) { if !reflect.DeepEqual(testMetadata, trailerMD) {
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata) t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
} }
@ -860,7 +876,7 @@ func testServerStreaming(t *testing.T, e env) {
respCnt++ respCnt++
} }
if rpcStatus != io.EOF { if rpcStatus != io.EOF {
t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", err) t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", rpcStatus)
} }
if respCnt != len(respSizes) { if respCnt != len(respSizes) {
t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt) t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)

@ -56,6 +56,7 @@ type http2Client struct {
target string // server name/addr target string // server name/addr
userAgent string userAgent string
conn net.Conn // underlying communication channel conn net.Conn // underlying communication channel
authInfo map[string][]string // auth info about the connection
nextID uint32 // the next stream ID to be used nextID uint32 // the next stream ID to be used
// writableChan synchronizes write access to the transport. // writableChan synchronizes write access to the transport.
@ -114,6 +115,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if connErr != nil { if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr) return nil, ConnectionErrorf("transport: %v", connErr)
} }
var authInfo map[string][]string
for _, c := range opts.AuthOptions { for _, c := range opts.AuthOptions {
if ccreds, ok := c.(credentials.TransportAuthenticator); ok { if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
scheme = "https" scheme = "https"
@ -124,7 +126,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if timeout > 0 { if timeout > 0 {
timeout -= time.Since(startT) timeout -= time.Since(startT)
} }
conn, _, connErr = ccreds.ClientHandshake(addr, conn, timeout) conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
break break
} }
} }
@ -168,6 +170,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
target: addr, target: addr,
userAgent: ua, userAgent: ua,
conn: conn, conn: conn,
authInfo: authInfo,
// The client initiated stream id is odd starting from 1. // The client initiated stream id is odd starting from 1.
nextID: 1, nextID: 1,
writableChan: make(chan int, 1), writableChan: make(chan int, 1),
@ -701,7 +704,7 @@ func (t *http2Client) reader() {
} }
t.handleSettings(sf) t.handleSettings(sf)
hDec := newHPACKDecoder() hDec := newHPACKDecoder(t.authInfo)
var curStream *Stream var curStream *Stream
// loop to keep reading incoming messages on this transport. // loop to keep reading incoming messages on this transport.
for { for {

@ -58,7 +58,7 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
type http2Server struct { type http2Server struct {
conn net.Conn conn net.Conn
maxStreamID uint32 // max stream ID ever seen maxStreamID uint32 // max stream ID ever seen
authInfo map[string][]string // basic auth info about the connection authInfo map[string][]string // auth info about the connection
// writableChan synchronizes write access to the transport. // writableChan synchronizes write access to the transport.
// A writer acquires the write lock by sending a value on writableChan // A writer acquires the write lock by sending a value on writableChan
// and releases it by receiving from writableChan. // and releases it by receiving from writableChan.
@ -236,8 +236,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
} }
t.handleSettings(sf) t.handleSettings(sf)
hDec := newHPACKDecoder() hDec := newHPACKDecoder(t.authInfo)
hDec.state.mdata = t.authInfo
var curStream *Stream var curStream *Stream
var wg sync.WaitGroup var wg sync.WaitGroup
defer wg.Wait() defer wg.Wait()

@ -106,6 +106,7 @@ type decodeState struct {
// An hpackDecoder decodes HTTP2 headers which may span multiple frames. // An hpackDecoder decodes HTTP2 headers which may span multiple frames.
type hpackDecoder struct { type hpackDecoder struct {
h *hpack.Decoder h *hpack.Decoder
mdata map[string][]string // persistent metadata with this decoder
state decodeState state decodeState
err error // The err when decoding err error // The err when decoding
} }
@ -138,8 +139,12 @@ func isReservedHeader(hdr string) bool {
} }
} }
func newHPACKDecoder() *hpackDecoder { func newHPACKDecoder(mdata map[string][]string) *hpackDecoder {
d := &hpackDecoder{} d := &hpackDecoder{}
for k, v := range mdata {
d.mdata = make(map[string][]string)
d.mdata[k] = v
}
d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) { d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) {
switch f.Name { switch f.Name {
case "grpc-status": case "grpc-status":
@ -174,6 +179,9 @@ func newHPACKDecoder() *hpackDecoder {
} }
if d.state.mdata == nil { if d.state.mdata == nil {
d.state.mdata = make(map[string][]string) d.state.mdata = make(map[string][]string)
for k, v := range d.mdata {
d.state.mdata[k] = v
}
} }
k, v, err := metadata.DecodeKeyValue(f.Name, f.Value) k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
if err != nil { if err != nil {