diff --git a/credentials/credentials.go b/credentials/credentials.go index 11560dec..a98bbb7a 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -96,8 +96,8 @@ type TransportAuthenticator interface { const ( transportSecurityType = "transport_security_type" - x509CN = "x509_common_name" - x509SAN = "x509_suject_alternative_name" + x509CN = "x509_common_name" + x509SAN = "x509_suject_alternative_name" ) // tlsCreds is the credentials required for authenticating a connection using TLS. @@ -169,7 +169,6 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]str info := make(map[string][]string) info[transportSecurityType] = []string{"tls"} for _, certs := range state.VerifiedChains { - fmt.Println("DEBUG: reach here") for _, cert := range certs { info[x509CN] = append(info[x509CN], cert.Subject.CommonName) for _, san := range cert.DNSNames { diff --git a/test/end2end_test.go b/test/end2end_test.go index 14748354..d7b83ff6 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -118,9 +118,12 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* } func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { - if _, ok := metadata.FromContext(stream.Context()); ok { - // For testing purpose, returns an error if there is attached metadata. - return grpc.Errorf(codes.DataLoss, "got extra metadata") + if md, ok := metadata.FromContext(stream.Context()); ok { + delete(md, "transport_security_type") + // For testing purpose, returns an error if there is attached metadata other than transport_security_type. + if len(md) > 0 { + return grpc.Errorf(codes.DataLoss, "got extra metadata") + } } cs := args.GetResponseParameters() for _, c := range cs { @@ -588,6 +591,10 @@ func testMetadataUnaryRPC(t *testing.T, e env) { if err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } + if e.security == "tls" { + delete(header, "transport_security_type") + delete(trailer, "transport_security_type") + } if !reflect.DeepEqual(testMetadata, header) { t.Fatalf("Received header metadata %v, want %v", header, testMetadata) } @@ -775,11 +782,17 @@ func testMetadataStreamingRPC(t *testing.T, e env) { } go func() { headerMD, err := stream.Header() + if e.security == "tls" { + delete(headerMD, "transport_security_type") + } if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { t.Errorf("#1 %v.Header() = %v, %v, want %v, ", stream, headerMD, err, testMetadata) } // test the cached value. headerMD, err = stream.Header() + if e.security == "tls" { + delete(headerMD, "transport_security_type") + } if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { t.Errorf("#2 %v.Header() = %v, %v, want %v, ", stream, headerMD, err, testMetadata) } @@ -810,6 +823,9 @@ func testMetadataStreamingRPC(t *testing.T, e env) { } } trailerMD := stream.Trailer() + if e.security == "tls" { + delete(trailerMD, "transport_security_type") + } if !reflect.DeepEqual(testMetadata, trailerMD) { t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata) } @@ -860,7 +876,7 @@ func testServerStreaming(t *testing.T, e env) { respCnt++ } if rpcStatus != io.EOF { - t.Fatalf("Failed to finish the server streaming rpc: %v, want ", err) + t.Fatalf("Failed to finish the server streaming rpc: %v, want ", rpcStatus) } if respCnt != len(respSizes) { t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt) diff --git a/transport/http2_client.go b/transport/http2_client.go index 8dd4b7d3..3e60058c 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -55,8 +55,9 @@ import ( type http2Client struct { target string // server name/addr userAgent string - conn net.Conn // underlying communication channel - nextID uint32 // the next stream ID to be used + conn net.Conn // underlying communication channel + authInfo map[string][]string // auth info about the connection + nextID uint32 // the next stream ID to be used // writableChan synchronizes write access to the transport. // A writer acquires the write lock by sending a value on writableChan @@ -114,6 +115,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e if connErr != nil { return nil, ConnectionErrorf("transport: %v", connErr) } + var authInfo map[string][]string for _, c := range opts.AuthOptions { if ccreds, ok := c.(credentials.TransportAuthenticator); ok { scheme = "https" @@ -124,7 +126,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e if timeout > 0 { timeout -= time.Since(startT) } - conn, _, connErr = ccreds.ClientHandshake(addr, conn, timeout) + conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout) break } } @@ -168,6 +170,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e target: addr, userAgent: ua, conn: conn, + authInfo: authInfo, // The client initiated stream id is odd starting from 1. nextID: 1, writableChan: make(chan int, 1), @@ -701,7 +704,7 @@ func (t *http2Client) reader() { } t.handleSettings(sf) - hDec := newHPACKDecoder() + hDec := newHPACKDecoder(t.authInfo) var curStream *Stream // loop to keep reading incoming messages on this transport. for { diff --git a/transport/http2_server.go b/transport/http2_server.go index 8df2aa24..5d1bff3b 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -57,8 +57,8 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { conn net.Conn - maxStreamID uint32 // max stream ID ever seen - authInfo map[string][]string // basic auth info about the connection + maxStreamID uint32 // max stream ID ever seen + authInfo map[string][]string // auth info about the connection // writableChan synchronizes write access to the transport. // A writer acquires the write lock by sending a value on writableChan // and releases it by receiving from writableChan. @@ -236,8 +236,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { } t.handleSettings(sf) - hDec := newHPACKDecoder() - hDec.state.mdata = t.authInfo + hDec := newHPACKDecoder(t.authInfo) var curStream *Stream var wg sync.WaitGroup defer wg.Wait() diff --git a/transport/http_util.go b/transport/http_util.go index ef07924c..5fe65e37 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -106,6 +106,7 @@ type decodeState struct { // An hpackDecoder decodes HTTP2 headers which may span multiple frames. type hpackDecoder struct { h *hpack.Decoder + mdata map[string][]string // persistent metadata with this decoder state decodeState err error // The err when decoding } @@ -138,8 +139,12 @@ func isReservedHeader(hdr string) bool { } } -func newHPACKDecoder() *hpackDecoder { +func newHPACKDecoder(mdata map[string][]string) *hpackDecoder { d := &hpackDecoder{} + for k, v := range mdata { + d.mdata = make(map[string][]string) + d.mdata[k] = v + } d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) { switch f.Name { case "grpc-status": @@ -174,6 +179,9 @@ func newHPACKDecoder() *hpackDecoder { } if d.state.mdata == nil { d.state.mdata = make(map[string][]string) + for k, v := range d.mdata { + d.state.mdata[k] = v + } } k, v, err := metadata.DecodeKeyValue(f.Name, f.Value) if err != nil {