make the test happy
This commit is contained in:
@ -169,7 +169,6 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]str
|
|||||||
info := make(map[string][]string)
|
info := make(map[string][]string)
|
||||||
info[transportSecurityType] = []string{"tls"}
|
info[transportSecurityType] = []string{"tls"}
|
||||||
for _, certs := range state.VerifiedChains {
|
for _, certs := range state.VerifiedChains {
|
||||||
fmt.Println("DEBUG: reach here")
|
|
||||||
for _, cert := range certs {
|
for _, cert := range certs {
|
||||||
info[x509CN] = append(info[x509CN], cert.Subject.CommonName)
|
info[x509CN] = append(info[x509CN], cert.Subject.CommonName)
|
||||||
for _, san := range cert.DNSNames {
|
for _, san := range cert.DNSNames {
|
||||||
|
@ -118,10 +118,13 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
|
func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
|
||||||
if _, ok := metadata.FromContext(stream.Context()); ok {
|
if md, ok := metadata.FromContext(stream.Context()); ok {
|
||||||
// For testing purpose, returns an error if there is attached metadata.
|
delete(md, "transport_security_type")
|
||||||
|
// For testing purpose, returns an error if there is attached metadata other than transport_security_type.
|
||||||
|
if len(md) > 0 {
|
||||||
return grpc.Errorf(codes.DataLoss, "got extra metadata")
|
return grpc.Errorf(codes.DataLoss, "got extra metadata")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
cs := args.GetResponseParameters()
|
cs := args.GetResponseParameters()
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
if us := c.GetIntervalUs(); us > 0 {
|
if us := c.GetIntervalUs(); us > 0 {
|
||||||
@ -588,6 +591,10 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||||
}
|
}
|
||||||
|
if e.security == "tls" {
|
||||||
|
delete(header, "transport_security_type")
|
||||||
|
delete(trailer, "transport_security_type")
|
||||||
|
}
|
||||||
if !reflect.DeepEqual(testMetadata, header) {
|
if !reflect.DeepEqual(testMetadata, header) {
|
||||||
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
|
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
|
||||||
}
|
}
|
||||||
@ -775,11 +782,17 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
|
|||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
headerMD, err := stream.Header()
|
headerMD, err := stream.Header()
|
||||||
|
if e.security == "tls" {
|
||||||
|
delete(headerMD, "transport_security_type")
|
||||||
|
}
|
||||||
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
|
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
|
||||||
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
||||||
}
|
}
|
||||||
// test the cached value.
|
// test the cached value.
|
||||||
headerMD, err = stream.Header()
|
headerMD, err = stream.Header()
|
||||||
|
if e.security == "tls" {
|
||||||
|
delete(headerMD, "transport_security_type")
|
||||||
|
}
|
||||||
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
|
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
|
||||||
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
||||||
}
|
}
|
||||||
@ -810,6 +823,9 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
trailerMD := stream.Trailer()
|
trailerMD := stream.Trailer()
|
||||||
|
if e.security == "tls" {
|
||||||
|
delete(trailerMD, "transport_security_type")
|
||||||
|
}
|
||||||
if !reflect.DeepEqual(testMetadata, trailerMD) {
|
if !reflect.DeepEqual(testMetadata, trailerMD) {
|
||||||
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
|
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
|
||||||
}
|
}
|
||||||
@ -860,7 +876,7 @@ func testServerStreaming(t *testing.T, e env) {
|
|||||||
respCnt++
|
respCnt++
|
||||||
}
|
}
|
||||||
if rpcStatus != io.EOF {
|
if rpcStatus != io.EOF {
|
||||||
t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", err)
|
t.Fatalf("Failed to finish the server streaming rpc: %v, want <EOF>", rpcStatus)
|
||||||
}
|
}
|
||||||
if respCnt != len(respSizes) {
|
if respCnt != len(respSizes) {
|
||||||
t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)
|
t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt)
|
||||||
|
@ -56,6 +56,7 @@ type http2Client struct {
|
|||||||
target string // server name/addr
|
target string // server name/addr
|
||||||
userAgent string
|
userAgent string
|
||||||
conn net.Conn // underlying communication channel
|
conn net.Conn // underlying communication channel
|
||||||
|
authInfo map[string][]string // auth info about the connection
|
||||||
nextID uint32 // the next stream ID to be used
|
nextID uint32 // the next stream ID to be used
|
||||||
|
|
||||||
// writableChan synchronizes write access to the transport.
|
// writableChan synchronizes write access to the transport.
|
||||||
@ -114,6 +115,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
if connErr != nil {
|
if connErr != nil {
|
||||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||||
}
|
}
|
||||||
|
var authInfo map[string][]string
|
||||||
for _, c := range opts.AuthOptions {
|
for _, c := range opts.AuthOptions {
|
||||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
||||||
scheme = "https"
|
scheme = "https"
|
||||||
@ -124,7 +126,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
if timeout > 0 {
|
if timeout > 0 {
|
||||||
timeout -= time.Since(startT)
|
timeout -= time.Since(startT)
|
||||||
}
|
}
|
||||||
conn, _, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -168,6 +170,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||||||
target: addr,
|
target: addr,
|
||||||
userAgent: ua,
|
userAgent: ua,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
authInfo: authInfo,
|
||||||
// The client initiated stream id is odd starting from 1.
|
// The client initiated stream id is odd starting from 1.
|
||||||
nextID: 1,
|
nextID: 1,
|
||||||
writableChan: make(chan int, 1),
|
writableChan: make(chan int, 1),
|
||||||
@ -701,7 +704,7 @@ func (t *http2Client) reader() {
|
|||||||
}
|
}
|
||||||
t.handleSettings(sf)
|
t.handleSettings(sf)
|
||||||
|
|
||||||
hDec := newHPACKDecoder()
|
hDec := newHPACKDecoder(t.authInfo)
|
||||||
var curStream *Stream
|
var curStream *Stream
|
||||||
// loop to keep reading incoming messages on this transport.
|
// loop to keep reading incoming messages on this transport.
|
||||||
for {
|
for {
|
||||||
|
@ -58,7 +58,7 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
|
|||||||
type http2Server struct {
|
type http2Server struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
maxStreamID uint32 // max stream ID ever seen
|
maxStreamID uint32 // max stream ID ever seen
|
||||||
authInfo map[string][]string // basic auth info about the connection
|
authInfo map[string][]string // auth info about the connection
|
||||||
// writableChan synchronizes write access to the transport.
|
// writableChan synchronizes write access to the transport.
|
||||||
// A writer acquires the write lock by sending a value on writableChan
|
// A writer acquires the write lock by sending a value on writableChan
|
||||||
// and releases it by receiving from writableChan.
|
// and releases it by receiving from writableChan.
|
||||||
@ -236,8 +236,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
|||||||
}
|
}
|
||||||
t.handleSettings(sf)
|
t.handleSettings(sf)
|
||||||
|
|
||||||
hDec := newHPACKDecoder()
|
hDec := newHPACKDecoder(t.authInfo)
|
||||||
hDec.state.mdata = t.authInfo
|
|
||||||
var curStream *Stream
|
var curStream *Stream
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
defer wg.Wait()
|
defer wg.Wait()
|
||||||
|
@ -106,6 +106,7 @@ type decodeState struct {
|
|||||||
// An hpackDecoder decodes HTTP2 headers which may span multiple frames.
|
// An hpackDecoder decodes HTTP2 headers which may span multiple frames.
|
||||||
type hpackDecoder struct {
|
type hpackDecoder struct {
|
||||||
h *hpack.Decoder
|
h *hpack.Decoder
|
||||||
|
mdata map[string][]string // persistent metadata with this decoder
|
||||||
state decodeState
|
state decodeState
|
||||||
err error // The err when decoding
|
err error // The err when decoding
|
||||||
}
|
}
|
||||||
@ -138,8 +139,12 @@ func isReservedHeader(hdr string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHPACKDecoder() *hpackDecoder {
|
func newHPACKDecoder(mdata map[string][]string) *hpackDecoder {
|
||||||
d := &hpackDecoder{}
|
d := &hpackDecoder{}
|
||||||
|
for k, v := range mdata {
|
||||||
|
d.mdata = make(map[string][]string)
|
||||||
|
d.mdata[k] = v
|
||||||
|
}
|
||||||
d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) {
|
d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) {
|
||||||
switch f.Name {
|
switch f.Name {
|
||||||
case "grpc-status":
|
case "grpc-status":
|
||||||
@ -174,6 +179,9 @@ func newHPACKDecoder() *hpackDecoder {
|
|||||||
}
|
}
|
||||||
if d.state.mdata == nil {
|
if d.state.mdata == nil {
|
||||||
d.state.mdata = make(map[string][]string)
|
d.state.mdata = make(map[string][]string)
|
||||||
|
for k, v := range d.mdata {
|
||||||
|
d.state.mdata[k] = v
|
||||||
|
}
|
||||||
}
|
}
|
||||||
k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
|
k, v, err := metadata.DecodeKeyValue(f.Name, f.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Reference in New Issue
Block a user