make the test happy
This commit is contained in:
@ -169,7 +169,6 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]str
|
||||
info := make(map[string][]string)
|
||||
info[transportSecurityType] = []string{"tls"}
|
||||
for _, certs := range state.VerifiedChains {
|
||||
fmt.Println("DEBUG: reach here")
|
||||
for _, cert := range certs {
|
||||
info[x509CN] = append(info[x509CN], cert.Subject.CommonName)
|
||||
for _, san := range cert.DNSNames {
|
||||
|
@ -118,10 +118,13 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
||||
}
|
||||
|
||||
func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
|
||||
if _, ok := metadata.FromContext(stream.Context()); ok {
|
||||
// For testing purpose, returns an error if there is attached metadata.
|
||||
if md, ok := metadata.FromContext(stream.Context()); ok {
|
||||
delete(md, "transport_security_type")
|
||||
// For testing purpose, returns an error if there is attached metadata other than transport_security_type.
|
||||
if len(md) > 0 {
|
||||
return grpc.Errorf(codes.DataLoss, "got extra metadata")
|
||||
}
|
||||
}
|
||||
cs := args.GetResponseParameters()
|
||||
for _, c := range cs {
|
||||
if us := c.GetIntervalUs(); us > 0 {
|
||||
@ -588,6 +591,10 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
|
||||
if err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
if e.security == "tls" {
|
||||
delete(header, "transport_security_type")
|
||||
delete(trailer, "transport_security_type")
|
||||
}
|
||||
if !reflect.DeepEqual(testMetadata, header) {
|
||||
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
|
||||
}
|
||||
@ -775,11 +782,17 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
|
||||
}
|
||||
go func() {
|
||||
headerMD, err := stream.Header()
|
||||
if e.security == "tls" {
|
||||
delete(headerMD, "transport_security_type")
|
||||
}
|
||||
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
|
||||
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
||||
}
|
||||
// test the cached value.
|
||||
headerMD, err = stream.Header()
|
||||
if e.security == "tls" {
|
||||
delete(headerMD, "transport_security_type")
|
||||
}
|
||||
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
|
||||
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
||||
}
|
||||
@ -810,6 +823,9 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
|
||||
}
|
||||
}
|
||||
trailerMD := stream.Trailer()
|
||||
if e.security == "tls" {
|
||||
delete(trailerMD, "transport_security_type")
|
||||
}
|
||||
if !reflect.DeepEqual(testMetadata, trailerMD) {
|
||||
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
|
||||
}
|
||||
@ -860,7 +876,7 @@ func testServerStreaming(t *testing.T, e env) {
|
||||
respCnt++
|
||||
}
|
||||
if rpcStatus != io.EOF {
|
||||
t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", err)
|
||||
t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", rpcStatus)
|
||||
}
|
||||
if respCnt != len(respSizes) {
|
||||
t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)
|
||||
|
@ -56,6 +56,7 @@ type http2Client struct {
|
||||
target string // server name/addr
|
||||
userAgent string
|
||||
conn net.Conn // underlying communication channel
|
||||
authInfo map[string][]string // auth info about the connection
|
||||
nextID uint32 // the next stream ID to be used
|
||||
|
||||
// writableChan synchronizes write access to the transport.
|
||||
@ -114,6 +115,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
||||
if connErr != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
}
|
||||
var authInfo map[string][]string
|
||||
for _, c := range opts.AuthOptions {
|
||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
||||
scheme = "https"
|
||||
@ -124,7 +126,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
||||
if timeout > 0 {
|
||||
timeout -= time.Since(startT)
|
||||
}
|
||||
conn, _, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
||||
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -168,6 +170,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
||||
target: addr,
|
||||
userAgent: ua,
|
||||
conn: conn,
|
||||
authInfo: authInfo,
|
||||
// The client initiated stream id is odd starting from 1.
|
||||
nextID: 1,
|
||||
writableChan: make(chan int, 1),
|
||||
@ -701,7 +704,7 @@ func (t *http2Client) reader() {
|
||||
}
|
||||
t.handleSettings(sf)
|
||||
|
||||
hDec := newHPACKDecoder()
|
||||
hDec := newHPACKDecoder(t.authInfo)
|
||||
var curStream *Stream
|
||||
// loop to keep reading incoming messages on this transport.
|
||||
for {
|
||||
|
@ -58,7 +58,7 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
|
||||
type http2Server struct {
|
||||
conn net.Conn
|
||||
maxStreamID uint32 // max stream ID ever seen
|
||||
authInfo map[string][]string // basic auth info about the connection
|
||||
authInfo map[string][]string // auth info about the connection
|
||||
// writableChan synchronizes write access to the transport.
|
||||
// A writer acquires the write lock by sending a value on writableChan
|
||||
// and releases it by receiving from writableChan.
|
||||
@ -236,8 +236,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
||||
}
|
||||
t.handleSettings(sf)
|
||||
|
||||
hDec := newHPACKDecoder()
|
||||
hDec.state.mdata = t.authInfo
|
||||
hDec := newHPACKDecoder(t.authInfo)
|
||||
var curStream *Stream
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
@ -106,6 +106,7 @@ type decodeState struct {
|
||||
// An hpackDecoder decodes HTTP2 headers which may span multiple frames.
|
||||
type hpackDecoder struct {
|
||||
h *hpack.Decoder
|
||||
mdata map[string][]string // persistent metadata with this decoder
|
||||
state decodeState
|
||||
err error // The err when decoding
|
||||
}
|
||||
@ -138,8 +139,12 @@ func isReservedHeader(hdr string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func newHPACKDecoder() *hpackDecoder {
|
||||
func newHPACKDecoder(mdata map[string][]string) *hpackDecoder {
|
||||
d := &hpackDecoder{}
|
||||
for k, v := range mdata {
|
||||
d.mdata = make(map[string][]string)
|
||||
d.mdata[k] = v
|
||||
}
|
||||
d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) {
|
||||
switch f.Name {
|
||||
case "grpc-status":
|
||||
@ -174,6 +179,9 @@ func newHPACKDecoder() *hpackDecoder {
|
||||
}
|
||||
if d.state.mdata == nil {
|
||||
d.state.mdata = make(map[string][]string)
|
||||
for k, v := range d.mdata {
|
||||
d.state.mdata[k] = v
|
||||
}
|
||||
}
|
||||
k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user