make the test happy

This commit is contained in:
iamqizhao
2015-08-24 11:40:40 -07:00
parent d12ff72146
commit 97574c6499
5 changed files with 41 additions and 16 deletions

View File

@ -169,7 +169,6 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]str
info := make(map[string][]string)
info[transportSecurityType] = []string{"tls"}
for _, certs := range state.VerifiedChains {
fmt.Println("DEBUG: reach here")
for _, cert := range certs {
info[x509CN] = append(info[x509CN], cert.Subject.CommonName)
for _, san := range cert.DNSNames {

View File

@ -118,10 +118,13 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
}
func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
if _, ok := metadata.FromContext(stream.Context()); ok {
// For testing purpose, returns an error if there is attached metadata.
if md, ok := metadata.FromContext(stream.Context()); ok {
delete(md, "transport_security_type")
// For testing purpose, returns an error if there is attached metadata other than transport_security_type.
if len(md) > 0 {
return grpc.Errorf(codes.DataLoss, "got extra metadata")
}
}
cs := args.GetResponseParameters()
for _, c := range cs {
if us := c.GetIntervalUs(); us > 0 {
@ -588,6 +591,10 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
if err != nil {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
}
if e.security == "tls" {
delete(header, "transport_security_type")
delete(trailer, "transport_security_type")
}
if !reflect.DeepEqual(testMetadata, header) {
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
}
@ -775,11 +782,17 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
}
go func() {
headerMD, err := stream.Header()
if e.security == "tls" {
delete(headerMD, "transport_security_type")
}
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
// test the cached value.
headerMD, err = stream.Header()
if e.security == "tls" {
delete(headerMD, "transport_security_type")
}
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
@ -810,6 +823,9 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
}
}
trailerMD := stream.Trailer()
if e.security == "tls" {
delete(trailerMD, "transport_security_type")
}
if !reflect.DeepEqual(testMetadata, trailerMD) {
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
}
@ -860,7 +876,7 @@ func testServerStreaming(t *testing.T, e env) {
respCnt++
}
if rpcStatus != io.EOF {
t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", err)
t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", rpcStatus)
}
if respCnt != len(respSizes) {
t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)

View File

@ -56,6 +56,7 @@ type http2Client struct {
target string // server name/addr
userAgent string
conn net.Conn // underlying communication channel
authInfo map[string][]string // auth info about the connection
nextID uint32 // the next stream ID to be used
// writableChan synchronizes write access to the transport.
@ -114,6 +115,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
}
var authInfo map[string][]string
for _, c := range opts.AuthOptions {
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
scheme = "https"
@ -124,7 +126,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, _, connErr = ccreds.ClientHandshake(addr, conn, timeout)
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
break
}
}
@ -168,6 +170,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
target: addr,
userAgent: ua,
conn: conn,
authInfo: authInfo,
// The client initiated stream id is odd starting from 1.
nextID: 1,
writableChan: make(chan int, 1),
@ -701,7 +704,7 @@ func (t *http2Client) reader() {
}
t.handleSettings(sf)
hDec := newHPACKDecoder()
hDec := newHPACKDecoder(t.authInfo)
var curStream *Stream
// loop to keep reading incoming messages on this transport.
for {

View File

@ -58,7 +58,7 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
type http2Server struct {
conn net.Conn
maxStreamID uint32 // max stream ID ever seen
authInfo map[string][]string // basic auth info about the connection
authInfo map[string][]string // auth info about the connection
// writableChan synchronizes write access to the transport.
// A writer acquires the write lock by sending a value on writableChan
// and releases it by receiving from writableChan.
@ -236,8 +236,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
}
t.handleSettings(sf)
hDec := newHPACKDecoder()
hDec.state.mdata = t.authInfo
hDec := newHPACKDecoder(t.authInfo)
var curStream *Stream
var wg sync.WaitGroup
defer wg.Wait()

View File

@ -106,6 +106,7 @@ type decodeState struct {
// An hpackDecoder decodes HTTP2 headers which may span multiple frames.
type hpackDecoder struct {
h *hpack.Decoder
mdata map[string][]string // persistent metadata with this decoder
state decodeState
err error // The err when decoding
}
@ -138,8 +139,12 @@ func isReservedHeader(hdr string) bool {
}
}
func newHPACKDecoder() *hpackDecoder {
func newHPACKDecoder(mdata map[string][]string) *hpackDecoder {
d := &hpackDecoder{}
for k, v := range mdata {
d.mdata = make(map[string][]string)
d.mdata[k] = v
}
d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) {
switch f.Name {
case "grpc-status":
@ -174,6 +179,9 @@ func newHPACKDecoder() *hpackDecoder {
}
if d.state.mdata == nil {
d.state.mdata = make(map[string][]string)
for k, v := range d.mdata {
d.state.mdata[k] = v
}
}
k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
if err != nil {