From 12478347785fae8cafc8b1e4e9d47d17e137c355 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Thu, 25 Aug 2016 15:17:50 -0700 Subject: [PATCH] Allow multiple calls to setTrailer --- server.go | 4 +- stream.go | 4 +- test/end2end_test.go | 109 ++++++++++++++++++++++++++++++++++++----- transport/transport.go | 12 +---- 4 files changed, 103 insertions(+), 26 deletions(-) diff --git a/server.go b/server.go index a1f3ed56..debbd79a 100644 --- a/server.go +++ b/server.go @@ -886,8 +886,8 @@ func SendHeader(ctx context.Context, md metadata.MD) error { } // SetTrailer sets the trailer metadata that will be sent when an RPC returns. -// It may be called at most once from a unary RPC handler. The ctx is the RPC -// handler's Context or one derived from it. +// When called more than once, all the provided metadata will be merged. +// The ctx is the RPC handler's Context or one derived from it. func SetTrailer(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil diff --git a/stream.go b/stream.go index e1b4759e..68d777b5 100644 --- a/stream.go +++ b/stream.go @@ -414,8 +414,8 @@ type ServerStream interface { // after SendProto. It fails if called multiple times or if // called after SendProto. SendHeader(metadata.MD) error - // SetTrailer sets the trailer metadata which will be sent with the - // RPC status. + // SetTrailer sets the trailer metadata which will be sent with the RPC status. + // When called more than once, all the provided metadata will be merged. SetTrailer(metadata.MD) Stream } diff --git a/test/end2end_test.go b/test/end2end_test.go index 33962b9b..131db299 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -79,6 +79,10 @@ var ( "tkey1": []string{"trailerValue1"}, "tkey2": []string{"trailerValue2"}, } + testTrailerMetadata2 = metadata.MD{ + "tkey1": []string{"trailerValue12"}, + "tkey2": []string{"trailerValue22"}, + } // capital "Key" is illegal in HTTP/2. malformedHTTP2Metadata = metadata.MD{ "Key": []string{"foo"}, @@ -89,8 +93,9 @@ var ( var raceMode bool // set by race_test.go in race mode type testServer struct { - security string // indicate the authentication protocol used by this server. - earlyFail bool // whether to error out the execution of a service handler prematurely. + security string // indicate the authentication protocol used by this server. + earlyFail bool // whether to error out the execution of a service handler prematurely. + multipleSetTrailer bool // whether to call setTrailer multiple times. } func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { @@ -136,14 +141,21 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if err := grpc.SendHeader(ctx, md); err != nil { return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want %v", md, err, nil) } - grpc.SetTrailer(ctx, testTrailerMetadata) + if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want ", testTrailerMetadata, err) + } + if s.multipleSetTrailer { + if err := grpc.SetTrailer(ctx, testTrailerMetadata2); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want ", testTrailerMetadata2, err) + } + } } pr, ok := peer.FromContext(ctx) if !ok { - return nil, fmt.Errorf("failed to get peer from ctx") + return nil, grpc.Errorf(codes.DataLoss, "failed to get peer from ctx") } if pr.Addr == net.Addr(nil) { - return nil, fmt.Errorf("failed to get peer address") + return nil, grpc.Errorf(codes.DataLoss, "failed to get peer address") } if s.security != "" { // Check Auth info @@ -153,13 +165,13 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* authType = info.AuthType() serverName = info.State.ServerName default: - return nil, fmt.Errorf("Unknown AuthInfo type") + return nil, grpc.Errorf(codes.Unauthenticated, "Unknown AuthInfo type") } if authType != s.security { - return nil, fmt.Errorf("Wrong auth type: got %q, want %q", authType, s.security) + return nil, grpc.Errorf(codes.Unauthenticated, "Wrong auth type: got %q, want %q", authType, s.security) } if serverName != "x.test.youtube.com" { - return nil, fmt.Errorf("Unknown server name %q", serverName) + return nil, grpc.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName) } } // Simulate some service delay. @@ -229,9 +241,12 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ md, ok := metadata.FromContext(stream.Context()) if ok { if err := stream.SendHeader(md); err != nil { - return fmt.Errorf("%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + } + stream.SetTrailer(testTrailerMetadata) + if s.multipleSetTrailer { + stream.SetTrailer(testTrailerMetadata2) } - stream.SetTrailer(md) } for { in, err := stream.Recv() @@ -1193,6 +1208,76 @@ func testMetadataUnaryRPC(t *testing.T, e env) { } } +func TestMultipleSetTrailerUnaryRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testMultipleSetTrailerUnaryRPC(t, e) + } +} + +func testMultipleSetTrailerUnaryRPC(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, multipleSetTrailer: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const ( + argSize = 1 + respSize = 1 + ) + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(respSize), + Payload: payload, + } + var trailer metadata.MD + ctx := metadata.NewContext(context.Background(), testMetadata) + if _, err := tc.UnaryCall(ctx, req, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { + t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) + } + expectedTrailer := metadata.Join(testTrailerMetadata, testTrailerMetadata2) + if !reflect.DeepEqual(trailer, expectedTrailer) { + t.Fatalf("Received trailer metadata %v, want %v", trailer, expectedTrailer) + } +} + +func TestMultipleSetTrailerStreamingRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + testMultipleSetTrailerStreamingRPC(t, e) + } +} + +func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, multipleSetTrailer: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + ctx := metadata.NewContext(context.Background(), testMetadata) + stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.CloseSend(); err != nil { + t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil) + } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("%v failed to complele the FullDuplexCall: %v", stream, err) + } + + trailer := stream.Trailer() + expectedTrailer := metadata.Join(testTrailerMetadata, testTrailerMetadata2) + if !reflect.DeepEqual(trailer, expectedTrailer) { + t.Fatalf("Received trailer metadata %v, want %v", trailer, expectedTrailer) + } +} + // TestMalformedHTTP2Metedata verfies the returned error when the client // sends an illegal metadata. func TestMalformedHTTP2Metadata(t *testing.T) { @@ -1601,8 +1686,8 @@ func testMetadataStreamingRPC(t *testing.T, e env) { } } trailerMD := stream.Trailer() - if !reflect.DeepEqual(testMetadata, trailerMD) { - t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata) + if !reflect.DeepEqual(testTrailerMetadata, trailerMD) { + t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testTrailerMetadata) } } diff --git a/transport/transport.go b/transport/transport.go index e2f691ef..3d6b6a6d 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -39,7 +39,6 @@ package transport // import "google.golang.org/grpc/transport" import ( "bytes" - "errors" "fmt" "io" "net" @@ -287,19 +286,12 @@ func (s *Stream) StatusDesc() string { return s.statusDesc } -// ErrIllegalTrailerSet indicates that the trailer has already been set or it -// is too late to do so. -var ErrIllegalTrailerSet = errors.New("transport: trailer has been set") - // SetTrailer sets the trailer metadata which will be sent with the RPC status -// by the server. This can only be called at most once. Server side only. +// by the server. This can be called multiple times. Server side only. func (s *Stream) SetTrailer(md metadata.MD) error { s.mu.Lock() defer s.mu.Unlock() - if s.trailer != nil { - return ErrIllegalTrailerSet - } - s.trailer = md.Copy() + s.trailer = metadata.Join(s.trailer, md) return nil }