Separate incoming and outgoing metadata in context
This will prevent the incoming RPCs' metadata from appearing in outgoing RPCs unless it is explicitly copied, e.g.: incomingMD, ok := metadata.FromContext(ctx) if ok { ctx = metadata.NewContext(ctx, incomingMD) } Fixes #1148
This commit is contained in:
@ -66,11 +66,11 @@ md := metadata.Pairs(
|
||||
|
||||
## Retrieving metadata from context
|
||||
|
||||
Metadata can be retrieved from context using `FromContext`:
|
||||
Metadata can be retrieved from context using `FromIncomingContext`:
|
||||
|
||||
```go
|
||||
func (s *server) SomeRPC(ctx context.Context, in *pb.SomeRequest) (*pb.SomeResponse, err) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
// do something with metadata
|
||||
}
|
||||
```
|
||||
@ -88,7 +88,7 @@ To send metadata to server, the client can wrap the metadata into a context usin
|
||||
md := metadata.Pairs("key", "val")
|
||||
|
||||
// create a new context with this metadata
|
||||
ctx := metadata.NewContext(context.Background(), md)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), md)
|
||||
|
||||
// make unary RPC
|
||||
response, err := client.SomeRPC(ctx, someRequest)
|
||||
@ -96,6 +96,9 @@ response, err := client.SomeRPC(ctx, someRequest)
|
||||
// or make streaming RPC
|
||||
stream, err := client.SomeStreamingRPC(ctx)
|
||||
```
|
||||
|
||||
To read this back from the context on the client (e.g. in an interceptor) before the RPC is sent, use `FromOutgoingContext`.
|
||||
|
||||
### Receiving metadata
|
||||
|
||||
Metadata that a client can receive includes header and trailer.
|
||||
@ -152,7 +155,7 @@ For streaming calls, the server needs to get context from the stream.
|
||||
|
||||
```go
|
||||
func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
// do something with metadata
|
||||
}
|
||||
```
|
||||
@ -161,7 +164,7 @@ func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someRespo
|
||||
|
||||
```go
|
||||
func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error {
|
||||
md, ok := metadata.FromContext(stream.Context()) // get context from stream
|
||||
md, ok := metadata.FromIncomingContext(stream.Context()) // get context from stream
|
||||
// do something with metadata
|
||||
}
|
||||
```
|
||||
|
@ -215,7 +215,7 @@ type helloServer struct {
|
||||
}
|
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
|
||||
}
|
||||
|
@ -392,7 +392,7 @@ func DoPerRPCCreds(tc testpb.TestServiceClient, serviceAccountKeyFile, oauthScop
|
||||
}
|
||||
token := GetToken(serviceAccountKeyFile, oauthScope)
|
||||
kv := map[string]string{"authorization": token.TokenType + " " + token.AccessToken}
|
||||
ctx := metadata.NewContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}})
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}})
|
||||
reply, err := tc.UnaryCall(ctx, req)
|
||||
if err != nil {
|
||||
grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err)
|
||||
@ -416,7 +416,7 @@ var (
|
||||
|
||||
// DoCancelAfterBegin cancels the RPC after metadata has been sent but before payloads are sent.
|
||||
func DoCancelAfterBegin(tc testpb.TestServiceClient, args ...grpc.CallOption) {
|
||||
ctx, cancel := context.WithCancel(metadata.NewContext(context.Background(), testMetadata))
|
||||
ctx, cancel := context.WithCancel(metadata.NewOutgoingContext(context.Background(), testMetadata))
|
||||
stream, err := tc.StreamingInputCall(ctx, args...)
|
||||
if err != nil {
|
||||
grpclog.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
|
||||
@ -491,7 +491,7 @@ func DoCustomMetadata(tc testpb.TestServiceClient, args ...grpc.CallOption) {
|
||||
ResponseSize: proto.Int32(int32(1)),
|
||||
Payload: pl,
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), customMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), customMetadata)
|
||||
var header, trailer metadata.MD
|
||||
args = append(args, grpc.Header(&header), grpc.Trailer(&trailer))
|
||||
reply, err := tc.UnaryCall(
|
||||
@ -627,7 +627,7 @@ func serverNewPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error)
|
||||
|
||||
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
status := in.GetResponseStatus()
|
||||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
if initialMetadata, ok := md[initialMetadataKey]; ok {
|
||||
header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
|
||||
grpc.SendHeader(ctx, header)
|
||||
@ -686,7 +686,7 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput
|
||||
}
|
||||
|
||||
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
if md, ok := metadata.FromContext(stream.Context()); ok {
|
||||
if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
|
||||
if initialMetadata, ok := md[initialMetadataKey]; ok {
|
||||
header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
|
||||
stream.SendHeader(header)
|
||||
|
@ -136,17 +136,41 @@ func Join(mds ...MD) MD {
|
||||
return out
|
||||
}
|
||||
|
||||
type mdKey struct{}
|
||||
type mdIncomingKey struct{}
|
||||
type mdOutgoingKey struct{}
|
||||
|
||||
// NewContext creates a new context with md attached.
|
||||
// NewContext is a wrapper for NewOutgoingContext(ctx, md). Deprecated.
|
||||
func NewContext(ctx context.Context, md MD) context.Context {
|
||||
return context.WithValue(ctx, mdKey{}, md)
|
||||
return NewOutgoingContext(ctx, md)
|
||||
}
|
||||
|
||||
// FromContext returns the MD in ctx if it exists.
|
||||
// The returned md should be immutable, writing to it may cause races.
|
||||
// Modification should be made to the copies of the returned md.
|
||||
// NewIncomingContext creates a new context with incoming md attached.
|
||||
func NewIncomingContext(ctx context.Context, md MD) context.Context {
|
||||
return context.WithValue(ctx, mdIncomingKey{}, md)
|
||||
}
|
||||
|
||||
// NewOutgoingContext creates a new context with outgoing md attached.
|
||||
func NewOutgoingContext(ctx context.Context, md MD) context.Context {
|
||||
return context.WithValue(ctx, mdOutgoingKey{}, md)
|
||||
}
|
||||
|
||||
// FromContext is a wrapper for FromIncomingContext(ctx). Deprecated.
|
||||
func FromContext(ctx context.Context) (md MD, ok bool) {
|
||||
md, ok = ctx.Value(mdKey{}).(MD)
|
||||
return FromIncomingContext(ctx)
|
||||
}
|
||||
|
||||
// FromIncomingContext returns the incoming MD in ctx if it exists. The
|
||||
// returned md should be immutable, writing to it may cause races.
|
||||
// Modification should be made to the copies of the returned md.
|
||||
func FromIncomingContext(ctx context.Context) (md MD, ok bool) {
|
||||
md, ok = ctx.Value(mdIncomingKey{}).(MD)
|
||||
return
|
||||
}
|
||||
|
||||
// FromOutgoingContext returns the outgoing MD in ctx if it exists. The
|
||||
// returned md should be immutable, writing to it may cause races.
|
||||
// Modification should be made to the copies of the returned md.
|
||||
func FromOutgoingContext(ctx context.Context) (md MD, ok bool) {
|
||||
md, ok = ctx.Value(mdOutgoingKey{}).(MD)
|
||||
return
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ var (
|
||||
type testServer struct{}
|
||||
|
||||
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if ok {
|
||||
if err := grpc.SendHeader(ctx, md); err != nil {
|
||||
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
|
||||
@ -93,7 +93,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
||||
}
|
||||
|
||||
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
md, ok := metadata.FromContext(stream.Context())
|
||||
md, ok := metadata.FromIncomingContext(stream.Context())
|
||||
if ok {
|
||||
if err := stream.SendHeader(md); err != nil {
|
||||
return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
|
||||
@ -237,7 +237,7 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple
|
||||
} else {
|
||||
req = &testpb.SimpleRequest{Id: errorID}
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
|
||||
resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast))
|
||||
return req, resp, err
|
||||
@ -250,7 +250,7 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
|
||||
err error
|
||||
)
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
stream, err := tc.FullDuplexCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
|
||||
stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
|
||||
if err != nil {
|
||||
return reqs, resps, err
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ type testServer struct {
|
||||
}
|
||||
|
||||
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
// For testing purpose, returns an error if user-agent is failAppUA.
|
||||
// To test that client gets the correct error.
|
||||
if ua, ok := md["user-agent"]; !ok || strings.HasPrefix(ua[0], failAppUA) {
|
||||
@ -152,7 +152,7 @@ func newPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) {
|
||||
}
|
||||
|
||||
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if ok {
|
||||
if _, exists := md[":authority"]; !exists {
|
||||
return nil, grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md)
|
||||
@ -223,7 +223,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
||||
}
|
||||
|
||||
func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
|
||||
if md, ok := metadata.FromContext(stream.Context()); ok {
|
||||
if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
|
||||
if _, exists := md[":authority"]; !exists {
|
||||
return grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md)
|
||||
}
|
||||
@ -274,7 +274,7 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput
|
||||
}
|
||||
|
||||
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
md, ok := metadata.FromContext(stream.Context())
|
||||
md, ok := metadata.FromIncomingContext(stream.Context())
|
||||
if ok {
|
||||
if s.setAndSendHeader {
|
||||
if err := stream.SetHeader(md); err != nil {
|
||||
@ -1385,7 +1385,7 @@ func testFailedEmptyUnary(t *testing.T, e env) {
|
||||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
wantErr := detailedError
|
||||
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) {
|
||||
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr)
|
||||
@ -1602,7 +1602,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
|
||||
Payload: payload,
|
||||
}
|
||||
var header, trailer metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.Trailer(&trailer)); err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
@ -1648,7 +1648,7 @@ func testMultipleSetTrailerUnaryRPC(t *testing.T, e env) {
|
||||
Payload: payload,
|
||||
}
|
||||
var trailer metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
@ -1671,7 +1671,7 @@ func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) {
|
||||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
@ -1722,7 +1722,7 @@ func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) {
|
||||
Payload: payload,
|
||||
}
|
||||
var header metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
@ -1766,7 +1766,7 @@ func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) {
|
||||
}
|
||||
|
||||
var header metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
@ -1809,7 +1809,7 @@ func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) {
|
||||
Payload: payload,
|
||||
}
|
||||
var header metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err == nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <non-nil>", ctx, err)
|
||||
}
|
||||
@ -1841,7 +1841,7 @@ func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) {
|
||||
argSize = 1
|
||||
respSize = 1
|
||||
)
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
@ -1885,7 +1885,7 @@ func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) {
|
||||
argSize = 1
|
||||
respSize = 1
|
||||
)
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
@ -1949,7 +1949,7 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) {
|
||||
argSize = 1
|
||||
respSize = -1
|
||||
)
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
@ -2014,7 +2014,7 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) {
|
||||
ResponseSize: proto.Int32(314),
|
||||
Payload: payload,
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), malformedHTTP2Metadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), malformedHTTP2Metadata)
|
||||
if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Internal {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _) = _, %v; want _, %s", ctx, err, codes.Internal)
|
||||
}
|
||||
@ -2344,7 +2344,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
|
||||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
ctx := metadata.NewContext(te.ctx, testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(te.ctx, testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
@ -2483,7 +2483,7 @@ func testFailedServerStreaming(t *testing.T, e env) {
|
||||
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||
ResponseParameters: respParam,
|
||||
}
|
||||
ctx := metadata.NewContext(te.ctx, testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(te.ctx, testMetadata)
|
||||
stream, err := tc.StreamingOutputCall(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want <nil>", tc, err)
|
||||
@ -2887,7 +2887,7 @@ func testCompressOK(t *testing.T, e env) {
|
||||
ResponseSize: proto.Int32(respSize),
|
||||
Payload: payload,
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), metadata.Pairs("something", "something"))
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something", "something"))
|
||||
if _, err := tc.UnaryCall(ctx, req); err != nil {
|
||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
|
||||
}
|
||||
@ -3679,3 +3679,168 @@ func (fw *filterWriter) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
return fw.dst.Write(p)
|
||||
}
|
||||
|
||||
// stubServer is a server that is easy to customize within individual test
|
||||
// cases.
|
||||
type stubServer struct {
|
||||
// Guarantees we satisfy this interface; panics if unimplemented methods are called.
|
||||
testpb.TestServiceServer
|
||||
|
||||
// Customizable implementations of server handlers.
|
||||
emptyCall func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error)
|
||||
fullDuplexCall func(stream testpb.TestService_FullDuplexCallServer) error
|
||||
|
||||
// A client connected to this service the test may use. Created in Start().
|
||||
client testpb.TestServiceClient
|
||||
|
||||
cleanups []func() // Lambdas executed in Stop(); populated by Start().
|
||||
}
|
||||
|
||||
func (ss *stubServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
return ss.emptyCall(ctx, in)
|
||||
}
|
||||
|
||||
func (ss *stubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
return ss.fullDuplexCall(stream)
|
||||
}
|
||||
|
||||
// Start starts the server and creates a client connected to it.
|
||||
func (ss *stubServer) Start() error {
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`net.Listen("tcp", ":0") = %v`, err)
|
||||
}
|
||||
ss.cleanups = append(ss.cleanups, func() { lis.Close() })
|
||||
|
||||
s := grpc.NewServer()
|
||||
testpb.RegisterTestServiceServer(s, ss)
|
||||
go s.Serve(lis)
|
||||
ss.cleanups = append(ss.cleanups, s.Stop)
|
||||
|
||||
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock())
|
||||
if err != nil {
|
||||
return fmt.Errorf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
ss.cleanups = append(ss.cleanups, func() { cc.Close() })
|
||||
|
||||
ss.client = testpb.NewTestServiceClient(cc)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ss *stubServer) Stop() {
|
||||
for i := len(ss.cleanups) - 1; i >= 0; i-- {
|
||||
ss.cleanups[i]()
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
|
||||
const mdkey = "somedata"
|
||||
|
||||
// endpoint ensures mdkey is NOT in metadata and returns an error if it is.
|
||||
endpoint := &stubServer{
|
||||
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] != nil {
|
||||
return nil, status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey)
|
||||
}
|
||||
return &testpb.Empty{}, nil
|
||||
},
|
||||
}
|
||||
if err := endpoint.Start(); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer endpoint.Stop()
|
||||
|
||||
// proxy ensures mdkey IS in metadata, then forwards the RPC to endpoint
|
||||
// without explicitly copying the metadata.
|
||||
proxy := &stubServer{
|
||||
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] == nil {
|
||||
return nil, status.Errorf(codes.Internal, "proxy: md=%v; want contains(%q)", md, mdkey)
|
||||
}
|
||||
return endpoint.client.EmptyCall(ctx, in)
|
||||
},
|
||||
}
|
||||
if err := proxy.Start(); err != nil {
|
||||
t.Fatalf("Error starting proxy server: %v", err)
|
||||
}
|
||||
defer proxy.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
md := metadata.Pairs(mdkey, "val")
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
|
||||
// Sanity check that endpoint properly errors when it sees mdkey.
|
||||
_, err := endpoint.client.EmptyCall(ctx, &testpb.Empty{})
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.Internal {
|
||||
t.Fatalf("endpoint.client.EmptyCall(_, _) = _, %v; want _, <status with Code()=Internal>", err)
|
||||
}
|
||||
|
||||
if _, err := proxy.client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingProxyDoesNotForwardMetadata(t *testing.T) {
|
||||
const mdkey = "somedata"
|
||||
|
||||
// doFDC performs a FullDuplexCall with client and returns the error from the
|
||||
// first stream.Recv call, or nil if that error is io.EOF. Calls t.Fatal if
|
||||
// the stream cannot be established.
|
||||
doFDC := func(ctx context.Context, client testpb.TestServiceClient) error {
|
||||
stream, err := client.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unwanted error: %v", err)
|
||||
}
|
||||
if _, err := stream.Recv(); err != io.EOF {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// endpoint ensures mdkey is NOT in metadata and returns an error if it is.
|
||||
endpoint := &stubServer{
|
||||
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
ctx := stream.Context()
|
||||
if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] != nil {
|
||||
return status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := endpoint.Start(); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer endpoint.Stop()
|
||||
|
||||
// proxy ensures mdkey IS in metadata, then forwards the RPC to endpoint
|
||||
// without explicitly copying the metadata.
|
||||
proxy := &stubServer{
|
||||
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
ctx := stream.Context()
|
||||
if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] == nil {
|
||||
return status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey)
|
||||
}
|
||||
return doFDC(ctx, endpoint.client)
|
||||
},
|
||||
}
|
||||
if err := proxy.Start(); err != nil {
|
||||
t.Fatalf("Error starting proxy server: %v", err)
|
||||
}
|
||||
defer proxy.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
md := metadata.Pairs(mdkey, "val")
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
|
||||
// Sanity check that endpoint properly errors when it sees mdkey in ctx.
|
||||
err := doFDC(ctx, endpoint.client)
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.Internal {
|
||||
t.Fatalf("stream.Recv() = _, %v; want _, <status with Code()=Internal>", err)
|
||||
}
|
||||
|
||||
if err := doFDC(ctx, proxy.client); err != nil {
|
||||
t.Fatalf("doFDC(_, proxy.client) = %v; want nil", err)
|
||||
}
|
||||
}
|
||||
|
@ -319,7 +319,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
||||
if req.TLS != nil {
|
||||
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
|
||||
}
|
||||
ctx = metadata.NewContext(ctx, ht.headerMD)
|
||||
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
||||
ctx = peer.NewContext(ctx, pr)
|
||||
s.ctx = newContextWithStream(ctx, s)
|
||||
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
|
||||
|
@ -432,7 +432,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
hasMD bool
|
||||
endHeaders bool
|
||||
)
|
||||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
if md, ok := metadata.FromOutgoingContext(ctx); ok {
|
||||
hasMD = true
|
||||
for k, v := range md {
|
||||
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
|
||||
|
@ -261,7 +261,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
||||
s.ctx = newContextWithStream(s.ctx, s)
|
||||
// Attach the received metadata to the context.
|
||||
if len(state.mdata) > 0 {
|
||||
s.ctx = metadata.NewContext(s.ctx, state.mdata)
|
||||
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
|
||||
}
|
||||
|
||||
s.dec = &recvBufferReader{
|
||||
|
Reference in New Issue
Block a user