diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 78f9ddc3..fbf01d5f 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -227,7 +227,9 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro if err == nil { // transport has not been closed if ht.stats != nil { - ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) + ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) } } ht.Close() @@ -289,7 +291,9 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { if err == nil { if ht.stats != nil { - ht.stats.HandleRPC(s.Context(), &stats.OutHeader{}) + ht.stats.HandleRPC(s.Context(), &stats.OutHeader{ + Header: md.Copy(), + }) } } return err diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 294661a3..c18a29dc 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -669,12 +669,14 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea } } if t.statsHandler != nil { + header, _, _ := metadata.FromOutgoingContextRaw(ctx) outHeader := &stats.OutHeader{ Client: true, FullMethod: callHdr.Method, RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, Compression: callHdr.SendCompress, + Header: header.Copy(), } t.statsHandler.HandleRPC(s.ctx, outHeader) } @@ -1177,12 +1179,14 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { inHeader := &stats.InHeader{ Client: true, WireLength: int(frame.Header().Length), + Header: s.header.Copy(), } t.statsHandler.HandleRPC(s.ctx, inHeader) } else { inTrailer := &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), + Trailer: s.trailer.Copy(), } t.statsHandler.HandleRPC(s.ctx, inTrailer) } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 07603836..3368e6aa 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -416,6 +416,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( LocalAddr: t.localAddr, Compression: s.recvCompress, WireLength: int(frame.Header().Length), + Header: metadata.MD(state.data.mdata).Copy(), } t.stats.HandleRPC(s.ctx, inHeader) } @@ -808,7 +809,9 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error { if t.stats != nil { // Note: WireLength is not set in outHeader. // TODO(mmukhi): Revisit this later, if needed. - outHeader := &stats.OutHeader{} + outHeader := &stats.OutHeader{ + Header: s.header.Copy(), + } t.stats.HandleRPC(s.Context(), outHeader) } return nil @@ -871,7 +874,9 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { rst := s.getState() == streamActive t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) if t.stats != nil { - t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) + t.stats.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) } return nil } diff --git a/stats/stats.go b/stats/stats.go index f3f593c8..9e22c393 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -91,6 +91,8 @@ type InHeader struct { LocalAddr net.Addr // Compression is the compression algorithm used for the RPC. Compression string + // Header contains the header metadata received. + Header metadata.MD } // IsClient indicates if the stats information is from client side. @@ -104,6 +106,9 @@ type InTrailer struct { Client bool // WireLength is the wire length of trailer. WireLength int + // Trailer contains the trailer metadata received from the server. This + // field is only valid if this InTrailer is from the client side. + Trailer metadata.MD } // IsClient indicates if the stats information is from client side. @@ -146,6 +151,8 @@ type OutHeader struct { LocalAddr net.Addr // Compression is the compression algorithm used for the RPC. Compression string + // Header contains the header metadata sent. + Header metadata.MD } // IsClient indicates if this stats information is from client side. @@ -159,6 +166,9 @@ type OutTrailer struct { Client bool // WireLength is the wire length of trailer. WireLength int + // Trailer contains the trailer metadata sent to the client. This + // field is only valid if this OutTrailer is from the server side. + Trailer metadata.MD } // IsClient indicates if this stats information is from client side. @@ -176,6 +186,7 @@ type End struct { EndTime time.Time // Trailer contains the trailer metadata received from the server. This // field is only valid if this End is from the client side. + // Deprecated: use Trailer in InTrailer instead. Trailer metadata.MD // Error is the error the RPC ended with. It is an error generated from // status.Status and can be converted back to status.Status using diff --git a/stats/stats_test.go b/stats/stats_test.go index eb286539..30248c05 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -44,12 +44,17 @@ type connCtxKey struct{} type rpcCtxKey struct{} var ( - // For headers: + // For headers sent to server: testMetadata = metadata.MD{ "key1": []string{"value1"}, "key2": []string{"value2"}, } - // For trailers: + // For headers sent from server: + testHeaderMetadata = metadata.MD{ + "hkey1": []string{"headerValue1"}, + "hkey2": []string{"headerValue2"}, + } + // For trailers sent from server: testTrailerMetadata = metadata.MD{ "tkey1": []string{"trailerValue1"}, "tkey2": []string{"trailerValue2"}, @@ -63,14 +68,11 @@ type testServer struct { } func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { - md, ok := metadata.FromIncomingContext(ctx) - if ok { - if err := grpc.SendHeader(ctx, md); err != nil { - return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want ", md, err) - } - if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { - return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want ", testTrailerMetadata, err) - } + if err := grpc.SendHeader(ctx, testHeaderMetadata); err != nil { + return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want ", testHeaderMetadata, err) + } + if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { + return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want ", testTrailerMetadata, err) } if in.Id == errorID { @@ -81,13 +83,10 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* } func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { - md, ok := metadata.FromIncomingContext(stream.Context()) - if ok { - if err := stream.SendHeader(md); err != nil { - return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) - } - stream.SetTrailer(testTrailerMetadata) + if err := stream.SendHeader(testHeaderMetadata); err != nil { + return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) } + stream.SetTrailer(testTrailerMetadata) for { in, err := stream.Recv() if err == io.EOF { @@ -109,13 +108,10 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ } func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error { - md, ok := metadata.FromIncomingContext(stream.Context()) - if ok { - if err := stream.SendHeader(md); err != nil { - return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) - } - stream.SetTrailer(testTrailerMetadata) + if err := stream.SendHeader(testHeaderMetadata); err != nil { + return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) } + stream.SetTrailer(testTrailerMetadata) for { in, err := stream.Recv() if err == io.EOF { @@ -133,13 +129,10 @@ func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCall } func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error { - md, ok := metadata.FromIncomingContext(stream.Context()) - if ok { - if err := stream.SendHeader(md); err != nil { - return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) - } - stream.SetTrailer(testTrailerMetadata) + if err := stream.SendHeader(testHeaderMetadata); err != nil { + return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, testHeaderMetadata, err, nil) } + stream.SetTrailer(testTrailerMetadata) if in.Id == errorID { return fmt.Errorf("got error id: %v", in.Id) @@ -275,7 +268,6 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple req = &testpb.SimpleRequest{Id: errorID} } ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) - resp, err = tc.UnaryCall(ctx, req, grpc.WaitForReady(!c.failfast)) return req, resp, err } @@ -440,7 +432,15 @@ func checkInHeader(t *testing.T, d *gotData, e *expectedData) { if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } - if !d.client { + if d.client { + // additional headers might be injected so instead of testing equality, test that all the + // expected headers keys have the expected header values. + for key := range testHeaderMetadata { + if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) { + t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key)) + } + } + } else { if st.FullMethod != e.method { t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) } @@ -450,6 +450,13 @@ func checkInHeader(t *testing.T, d *gotData, e *expectedData) { if st.Compression != e.compression { t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) } + // additional headers might be injected so instead of testing equality, test that all the + // expected headers keys have the expected header values. + for key := range testMetadata { + if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) { + t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key)) + } + } if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok { if connInfo.RemoteAddr != st.RemoteAddr { @@ -527,13 +534,20 @@ func checkInPayload(t *testing.T, d *gotData, e *expectedData) { func checkInTrailer(t *testing.T, d *gotData, e *expectedData) { var ( ok bool + st *stats.InTrailer ) - if _, ok = d.s.(*stats.InTrailer); !ok { + if st, ok = d.s.(*stats.InTrailer); !ok { t.Fatalf("got %T, want InTrailer", d.s) } if d.ctx == nil { t.Fatalf("d.ctx = nil, want ") } + if !st.Client { + t.Fatalf("st IsClient = false, want true") + } + if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) { + t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata) + } } func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { @@ -557,6 +571,13 @@ func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { if st.Compression != e.compression { t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) } + // additional headers might be injected so instead of testing equality, test that all the + // expected headers keys have the expected header values. + for key := range testMetadata { + if !reflect.DeepEqual(st.Header.Get(key), testMetadata.Get(key)) { + t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testMetadata.Get(key)) + } + } if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { if rpcInfo.FullMethodName != st.FullMethod { @@ -565,6 +586,14 @@ func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { } else { t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) } + } else { + // additional headers might be injected so instead of testing equality, test that all the + // expected headers keys have the expected header values. + for key := range testHeaderMetadata { + if !reflect.DeepEqual(st.Header.Get(key), testHeaderMetadata.Get(key)) { + t.Fatalf("st.Header[%s] = %v, want %v", key, st.Header.Get(key), testHeaderMetadata.Get(key)) + } + } } } @@ -635,6 +664,9 @@ func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) { if st.Client { t.Fatalf("st IsClient = true, want false") } + if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) { + t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata) + } } func checkEnd(t *testing.T, d *gotData, e *expectedData) {